diff --git a/src/reader/spirv/enum_converter.cc b/src/reader/spirv/enum_converter.cc index 26aba0ab06..b52af7bdda 100644 --- a/src/reader/spirv/enum_converter.cc +++ b/src/reader/spirv/enum_converter.cc @@ -38,7 +38,7 @@ ast::PipelineStage EnumConverter::ToPipelineStage(SpvExecutionModel model) { return ast::PipelineStage::kNone; } -ast::StorageClass EnumConverter::ToStorageClass(SpvStorageClass sc) { +ast::StorageClass EnumConverter::ToStorageClass(const SpvStorageClass sc) { switch (sc) { case SpvStorageClassInput: return ast::StorageClass::kInput; diff --git a/src/reader/spirv/enum_converter.h b/src/reader/spirv/enum_converter.h index ca7516ab09..cdf922ff69 100644 --- a/src/reader/spirv/enum_converter.h +++ b/src/reader/spirv/enum_converter.h @@ -44,7 +44,7 @@ class EnumConverter { /// On failure, logs an error and returns kNone /// @param sc the SPIR-V storage class /// @returns a Tint AST storage class - ast::StorageClass ToStorageClass(SpvStorageClass sc); + ast::StorageClass ToStorageClass(const SpvStorageClass sc); /// Converts a SPIR-V Builtin value a Tint Builtin. /// On failure, logs an error and returns kNone diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index f62c967f0b..6f8de075bc 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -631,7 +631,10 @@ bool FunctionEmitter::EmitBody() { return false; } - RegisterValuesNeedingNamedOrHoistedDefinition(); + if (!RegisterLocallyDefinedValues()) { + return false; + } + FindValuesNeedingNamedOrHoistedDefinition(); if (!EmitFunctionVariables()) { return false; @@ -2419,10 +2422,10 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info, for (auto id : sorted_by_index(block_info.hoisted_ids)) { const auto* def_inst = def_use_mgr_->GetDef(id); assert(def_inst); - AddStatement( - std::make_unique(parser_impl_.MakeVariable( - id, ast::StorageClass::kFunction, - parser_impl_.ConvertType(def_inst->type_id())))); + auto* ast_type = + RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id); + AddStatement(std::make_unique( + parser_impl_.MakeVariable(id, ast::StorageClass::kFunction, ast_type))); // Save this as an already-named value. identifier_values_.insert(id); } @@ -2580,12 +2583,14 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { expr.type = expr.type->AsPointer()->type(); return EmitConstDefOrWriteToHoistedVar(inst, std::move(expr)); } - case SpvOpCopyObject: + case SpvOpCopyObject: { // Arguably, OpCopyObject is purely combinatorial. On the other hand, // it exists to make a new name for something. So we choose to make // a new named constant definition. - return EmitConstDefOrWriteToHoistedVar( - inst, MakeExpression(inst.GetSingleWordInOperand(0))); + auto expr = MakeExpression(inst.GetSingleWordInOperand(0)); + expr.type = RemapStorageClass(expr.type, result_id); + return EmitConstDefOrWriteToHoistedVar(inst, std::move(expr)); + } case SpvOpPhi: { // Emit a read from the associated state variable. auto expr = TypedExpression( @@ -2754,7 +2759,6 @@ TypedExpression FunctionEmitter::MakeAccessChain( // ever-deeper nested indexing expressions. Start off with an expression // for the base, and then bury that inside nested indexing expressions. TypedExpression current_expr(MakeOperand(inst, 0)); - const auto constants = constant_mgr_->GetOperandConstants(&inst); static const char* swizzles[] = {"x", "y", "z", "w"}; @@ -2803,9 +2807,10 @@ TypedExpression FunctionEmitter::MakeAccessChain( // Skip past the member index that gets us to Position. first_index = first_index + 1; // Replace the gl_PerVertex reference with the gl_Position reference + ptr_ty_id = builtin_position_info.member_pointer_type_id; current_expr.expr = std::make_unique(namer_.Name(base_id)); - ptr_ty_id = builtin_position_info.member_pointer_type_id; + current_expr.type = parser_impl_.ConvertType(ptr_ty_id); } } @@ -2815,6 +2820,7 @@ TypedExpression FunctionEmitter::MakeAccessChain( << " base pointer is not of pointer type"; return {}; } + SpvStorageClass storage_class = ptr_type->AsPointer()->storage_class(); const auto* pointee_type = ptr_type->AsPointer()->pointee_type(); for (uint32_t index = first_index; index < num_in_operands; ++index) { const auto* index_const = @@ -2904,9 +2910,13 @@ TypedExpression FunctionEmitter::MakeAccessChain( << type_mgr_->GetId(pointee_type) << " " << pointee_type->str(); return {}; } - current_expr.reset(TypedExpression( - parser_impl_.ConvertType(type_mgr_->GetId(pointee_type)), - std::move(next_expr))); + const auto pointee_type_id = type_mgr_->GetId(pointee_type); + const auto pointer_type_id = + type_mgr_->FindPointerToType(pointee_type_id, storage_class); + auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id); + assert(ast_pointer_type); + assert(ast_pointer_type->IsPointer); + current_expr.reset(TypedExpression(ast_pointer_type, std::move(next_expr))); } return current_expr; } @@ -3074,7 +3084,7 @@ TypedExpression FunctionEmitter::MakeVectorShuffle( result_type, std::move(values))}; } -void FunctionEmitter::RegisterValuesNeedingNamedOrHoistedDefinition() { +bool FunctionEmitter::RegisterLocallyDefinedValues() { // Create a DefInfo for each value definition in this function. size_t index = 0; for (auto block_id : block_order_) { @@ -3087,9 +3097,72 @@ void FunctionEmitter::RegisterValuesNeedingNamedOrHoistedDefinition() { } def_info_[result_id] = std::make_unique(inst, block_pos, index); index++; + + // Determine storage class for pointer values. Do this in order because + // we might rely on the storage class for a previously-visited definition. + // Logical pointers can't be transmitted through OpPhi, so remaining + // pointer definitions are SSA values, and their definitions must be + // visited before their uses. + auto& storage_class = def_info_[result_id]->storage_class; + const auto* type = type_mgr_->GetType(inst.type_id()); + if (type && type->AsPointer()) { + const auto* ast_type = parser_impl_.ConvertType(inst.type_id()); + if (ast_type && ast_type->AsPointer()) { + storage_class = ast_type->AsPointer()->storage_class(); + } + switch (inst.opcode()) { + case SpvOpUndef: + case SpvOpVariable: + // Keep the default decision based on the result type. + break; + case SpvOpAccessChain: + case SpvOpCopyObject: + // Inherit from the first operand. We need this so we can pick up + // a remapped storage buffer. + storage_class = + GetStorageClassForPointerValue(inst.GetSingleWordInOperand(0)); + break; + default: + return Fail() << "pointer defined in function from unknown opcode: " + << inst.PrettyPrint(); + } + } } } + return true; +} +ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) { + auto where = def_info_.find(id); + if (where != def_info_.end()) { + return where->second.get()->storage_class; + } + const auto type_id = def_use_mgr_->GetDef(id)->type_id(); + if (type_id) { + auto* ast_type = parser_impl_.ConvertType(type_id); + if (ast_type && ast_type->IsPointer()) { + return ast_type->AsPointer()->storage_class(); + } + } + return ast::StorageClass::kNone; +} + +ast::type::Type* FunctionEmitter::RemapStorageClass(ast::type::Type* type, + uint32_t result_id) { + if (type->IsPointer()) { + // Remap an old-style storage buffer pointer to a new-style storage + // buffer pointer. + const auto* ast_ptr_type = type->AsPointer(); + const auto sc = GetStorageClassForPointerValue(result_id); + if (ast_ptr_type->storage_class() != sc) { + return parser_impl_.context().type_mgr().Get( + std::make_unique(ast_ptr_type->type(), sc)); + } + } + return type; +} + +void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() { // Mark vector operands of OpVectorShuffle as needing a named definition, // but only if they are defined in this function as well. for (auto& id_def_info_pair : def_info_) { diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 476818fe1b..e3bde3a465 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -33,6 +33,7 @@ #include "src/ast/expression.h" #include "src/ast/module.h" #include "src/ast/statement.h" +#include "src/ast/storage_class.h" #include "src/reader/spirv/construct.h" #include "src/reader/spirv/fail_stream.h" #include "src/reader/spirv/namer.h" @@ -192,7 +193,7 @@ inline std::ostream& operator<<(std::ostream& o, const BlockInfo& bi) { /// Bookkeeping info for a SPIR-V ID defined in the function. /// This will be valid for result IDs for: -/// - instructions that are not OpLabel, OpVariable, and OpFunctionParameter +/// - instructions that are not OpLabel, and not OpFunctionParameter /// - are defined in a basic block visited in the block-order for the function. struct DefInfo { /// Constructor. @@ -243,8 +244,15 @@ struct DefInfo { /// If the definition is an OpPhi, then |phi_var| is the name of the /// variable that stores the value carried from parent basic blocks into - // the basic block containing the OpPhi. Otherwise this is the empty string. + /// the basic block containing the OpPhi. Otherwise this is the empty string. std::string phi_var; + + /// The storage class to use for this value, if it is of pointer type. + /// This is required to carry a stroage class override from a storage + /// buffer expressed in the old style (with Uniform storage class) + /// that needs to be remapped to StorageBuffer storage class. + /// This is kNone for non-pointers. + ast::StorageClass storage_class = ast::StorageClass::kNone; }; inline std::ostream& operator<<(std::ostream& o, const DefInfo& di) { @@ -254,8 +262,11 @@ inline std::ostream& operator<<(std::ostream& o, const DefInfo& di) { << " last_use_pos: " << di.last_use_pos << " requires_named_const_def: " << (di.requires_named_const_def ? "true" : "false") << " requires_hoisted_def: " << (di.requires_hoisted_def ? "true" : "false") - << " phi_var: '" << di.phi_var << "'" - << "}"; + << " phi_var: '" << di.phi_var << "'"; + if (di.storage_class != ast::StorageClass::kNone) { + o << " sc:" << int(di.storage_class); + } + o << "}"; return o; } @@ -367,7 +378,24 @@ class FunctionEmitter { /// @returns false if bad nesting has been detected. bool FindIfSelectionInternalHeaders(); - /// Record the SPIR-V IDs of non-constants that should get a 'const' + /// Creates a DefInfo record for each locally defined SPIR-V ID. + /// Populates the |def_info_| mapping with basic results. + /// @returns false on failure + bool RegisterLocallyDefinedValues(); + + /// Returns the Tint storage class for the given SPIR-V ID that is a + /// pointer value. + /// @returns the storage class + ast::StorageClass GetStorageClassForPointerValue(uint32_t id); + + /// Remaps the storage class for the type of a locally-defined value, + /// if necessary. If it's not a pointer type, or if its storage class + /// already matches, then the result is a copy of the |type| argument. + /// @param type the AST type + /// @param result_id the SPIR-V ID for the locally defined value + ast::type::Type* RemapStorageClass(ast::type::Type* type, uint32_t result_id); + + /// Marks locally defined values when they should get a 'const' /// definition in WGSL, or a 'var' definition at an outer scope. /// This occurs in several cases: /// - When a SPIR-V instruction might use the dynamically computed value @@ -382,8 +410,8 @@ class FunctionEmitter { /// - When a definition is in a construct that does not enclose all the /// uses. In this case the definition's |requires_hoisted_def| property /// is set to true. - /// Populates the |def_info_| mapping. - void RegisterValuesNeedingNamedOrHoistedDefinition(); + /// Updates the |def_info_| mapping. + void FindValuesNeedingNamedOrHoistedDefinition(); /// Emits declarations of function variables. /// @returns false if emission failed. diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc index f3caa245d7..b3997018aa 100644 --- a/src/reader/spirv/function_composite_test.cc +++ b/src/reader/spirv/function_composite_test.cc @@ -287,10 +287,10 @@ TEST_F(SpvParserTest_CompositeExtract, Vector_IndexTooBigError) { TEST_F(SpvParserTest_CompositeExtract, Matrix) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %m3v2float - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %m3v2float %var %2 = OpCompositeExtract %v2float %1 2 OpReturn @@ -318,10 +318,10 @@ TEST_F(SpvParserTest_CompositeExtract, Matrix) { TEST_F(SpvParserTest_CompositeExtract, Matrix_IndexTooBigError) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %m3v2float - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %m3v2float %var %2 = OpCompositeExtract %v2float %1 3 OpReturn @@ -338,10 +338,10 @@ TEST_F(SpvParserTest_CompositeExtract, Matrix_IndexTooBigError) { TEST_F(SpvParserTest_CompositeExtract, Matrix_Vector) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %m3v2float - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %m3v2float %var %2 = OpCompositeExtract %float %1 2 1 OpReturn @@ -372,10 +372,10 @@ TEST_F(SpvParserTest_CompositeExtract, Matrix_Vector) { TEST_F(SpvParserTest_CompositeExtract, Array) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %a_u_5 - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %a_u_5 %var %2 = OpCompositeExtract %uint %1 3 OpReturn @@ -404,10 +404,10 @@ TEST_F(SpvParserTest_CompositeExtract, RuntimeArray_IsError) { const auto assembly = Preamble() + R"( %rtarr = OpTypeRuntimeArray %uint %ptr = OpTypePointer Function %rtarr - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %rtarr %var %2 = OpCompositeExtract %uint %1 3 OpReturn @@ -423,10 +423,10 @@ TEST_F(SpvParserTest_CompositeExtract, RuntimeArray_IsError) { TEST_F(SpvParserTest_CompositeExtract, Struct) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %s_v2f_u_i - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %s_v2f_u_i %var %2 = OpCompositeExtract %int %1 2 OpReturn @@ -454,10 +454,10 @@ TEST_F(SpvParserTest_CompositeExtract, Struct) { TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %s_v2f_u_i - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %s_v2f_u_i %var %2 = OpCompositeExtract %int %1 40 OpReturn @@ -476,10 +476,10 @@ TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) { %a_mat = OpTypeArray %m3v2float %uint_3 %s = OpTypeStruct %uint %a_mat %ptr = OpTypePointer Function %s - %var = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %var = OpVariable %ptr Function %1 = OpLoad %s %var %2 = OpCompositeExtract %float %1 1 2 0 1 OpReturn @@ -553,10 +553,10 @@ VariableDeclStatement{ TEST_F(SpvParserTest_CopyObject, Pointer) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %uint - %10 = OpVariable %ptr Function %100 = OpFunction %void None %voidfn %entry = OpLabel + %10 = OpVariable %ptr Function %1 = OpCopyObject %ptr %10 %2 = OpCopyObject %ptr %1 OpReturn diff --git a/src/reader/spirv/function_memory_test.cc b/src/reader/spirv/function_memory_test.cc index 5bdfb2c96b..662e9cb109 100644 --- a/src/reader/spirv/function_memory_test.cc +++ b/src/reader/spirv/function_memory_test.cc @@ -708,6 +708,245 @@ TEST_F(SpvParserTest, EmitStatement_AccessChain_InvalidPointeeType) { HasSubstr("Access chain with unknown pointee type %60 void")); } +std::string OldStorageBufferPreamble() { + return R"( + OpName %myvar "myvar" + + OpDecorate %struct BufferBlock + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 1 Offset 4 + OpDecorate %arr ArrayStride 4 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %uint = OpTypeInt 32 0 + + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + + %arr = OpTypeRuntimeArray %uint + %struct = OpTypeStruct %uint %arr + %ptr_struct = OpTypePointer Uniform %struct + %ptr_uint = OpTypePointer Uniform %uint + + %myvar = OpVariable %ptr_struct Uniform + )"; +} + +TEST_F(SpvParserTest, RemapStorageBuffer_TypesAndVarDeclarations) { + // Enusure we get the right module-scope declaration. This tests translation + // of the structure type, arrays of the structure, pointers to them, and + // OpVariable of these. + const auto assembly = OldStorageBufferPreamble(); + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) + << assembly << p->error(); + const auto module_str = p->module().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + Variable{ + myvar + storage_buffer + __alias_S__struct_S + } +RTArr -> __array__u32_stride_4 +S -> __struct_S)")); +} + +TEST_F(SpvParserTest, + RemapStorageBuffer_ThroughAccessChain_NonCascaded) { + const auto assembly = OldStorageBufferPreamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + + ; the scalar element + %1 = OpAccessChain %ptr_uint %myvar %uint_0 + OpStore %1 %uint_0 + + ; element in the runtime array + %2 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1 + OpStore %2 %uint_0 + + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{ + MemberAccessor{ + Identifier{myvar} + Identifier{field0} + } + ScalarConstructor{0} +} +Assignment{ + ArrayAccessor{ + MemberAccessor{ + Identifier{myvar} + Identifier{field1} + } + ScalarConstructor{1} + } + ScalarConstructor{0} +})")) << ToString(fe.ast_body()) + << p->error(); +} + +TEST_F(SpvParserTest, RemapStorageBuffer_ThroughAccessChain_Cascaded) { + const auto assembly = OldStorageBufferPreamble() + R"( + %ptr_rtarr = OpTypePointer Uniform %arr + %100 = OpFunction %void None %voidfn + %entry = OpLabel + + ; get the runtime array + %1 = OpAccessChain %ptr_rtarr %myvar %uint_1 + ; now an element in it + %2 = OpAccessChain %ptr_uint %1 %uint_1 + OpStore %2 %uint_0 + + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{ + ArrayAccessor{ + MemberAccessor{ + Identifier{myvar} + Identifier{field1} + } + ScalarConstructor{1} + } + ScalarConstructor{0} +})")) << ToString(fe.ast_body()) + << p->error(); +} + +TEST_F(SpvParserTest, + RemapStorageBuffer_ThroughCopyObject_WithoutHoisting) { + // Generates a const declaration directly. + // We have to do a bunch of storage class tracking for locally + // defined values in order to get the right pointer-to-storage-buffer + // value type for the const declration. + const auto assembly = OldStorageBufferPreamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + + %1 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1 + %2 = OpCopyObject %ptr_uint %1 + OpStore %2 %uint_0 + + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{ + Variable{ + x_2 + none + __ptr_storage_buffer__u32 + { + ArrayAccessor{ + MemberAccessor{ + Identifier{myvar} + Identifier{field1} + } + ScalarConstructor{1} + } + } + } +} +Assignment{ + Identifier{x_2} + ScalarConstructor{0} +})")) << ToString(fe.ast_body()) + << p->error(); +} + +TEST_F(SpvParserTest, RemapStorageBuffer_ThroughCopyObject_WithHoisting) { + // Like the previous test, but the declaration for the copy-object + // has its declaration hoisted. + const auto assembly = OldStorageBufferPreamble() + R"( + %bool = OpTypeBool + %cond = OpConstantTrue %bool + + %100 = OpFunction %void None %voidfn + + %entry = OpLabel + OpSelectionMerge %99 None + OpBranchConditional %cond %20 %30 + + %20 = OpLabel + %1 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1 + ; this definintion dominates the use in %99 + %2 = OpCopyObject %ptr_uint %1 + OpBranch %99 + + %30 = OpLabel + OpReturn + + %99 = OpLabel + OpStore %2 %uint_0 + OpReturn + + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{ + Variable{ + x_2 + function + __ptr_storage_buffer__u32 + } +} +If{ + ( + ScalarConstructor{true} + ) + { + Assignment{ + Identifier{x_2} + ArrayAccessor{ + MemberAccessor{ + Identifier{myvar} + Identifier{field1} + } + ScalarConstructor{1} + } + } + } +} +Else{ + { + Return{} + } +} +Assignment{ + Identifier{x_2} + ScalarConstructor{0} +} +Return{} +)")) << ToString(fe.ast_body()) + << p->error(); +} + +TEST_F(SpvParserTest, DISABLED_RemapStorageBuffer_ThroughFunctionCall) { + // TODO(dneto): Blocked on OpFunctionCall support. + // We might need this for passing pointers into atomic builtins. +} +TEST_F(SpvParserTest, DISABLED_RemapStorageBuffer_ThroughFunctionParameter) { + // TODO(dneto): Blocked on OpFunctionCall support. +} + } // namespace } // namespace spirv } // namespace reader diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 69fe861240..9c1a158b8b 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -65,6 +65,7 @@ #include "src/ast/variable.h" #include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decoration.h" +#include "src/reader/spirv/enum_converter.h" #include "src/reader/spirv/function.h" #include "src/type_manager.h" @@ -612,7 +613,8 @@ ast::type::Type* ParserImpl::ConvertType( ast::type::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Array* arr_ty) { - auto* ast_elem_ty = ConvertType(type_mgr_->GetId(arr_ty->element_type())); + const auto elem_type_id = type_mgr_->GetId(arr_ty->element_type()); + auto* ast_elem_ty = ConvertType(elem_type_id); if (ast_elem_ty == nullptr) { return nullptr; } @@ -648,6 +650,9 @@ ast::type::Type* ParserImpl::ConvertType( if (!ApplyArrayDecorations(arr_ty, ast_type.get())) { return nullptr; } + if (remap_buffer_block_type_.count(elem_type_id)) { + remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty)); + } return ctx_.type_mgr().Get(std::move(ast_type)); } @@ -684,9 +689,17 @@ ast::type::Type* ParserImpl::ConvertType( // Compute the struct decoration. auto struct_decorations = this->GetDecorationsFor(type_id); auto ast_struct_decoration = ast::StructDecoration::kNone; - if (struct_decorations.size() == 1 && - struct_decorations[0][0] == SpvDecorationBlock) { - ast_struct_decoration = ast::StructDecoration::kBlock; + if (struct_decorations.size() == 1) { + const auto decoration = struct_decorations[0][0]; + if (decoration == SpvDecorationBlock) { + ast_struct_decoration = ast::StructDecoration::kBlock; + } else if (decoration == SpvDecorationBufferBlock) { + ast_struct_decoration = ast::StructDecoration::kBlock; + remap_buffer_block_type_.insert(type_id); + } else { + Fail() << "struct with ID " << type_id + << " has unrecognized decoration: " << int(decoration); + } } else if (struct_decorations.size() > 1) { Fail() << "can't handle a struct with more than one decoration: struct " << type_id << " has " << struct_decorations.size(); @@ -751,26 +764,28 @@ ast::type::Type* ParserImpl::ConvertType( // Set the struct name before registering it. namer_.SuggestSanitizedName(type_id, "S"); ast_struct_type->set_name(namer_.GetName(type_id)); - return ctx_.type_mgr().Get(std::move(ast_struct_type)); + auto* result = ctx_.type_mgr().Get(std::move(ast_struct_type)); + return result; } ast::type::Type* ParserImpl::ConvertType( uint32_t type_id, const spvtools::opt::analysis::Pointer*) { const auto* inst = def_use_mgr_->GetDef(type_id); - const auto pointee_ty_id = inst->GetSingleWordInOperand(1); + const auto pointee_type_id = inst->GetSingleWordInOperand(1); const auto storage_class = SpvStorageClass(inst->GetSingleWordInOperand(0)); - if (pointee_ty_id == builtin_position_.struct_type_id) { + if (pointee_type_id == builtin_position_.struct_type_id) { builtin_position_.pointer_type_id = type_id; builtin_position_.storage_class = storage_class; return nullptr; } - auto* ast_elem_ty = ConvertType(pointee_ty_id); + auto* ast_elem_ty = ConvertType(pointee_type_id); if (ast_elem_ty == nullptr) { Fail() << "SPIR-V pointer type with ID " << type_id - << " has invalid pointee type " << pointee_ty_id; + << " has invalid pointee type " << pointee_type_id; return nullptr; } + auto ast_storage_class = enum_converter_.ToStorageClass(storage_class); if (ast_storage_class == ast::StorageClass::kNone) { Fail() << "SPIR-V pointer type with ID " << type_id @@ -778,6 +793,11 @@ ast::type::Type* ParserImpl::ConvertType( << static_cast(storage_class); return nullptr; } + if (ast_storage_class == ast::StorageClass::kUniform && + remap_buffer_block_type_.count(pointee_type_id)) { + ast_storage_class = ast::StorageClass::kStorageBuffer; + remap_buffer_block_type_.insert(type_id); + } return ctx_.type_mgr().Get( std::make_unique(ast_elem_ty, ast_storage_class)); } @@ -854,7 +874,8 @@ bool ParserImpl::EmitModuleScopeVariables() { continue; } const auto& var = type_or_value; - const auto spirv_storage_class = var.GetSingleWordInOperand(0); + const auto spirv_storage_class = + SpvStorageClass(var.GetSingleWordInOperand(0)); uint32_t type_id = var.type_id(); if ((type_id == builtin_position_.pointer_type_id) && @@ -864,9 +885,21 @@ bool ParserImpl::EmitModuleScopeVariables() { builtin_position_.per_vertex_var_id = var.result_id(); continue; } - - auto ast_storage_class = enum_converter_.ToStorageClass( - static_cast(spirv_storage_class)); + switch (enum_converter_.ToStorageClass(spirv_storage_class)) { + case ast::StorageClass::kInput: + case ast::StorageClass::kOutput: + case ast::StorageClass::kUniform: + case ast::StorageClass::kUniformConstant: + case ast::StorageClass::kStorageBuffer: + case ast::StorageClass::kImage: + case ast::StorageClass::kWorkgroup: + case ast::StorageClass::kPrivate: + break; + default: + return Fail() << "invalid SPIR-V storage class " + << int(spirv_storage_class) + << " for module scope variable: " << var.PrettyPrint(); + } if (!success_) { return false; } @@ -881,6 +914,7 @@ bool ParserImpl::EmitModuleScopeVariables() { << " has non-pointer type " << var.type_id(); } auto* ast_store_type = ast_type->AsPointer()->type(); + auto ast_storage_class = ast_type->AsPointer()->storage_class(); auto ast_var = MakeVariable(var.result_id(), ast_storage_class, ast_store_type); if (var.NumInOperands() > 1) { diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 6fc60a5bff..d304aad57f 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -89,6 +89,11 @@ class ParserImpl : Reader { /// @returns true if the parse was successful, false otherwise. bool Parse() override; + /// @returns the Tint context. + Context& context() { + return ctx_; // Inherited from Reader + } + /// @returns the module. The module in the parser will be reset after this. ast::Module module() override; @@ -439,6 +444,16 @@ class ParserImpl : Reader { // [[position]] var gl_Position : vec4; // The builtin variable was detected if and only if the struct_id is non-zero. BuiltInPositionInfo builtin_position_; + + // SPIR-V type IDs that are either: + // - a struct type decorated by BufferBlock + // - an array, runtime array containing one of these + // - a pointer type to one of these + // These are the types "enclosing" a buffer block with the old style + // representation: using Uniform storage class and BufferBlock decoration + // on the struct. The new style is to use the StorageBuffer storage class + // and Block decoration. + std::unordered_set remap_buffer_block_type_; }; } // namespace spirv diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index da57fd8a70..83895af7f8 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -70,7 +70,7 @@ TEST_F(SpvParserTest, ModuleScopeVar_NoVar) { EXPECT_THAT(module_ast, Not(HasSubstr("Variable"))); } -TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass) { +TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass_NotAWebGPUStorageClass) { auto* p = parser(test::Assemble(R"( %float = OpTypeFloat 32 %ptr = OpTypePointer CrossWorkgroup %float @@ -84,6 +84,22 @@ TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass) { EXPECT_THAT(p->error(), HasSubstr("unknown SPIR-V storage class: 5")); } +TEST_F(SpvParserTest, ModuleScopeVar_BadStorageClass_Function) { + auto* p = parser(test::Assemble(R"( + %float = OpTypeFloat 32 + %ptr = OpTypePointer Function %float + %52 = OpVariable %ptr Function + )")); + EXPECT_TRUE(p->BuildInternalModule()); + // Normally we should run ParserImpl::RegisterTypes before emitting + // variables. But defensive coding in EmitModuleScopeVariables lets + // us catch this error. + EXPECT_FALSE(p->EmitModuleScopeVariables()) << p->error(); + EXPECT_THAT(p->error(), + HasSubstr("invalid SPIR-V storage class 7 for module scope " + "variable: %52 = OpVariable %2 Function")); +} + TEST_F(SpvParserTest, ModuleScopeVar_BadPointerType) { auto* p = parser(test::Assemble(R"( %float = OpTypeFloat 32