[spirv-reader] Handle old-style storage buffers

Old way:
  - struct decorated with BufferBlock
  - Uniform storage class

New way
  - struct decorated with Block
  - StorageBuffer storage class

Also fixes the result type for an access chain.

Bug: tint:99
Change-Id: I2324ba94bb19b369d206313de798bdfec6099fe0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24605
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
David Neto 2020-07-20 13:24:26 +00:00 committed by dan sinclair
parent 13904a612d
commit 6982c22eee
9 changed files with 451 additions and 46 deletions

View File

@ -38,7 +38,7 @@ ast::PipelineStage EnumConverter::ToPipelineStage(SpvExecutionModel model) {
return ast::PipelineStage::kNone;
}
ast::StorageClass EnumConverter::ToStorageClass(SpvStorageClass sc) {
ast::StorageClass EnumConverter::ToStorageClass(const SpvStorageClass sc) {
switch (sc) {
case SpvStorageClassInput:
return ast::StorageClass::kInput;

View File

@ -44,7 +44,7 @@ class EnumConverter {
/// On failure, logs an error and returns kNone
/// @param sc the SPIR-V storage class
/// @returns a Tint AST storage class
ast::StorageClass ToStorageClass(SpvStorageClass sc);
ast::StorageClass ToStorageClass(const SpvStorageClass sc);
/// Converts a SPIR-V Builtin value a Tint Builtin.
/// On failure, logs an error and returns kNone

View File

@ -631,7 +631,10 @@ bool FunctionEmitter::EmitBody() {
return false;
}
RegisterValuesNeedingNamedOrHoistedDefinition();
if (!RegisterLocallyDefinedValues()) {
return false;
}
FindValuesNeedingNamedOrHoistedDefinition();
if (!EmitFunctionVariables()) {
return false;
@ -2419,10 +2422,10 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
for (auto id : sorted_by_index(block_info.hoisted_ids)) {
const auto* def_inst = def_use_mgr_->GetDef(id);
assert(def_inst);
AddStatement(
std::make_unique<ast::VariableDeclStatement>(parser_impl_.MakeVariable(
id, ast::StorageClass::kFunction,
parser_impl_.ConvertType(def_inst->type_id()))));
auto* ast_type =
RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
AddStatement(std::make_unique<ast::VariableDeclStatement>(
parser_impl_.MakeVariable(id, ast::StorageClass::kFunction, ast_type)));
// Save this as an already-named value.
identifier_values_.insert(id);
}
@ -2580,12 +2583,14 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
expr.type = expr.type->AsPointer()->type();
return EmitConstDefOrWriteToHoistedVar(inst, std::move(expr));
}
case SpvOpCopyObject:
case SpvOpCopyObject: {
// Arguably, OpCopyObject is purely combinatorial. On the other hand,
// it exists to make a new name for something. So we choose to make
// a new named constant definition.
return EmitConstDefOrWriteToHoistedVar(
inst, MakeExpression(inst.GetSingleWordInOperand(0)));
auto expr = MakeExpression(inst.GetSingleWordInOperand(0));
expr.type = RemapStorageClass(expr.type, result_id);
return EmitConstDefOrWriteToHoistedVar(inst, std::move(expr));
}
case SpvOpPhi: {
// Emit a read from the associated state variable.
auto expr = TypedExpression(
@ -2754,7 +2759,6 @@ TypedExpression FunctionEmitter::MakeAccessChain(
// ever-deeper nested indexing expressions. Start off with an expression
// for the base, and then bury that inside nested indexing expressions.
TypedExpression current_expr(MakeOperand(inst, 0));
const auto constants = constant_mgr_->GetOperandConstants(&inst);
static const char* swizzles[] = {"x", "y", "z", "w"};
@ -2803,9 +2807,10 @@ TypedExpression FunctionEmitter::MakeAccessChain(
// Skip past the member index that gets us to Position.
first_index = first_index + 1;
// Replace the gl_PerVertex reference with the gl_Position reference
ptr_ty_id = builtin_position_info.member_pointer_type_id;
current_expr.expr =
std::make_unique<ast::IdentifierExpression>(namer_.Name(base_id));
ptr_ty_id = builtin_position_info.member_pointer_type_id;
current_expr.type = parser_impl_.ConvertType(ptr_ty_id);
}
}
@ -2815,6 +2820,7 @@ TypedExpression FunctionEmitter::MakeAccessChain(
<< " base pointer is not of pointer type";
return {};
}
SpvStorageClass storage_class = ptr_type->AsPointer()->storage_class();
const auto* pointee_type = ptr_type->AsPointer()->pointee_type();
for (uint32_t index = first_index; index < num_in_operands; ++index) {
const auto* index_const =
@ -2904,9 +2910,13 @@ TypedExpression FunctionEmitter::MakeAccessChain(
<< type_mgr_->GetId(pointee_type) << " " << pointee_type->str();
return {};
}
current_expr.reset(TypedExpression(
parser_impl_.ConvertType(type_mgr_->GetId(pointee_type)),
std::move(next_expr)));
const auto pointee_type_id = type_mgr_->GetId(pointee_type);
const auto pointer_type_id =
type_mgr_->FindPointerToType(pointee_type_id, storage_class);
auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id);
assert(ast_pointer_type);
assert(ast_pointer_type->IsPointer);
current_expr.reset(TypedExpression(ast_pointer_type, std::move(next_expr)));
}
return current_expr;
}
@ -3074,7 +3084,7 @@ TypedExpression FunctionEmitter::MakeVectorShuffle(
result_type, std::move(values))};
}
void FunctionEmitter::RegisterValuesNeedingNamedOrHoistedDefinition() {
bool FunctionEmitter::RegisterLocallyDefinedValues() {
// Create a DefInfo for each value definition in this function.
size_t index = 0;
for (auto block_id : block_order_) {
@ -3087,9 +3097,72 @@ void FunctionEmitter::RegisterValuesNeedingNamedOrHoistedDefinition() {
}
def_info_[result_id] = std::make_unique<DefInfo>(inst, block_pos, index);
index++;
// Determine storage class for pointer values. Do this in order because
// we might rely on the storage class for a previously-visited definition.
// Logical pointers can't be transmitted through OpPhi, so remaining
// pointer definitions are SSA values, and their definitions must be
// visited before their uses.
auto& storage_class = def_info_[result_id]->storage_class;
const auto* type = type_mgr_->GetType(inst.type_id());
if (type && type->AsPointer()) {
const auto* ast_type = parser_impl_.ConvertType(inst.type_id());
if (ast_type && ast_type->AsPointer()) {
storage_class = ast_type->AsPointer()->storage_class();
}
switch (inst.opcode()) {
case SpvOpUndef:
case SpvOpVariable:
// Keep the default decision based on the result type.
break;
case SpvOpAccessChain:
case SpvOpCopyObject:
// Inherit from the first operand. We need this so we can pick up
// a remapped storage buffer.
storage_class =
GetStorageClassForPointerValue(inst.GetSingleWordInOperand(0));
break;
default:
return Fail() << "pointer defined in function from unknown opcode: "
<< inst.PrettyPrint();
}
}
}
}
return true;
}
ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) {
auto where = def_info_.find(id);
if (where != def_info_.end()) {
return where->second.get()->storage_class;
}
const auto type_id = def_use_mgr_->GetDef(id)->type_id();
if (type_id) {
auto* ast_type = parser_impl_.ConvertType(type_id);
if (ast_type && ast_type->IsPointer()) {
return ast_type->AsPointer()->storage_class();
}
}
return ast::StorageClass::kNone;
}
ast::type::Type* FunctionEmitter::RemapStorageClass(ast::type::Type* type,
uint32_t result_id) {
if (type->IsPointer()) {
// Remap an old-style storage buffer pointer to a new-style storage
// buffer pointer.
const auto* ast_ptr_type = type->AsPointer();
const auto sc = GetStorageClassForPointerValue(result_id);
if (ast_ptr_type->storage_class() != sc) {
return parser_impl_.context().type_mgr().Get(
std::make_unique<ast::type::PointerType>(ast_ptr_type->type(), sc));
}
}
return type;
}
void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
// Mark vector operands of OpVectorShuffle as needing a named definition,
// but only if they are defined in this function as well.
for (auto& id_def_info_pair : def_info_) {

View File

@ -33,6 +33,7 @@
#include "src/ast/expression.h"
#include "src/ast/module.h"
#include "src/ast/statement.h"
#include "src/ast/storage_class.h"
#include "src/reader/spirv/construct.h"
#include "src/reader/spirv/fail_stream.h"
#include "src/reader/spirv/namer.h"
@ -192,7 +193,7 @@ inline std::ostream& operator<<(std::ostream& o, const BlockInfo& bi) {
/// Bookkeeping info for a SPIR-V ID defined in the function.
/// This will be valid for result IDs for:
/// - instructions that are not OpLabel, OpVariable, and OpFunctionParameter
/// - instructions that are not OpLabel, and not OpFunctionParameter
/// - are defined in a basic block visited in the block-order for the function.
struct DefInfo {
/// Constructor.
@ -243,8 +244,15 @@ struct DefInfo {
/// If the definition is an OpPhi, then |phi_var| is the name of the
/// variable that stores the value carried from parent basic blocks into
// the basic block containing the OpPhi. Otherwise this is the empty string.
/// the basic block containing the OpPhi. Otherwise this is the empty string.
std::string phi_var;
/// The storage class to use for this value, if it is of pointer type.
/// This is required to carry a stroage class override from a storage
/// buffer expressed in the old style (with Uniform storage class)
/// that needs to be remapped to StorageBuffer storage class.
/// This is kNone for non-pointers.
ast::StorageClass storage_class = ast::StorageClass::kNone;
};
inline std::ostream& operator<<(std::ostream& o, const DefInfo& di) {
@ -254,8 +262,11 @@ inline std::ostream& operator<<(std::ostream& o, const DefInfo& di) {
<< " last_use_pos: " << di.last_use_pos << " requires_named_const_def: "
<< (di.requires_named_const_def ? "true" : "false")
<< " requires_hoisted_def: " << (di.requires_hoisted_def ? "true" : "false")
<< " phi_var: '" << di.phi_var << "'"
<< "}";
<< " phi_var: '" << di.phi_var << "'";
if (di.storage_class != ast::StorageClass::kNone) {
o << " sc:" << int(di.storage_class);
}
o << "}";
return o;
}
@ -367,7 +378,24 @@ class FunctionEmitter {
/// @returns false if bad nesting has been detected.
bool FindIfSelectionInternalHeaders();
/// Record the SPIR-V IDs of non-constants that should get a 'const'
/// Creates a DefInfo record for each locally defined SPIR-V ID.
/// Populates the |def_info_| mapping with basic results.
/// @returns false on failure
bool RegisterLocallyDefinedValues();
/// Returns the Tint storage class for the given SPIR-V ID that is a
/// pointer value.
/// @returns the storage class
ast::StorageClass GetStorageClassForPointerValue(uint32_t id);
/// Remaps the storage class for the type of a locally-defined value,
/// if necessary. If it's not a pointer type, or if its storage class
/// already matches, then the result is a copy of the |type| argument.
/// @param type the AST type
/// @param result_id the SPIR-V ID for the locally defined value
ast::type::Type* RemapStorageClass(ast::type::Type* type, uint32_t result_id);
/// Marks locally defined values when they should get a 'const'
/// definition in WGSL, or a 'var' definition at an outer scope.
/// This occurs in several cases:
/// - When a SPIR-V instruction might use the dynamically computed value
@ -382,8 +410,8 @@ class FunctionEmitter {
/// - When a definition is in a construct that does not enclose all the
/// uses. In this case the definition's |requires_hoisted_def| property
/// is set to true.
/// Populates the |def_info_| mapping.
void RegisterValuesNeedingNamedOrHoistedDefinition();
/// Updates the |def_info_| mapping.
void FindValuesNeedingNamedOrHoistedDefinition();
/// Emits declarations of function variables.
/// @returns false if emission failed.

View File

@ -287,10 +287,10 @@ TEST_F(SpvParserTest_CompositeExtract, Vector_IndexTooBigError) {
TEST_F(SpvParserTest_CompositeExtract, Matrix) {
const auto assembly = Preamble() + R"(
%ptr = OpTypePointer Function %m3v2float
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %m3v2float %var
%2 = OpCompositeExtract %v2float %1 2
OpReturn
@ -318,10 +318,10 @@ TEST_F(SpvParserTest_CompositeExtract, Matrix) {
TEST_F(SpvParserTest_CompositeExtract, Matrix_IndexTooBigError) {
const auto assembly = Preamble() + R"(
%ptr = OpTypePointer Function %m3v2float
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %m3v2float %var
%2 = OpCompositeExtract %v2float %1 3
OpReturn
@ -338,10 +338,10 @@ TEST_F(SpvParserTest_CompositeExtract, Matrix_IndexTooBigError) {
TEST_F(SpvParserTest_CompositeExtract, Matrix_Vector) {
const auto assembly = Preamble() + R"(
%ptr = OpTypePointer Function %m3v2float
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %m3v2float %var
%2 = OpCompositeExtract %float %1 2 1
OpReturn
@ -372,10 +372,10 @@ TEST_F(SpvParserTest_CompositeExtract, Matrix_Vector) {
TEST_F(SpvParserTest_CompositeExtract, Array) {
const auto assembly = Preamble() + R"(
%ptr = OpTypePointer Function %a_u_5
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %a_u_5 %var
%2 = OpCompositeExtract %uint %1 3
OpReturn
@ -404,10 +404,10 @@ TEST_F(SpvParserTest_CompositeExtract, RuntimeArray_IsError) {
const auto assembly = Preamble() + R"(
%rtarr = OpTypeRuntimeArray %uint
%ptr = OpTypePointer Function %rtarr
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %rtarr %var
%2 = OpCompositeExtract %uint %1 3
OpReturn
@ -423,10 +423,10 @@ TEST_F(SpvParserTest_CompositeExtract, RuntimeArray_IsError) {
TEST_F(SpvParserTest_CompositeExtract, Struct) {
const auto assembly = Preamble() + R"(
%ptr = OpTypePointer Function %s_v2f_u_i
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %s_v2f_u_i %var
%2 = OpCompositeExtract %int %1 2
OpReturn
@ -454,10 +454,10 @@ TEST_F(SpvParserTest_CompositeExtract, Struct) {
TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) {
const auto assembly = Preamble() + R"(
%ptr = OpTypePointer Function %s_v2f_u_i
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %s_v2f_u_i %var
%2 = OpCompositeExtract %int %1 40
OpReturn
@ -476,10 +476,10 @@ TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) {
%a_mat = OpTypeArray %m3v2float %uint_3
%s = OpTypeStruct %uint %a_mat
%ptr = OpTypePointer Function %s
%var = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%var = OpVariable %ptr Function
%1 = OpLoad %s %var
%2 = OpCompositeExtract %float %1 1 2 0 1
OpReturn
@ -553,10 +553,10 @@ VariableDeclStatement{
TEST_F(SpvParserTest_CopyObject, Pointer) {
const auto assembly = Preamble() + R"(
%ptr = OpTypePointer Function %uint
%10 = OpVariable %ptr Function
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%10 = OpVariable %ptr Function
%1 = OpCopyObject %ptr %10
%2 = OpCopyObject %ptr %1
OpReturn

View File

@ -708,6 +708,245 @@ TEST_F(SpvParserTest, EmitStatement_AccessChain_InvalidPointeeType) {
HasSubstr("Access chain with unknown pointee type %60 void"));
}
std::string OldStorageBufferPreamble() {
return R"(
OpName %myvar "myvar"
OpDecorate %struct BufferBlock
OpMemberDecorate %struct 0 Offset 0
OpMemberDecorate %struct 1 Offset 4
OpDecorate %arr ArrayStride 4
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%arr = OpTypeRuntimeArray %uint
%struct = OpTypeStruct %uint %arr
%ptr_struct = OpTypePointer Uniform %struct
%ptr_uint = OpTypePointer Uniform %uint
%myvar = OpVariable %ptr_struct Uniform
)";
}
TEST_F(SpvParserTest, RemapStorageBuffer_TypesAndVarDeclarations) {
// Enusure we get the right module-scope declaration. This tests translation
// of the structure type, arrays of the structure, pointers to them, and
// OpVariable of these.
const auto assembly = OldStorageBufferPreamble();
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< assembly << p->error();
const auto module_str = p->module().to_str();
EXPECT_THAT(module_str, HasSubstr(R"(
Variable{
myvar
storage_buffer
__alias_S__struct_S
}
RTArr -> __array__u32_stride_4
S -> __struct_S)"));
}
TEST_F(SpvParserTest,
RemapStorageBuffer_ThroughAccessChain_NonCascaded) {
const auto assembly = OldStorageBufferPreamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
; the scalar element
%1 = OpAccessChain %ptr_uint %myvar %uint_0
OpStore %1 %uint_0
; element in the runtime array
%2 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1
OpStore %2 %uint_0
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
MemberAccessor{
Identifier{myvar}
Identifier{field0}
}
ScalarConstructor{0}
}
Assignment{
ArrayAccessor{
MemberAccessor{
Identifier{myvar}
Identifier{field1}
}
ScalarConstructor{1}
}
ScalarConstructor{0}
})")) << ToString(fe.ast_body())
<< p->error();
}
TEST_F(SpvParserTest, RemapStorageBuffer_ThroughAccessChain_Cascaded) {
const auto assembly = OldStorageBufferPreamble() + R"(
%ptr_rtarr = OpTypePointer Uniform %arr
%100 = OpFunction %void None %voidfn
%entry = OpLabel
; get the runtime array
%1 = OpAccessChain %ptr_rtarr %myvar %uint_1
; now an element in it
%2 = OpAccessChain %ptr_uint %1 %uint_1
OpStore %2 %uint_0
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
ArrayAccessor{
MemberAccessor{
Identifier{myvar}
Identifier{field1}
}
ScalarConstructor{1}
}
ScalarConstructor{0}
})")) << ToString(fe.ast_body())
<< p->error();
}
TEST_F(SpvParserTest,
RemapStorageBuffer_ThroughCopyObject_WithoutHoisting) {
// Generates a const declaration directly.
// We have to do a bunch of storage class tracking for locally
// defined values in order to get the right pointer-to-storage-buffer
// value type for the const declration.
const auto assembly = OldStorageBufferPreamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1
%2 = OpCopyObject %ptr_uint %1
OpStore %2 %uint_0
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
Variable{
x_2
none
__ptr_storage_buffer__u32
{
ArrayAccessor{
MemberAccessor{
Identifier{myvar}
Identifier{field1}
}
ScalarConstructor{1}
}
}
}
}
Assignment{
Identifier{x_2}
ScalarConstructor{0}
})")) << ToString(fe.ast_body())
<< p->error();
}
TEST_F(SpvParserTest, RemapStorageBuffer_ThroughCopyObject_WithHoisting) {
// Like the previous test, but the declaration for the copy-object
// has its declaration hoisted.
const auto assembly = OldStorageBufferPreamble() + R"(
%bool = OpTypeBool
%cond = OpConstantTrue %bool
%100 = OpFunction %void None %voidfn
%entry = OpLabel
OpSelectionMerge %99 None
OpBranchConditional %cond %20 %30
%20 = OpLabel
%1 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1
; this definintion dominates the use in %99
%2 = OpCopyObject %ptr_uint %1
OpBranch %99
%30 = OpLabel
OpReturn
%99 = OpLabel
OpStore %2 %uint_0
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{
Variable{
x_2
function
__ptr_storage_buffer__u32
}
}
If{
(
ScalarConstructor{true}
)
{
Assignment{
Identifier{x_2}
ArrayAccessor{
MemberAccessor{
Identifier{myvar}
Identifier{field1}
}
ScalarConstructor{1}
}
}
}
}
Else{
{
Return{}
}
}
Assignment{
Identifier{x_2}
ScalarConstructor{0}
}
Return{}
)")) << ToString(fe.ast_body())
<< p->error();
}
TEST_F(SpvParserTest, DISABLED_RemapStorageBuffer_ThroughFunctionCall) {
// TODO(dneto): Blocked on OpFunctionCall support.
// We might need this for passing pointers into atomic builtins.
}
TEST_F(SpvParserTest, DISABLED_RemapStorageBuffer_ThroughFunctionParameter) {
// TODO(dneto): Blocked on OpFunctionCall support.
}
} // namespace
} // namespace spirv
} // namespace reader

View File

@ -65,6 +65,7 @@
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/ast/variable_decoration.h"
#include "src/reader/spirv/enum_converter.h"
#include "src/reader/spirv/function.h"
#include "src/type_manager.h"
@ -612,7 +613,8 @@ ast::type::Type* ParserImpl::ConvertType(
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Array* arr_ty) {
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(arr_ty->element_type()));
const auto elem_type_id = type_mgr_->GetId(arr_ty->element_type());
auto* ast_elem_ty = ConvertType(elem_type_id);
if (ast_elem_ty == nullptr) {
return nullptr;
}
@ -648,6 +650,9 @@ ast::type::Type* ParserImpl::ConvertType(
if (!ApplyArrayDecorations(arr_ty, ast_type.get())) {
return nullptr;
}
if (remap_buffer_block_type_.count(elem_type_id)) {
remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty));
}
return ctx_.type_mgr().Get(std::move(ast_type));
}
@ -684,9 +689,17 @@ ast::type::Type* ParserImpl::ConvertType(
// Compute the struct decoration.
auto struct_decorations = this->GetDecorationsFor(type_id);
auto ast_struct_decoration = ast::StructDecoration::kNone;
if (struct_decorations.size() == 1 &&
struct_decorations[0][0] == SpvDecorationBlock) {
ast_struct_decoration = ast::StructDecoration::kBlock;
if (struct_decorations.size() == 1) {
const auto decoration = struct_decorations[0][0];
if (decoration == SpvDecorationBlock) {
ast_struct_decoration = ast::StructDecoration::kBlock;
} else if (decoration == SpvDecorationBufferBlock) {
ast_struct_decoration = ast::StructDecoration::kBlock;
remap_buffer_block_type_.insert(type_id);
} else {
Fail() << "struct with ID " << type_id
<< " has unrecognized decoration: " << int(decoration);
}
} else if (struct_decorations.size() > 1) {
Fail() << "can't handle a struct with more than one decoration: struct "
<< type_id << " has " << struct_decorations.size();
@ -751,26 +764,28 @@ ast::type::Type* ParserImpl::ConvertType(
// Set the struct name before registering it.
namer_.SuggestSanitizedName(type_id, "S");
ast_struct_type->set_name(namer_.GetName(type_id));
return ctx_.type_mgr().Get(std::move(ast_struct_type));
auto* result = ctx_.type_mgr().Get(std::move(ast_struct_type));
return result;
}
ast::type::Type* ParserImpl::ConvertType(
uint32_t type_id,
const spvtools::opt::analysis::Pointer*) {
const auto* inst = def_use_mgr_->GetDef(type_id);
const auto pointee_ty_id = inst->GetSingleWordInOperand(1);
const auto pointee_type_id = inst->GetSingleWordInOperand(1);
const auto storage_class = SpvStorageClass(inst->GetSingleWordInOperand(0));
if (pointee_ty_id == builtin_position_.struct_type_id) {
if (pointee_type_id == builtin_position_.struct_type_id) {
builtin_position_.pointer_type_id = type_id;
builtin_position_.storage_class = storage_class;
return nullptr;
}
auto* ast_elem_ty = ConvertType(pointee_ty_id);
auto* ast_elem_ty = ConvertType(pointee_type_id);
if (ast_elem_ty == nullptr) {
Fail() << "SPIR-V pointer type with ID " << type_id
<< " has invalid pointee type " << pointee_ty_id;
<< " has invalid pointee type " << pointee_type_id;
return nullptr;
}
auto ast_storage_class = enum_converter_.ToStorageClass(storage_class);
if (ast_storage_class == ast::StorageClass::kNone) {
Fail() << "SPIR-V pointer type with ID " << type_id
@ -778,6 +793,11 @@ ast::type::Type* ParserImpl::ConvertType(
<< static_cast<uint32_t>(storage_class);
return nullptr;
}
if (ast_storage_class == ast::StorageClass::kUniform &&
remap_buffer_block_type_.count(pointee_type_id)) {
ast_storage_class = ast::StorageClass::kStorageBuffer;
remap_buffer_block_type_.insert(type_id);
}
return ctx_.type_mgr().Get(
std::make_unique<ast::type::PointerType>(ast_elem_ty, ast_storage_class));
}
@ -854,7 +874,8 @@ bool ParserImpl::EmitModuleScopeVariables() {
continue;
}
const auto& var = type_or_value;
const auto spirv_storage_class = var.GetSingleWordInOperand(0);
const auto spirv_storage_class =
SpvStorageClass(var.GetSingleWordInOperand(0));
uint32_t type_id = var.type_id();
if ((type_id == builtin_position_.pointer_type_id) &&
@ -864,9 +885,21 @@ bool ParserImpl::EmitModuleScopeVariables() {
builtin_position_.per_vertex_var_id = var.result_id();
continue;
}
auto ast_storage_class = enum_converter_.ToStorageClass(
static_cast<SpvStorageClass>(spirv_storage_class));
switch (enum_converter_.ToStorageClass(spirv_storage_class)) {
case ast::StorageClass::kInput:
case ast::StorageClass::kOutput:
case ast::StorageClass::kUniform:
case ast::StorageClass::kUniformConstant:
case ast::StorageClass::kStorageBuffer:
case ast::StorageClass::kImage:
case ast::StorageClass::kWorkgroup:
case ast::StorageClass::kPrivate:
break;
default:
return Fail() << "invalid SPIR-V storage class "
<< int(spirv_storage_class)
<< " for module scope variable: " << var.PrettyPrint();
}
if (!success_) {
return false;
}
@ -881,6 +914,7 @@ bool ParserImpl::EmitModuleScopeVariables() {
<< " has non-pointer type " << var.type_id();
}
auto* ast_store_type = ast_type->AsPointer()->type();
auto ast_storage_class = ast_type->AsPointer()->storage_class();
auto ast_var =
MakeVariable(var.result_id(), ast_storage_class, ast_store_type);
if (var.NumInOperands() > 1) {

View File

@ -89,6 +89,11 @@ class ParserImpl : Reader {
/// @returns true if the parse was successful, false otherwise.
bool Parse() override;
/// @returns the Tint context.
Context& context() {
return ctx_; // Inherited from Reader
}
/// @returns the module. The module in the parser will be reset after this.
ast::Module module() override;
@ -439,6 +444,16 @@ class ParserImpl : Reader {
// [[position]] var<in> gl_Position : vec4<f32>;
// The builtin variable was detected if and only if the struct_id is non-zero.
BuiltInPositionInfo builtin_position_;
// SPIR-V type IDs that are either:
// - a struct type decorated by BufferBlock
// - an array, runtime array containing one of these
// - a pointer type to one of these
// These are the types "enclosing" a buffer block with the old style
// representation: using Uniform storage class and BufferBlock decoration
// on the struct. The new style is to use the StorageBuffer storage class
// and Block decoration.
std::unordered_set<uint32_t> remap_buffer_block_type_;
};
} // namespace spirv

View File

@ -70,7 +70,7 @@ TEST_F(SpvParserTest, ModuleScopeVar_NoVar) {
EXPECT_THAT(module_ast, Not(HasSubstr("Variable")));
}
TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass) {
TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass_NotAWebGPUStorageClass) {
auto* p = parser(test::Assemble(R"(
%float = OpTypeFloat 32
%ptr = OpTypePointer CrossWorkgroup %float
@ -84,6 +84,22 @@ TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass) {
EXPECT_THAT(p->error(), HasSubstr("unknown SPIR-V storage class: 5"));
}
TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass_Function) {
auto* p = parser(test::Assemble(R"(
%float = OpTypeFloat 32
%ptr = OpTypePointer Function %float
%52 = OpVariable %ptr Function
)"));
EXPECT_TRUE(p->BuildInternalModule());
// Normally we should run ParserImpl::RegisterTypes before emitting
// variables. But defensive coding in EmitModuleScopeVariables lets
// us catch this error.
EXPECT_FALSE(p->EmitModuleScopeVariables()) << p->error();
EXPECT_THAT(p->error(),
HasSubstr("invalid SPIR-V storage class 7 for module scope "
"variable: %52 = OpVariable %2 Function"));
}
TEST_F(SpvParserTest, ModuleScopeVar_BadPointerType) {
auto* p = parser(test::Assemble(R"(
%float = OpTypeFloat 32