diff --git a/src/tint/ast/location_attribute.cc b/src/tint/ast/location_attribute.cc index 2ea2d5d24d..e1d101b6f2 100644 --- a/src/tint/ast/location_attribute.cc +++ b/src/tint/ast/location_attribute.cc @@ -22,7 +22,10 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::LocationAttribute); namespace tint::ast { -LocationAttribute::LocationAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t val) +LocationAttribute::LocationAttribute(ProgramID pid, + NodeID nid, + const Source& src, + const ast::Expression* val) : Base(pid, nid, src), value(val) {} LocationAttribute::~LocationAttribute() = default; @@ -34,7 +37,8 @@ std::string LocationAttribute::Name() const { const LocationAttribute* LocationAttribute::Clone(CloneContext* ctx) const { // Clone arguments outside of create() call to have deterministic ordering auto src = ctx->Clone(source); - return ctx->dst->create(src, value); + auto value_ = ctx->Clone(value); + return ctx->dst->create(src, value_); } } // namespace tint::ast diff --git a/src/tint/ast/location_attribute.h b/src/tint/ast/location_attribute.h index 97c6feaf84..48c623f1a3 100644 --- a/src/tint/ast/location_attribute.h +++ b/src/tint/ast/location_attribute.h @@ -18,6 +18,7 @@ #include #include "src/tint/ast/attribute.h" +#include "src/tint/ast/expression.h" namespace tint::ast { @@ -28,8 +29,8 @@ class LocationAttribute final : public Castable { /// @param pid the identifier of the program that owns this node /// @param nid the unique node identifier /// @param src the source of this node - /// @param value the location value - LocationAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t value); + /// @param value the location value expression + LocationAttribute(ProgramID pid, NodeID nid, const Source& src, const ast::Expression* value); ~LocationAttribute() override; /// @returns the WGSL name for the attribute @@ -42,7 +43,7 @@ class LocationAttribute final : public Castable { const LocationAttribute* Clone(CloneContext* ctx) const override; /// The location value - const uint32_t value; + const ast::Expression* const value; }; } // namespace tint::ast diff --git a/src/tint/ast/location_attribute_test.cc b/src/tint/ast/location_attribute_test.cc index e0bcb39c7f..d921131863 100644 --- a/src/tint/ast/location_attribute_test.cc +++ b/src/tint/ast/location_attribute_test.cc @@ -17,11 +17,12 @@ namespace tint::ast { namespace { +using namespace tint::number_suffixes; // NOLINT using LocationAttributeTest = TestHelper; TEST_F(LocationAttributeTest, Creation) { - auto* d = create(2u); - EXPECT_EQ(2u, d->value); + auto* d = Location(2_a); + EXPECT_TRUE(d->value->Is()); } } // namespace diff --git a/src/tint/ast/variable_test.cc b/src/tint/ast/variable_test.cc index 40dd68d3de..5b3ccb055a 100644 --- a/src/tint/ast/variable_test.cc +++ b/src/tint/ast/variable_test.cc @@ -92,7 +92,7 @@ TEST_F(VariableTest, Assert_DifferentProgramID_Constructor) { } TEST_F(VariableTest, WithAttributes) { - auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Location(1u), + auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Location(1_u), Builtin(BuiltinValue::kPosition), Id(1200_u)); auto& attributes = var->attributes; @@ -102,7 +102,8 @@ TEST_F(VariableTest, WithAttributes) { auto* location = ast::GetAttribute(attributes); ASSERT_NE(nullptr, location); - EXPECT_EQ(1u, location->value); + ASSERT_NE(nullptr, location->value); + EXPECT_TRUE(location->value->Is()); } TEST_F(VariableTest, HasBindingPoint_BothProvided) { diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 4dd3c75d50..087e786b88 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc @@ -172,7 +172,7 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) { for (auto* param : sem->Parameters()) { AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol), param->Type(), param->Declaration()->attributes, - entry_point.input_variables); + param->Location(), entry_point.input_variables); entry_point.input_position_used |= ContainsBuiltin( ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes); @@ -188,7 +188,7 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) { if (!sem->ReturnType()->Is()) { AddEntryPointInOutVariables("", sem->ReturnType(), func->return_type_attributes, - entry_point.output_variables); + sem->ReturnLocation(), entry_point.output_variables); entry_point.output_sample_mask_used = ContainsBuiltin( ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes); @@ -623,6 +623,7 @@ const ast::Function* Inspector::FindEntryPointByName(const std::string& name) { void Inspector::AddEntryPointInOutVariables(std::string name, const sem::Type* type, utils::VectorRef attributes, + std::optional location, std::vector& variables) const { // Skip builtins. if (ast::HasAttribute(attributes)) { @@ -636,7 +637,7 @@ void Inspector::AddEntryPointInOutVariables(std::string name, for (auto* member : struct_ty->Members()) { AddEntryPointInOutVariables( name + "." + program_->Symbols().NameFor(member->Declaration()->symbol), - member->Type(), member->Declaration()->attributes, variables); + member->Type(), member->Declaration()->attributes, member->Location(), variables); } return; } @@ -648,10 +649,9 @@ void Inspector::AddEntryPointInOutVariables(std::string name, std::tie(stage_variable.component_type, stage_variable.composition_type) = CalculateComponentAndComposition(type); - auto* location = ast::GetAttribute(attributes); - TINT_ASSERT(Inspector, location != nullptr); + TINT_ASSERT(Inspector, location.has_value()); stage_variable.has_location_attribute = true; - stage_variable.location_attribute = location->value; + stage_variable.location_attribute = location.value(); std::tie(stage_variable.interpolation_type, stage_variable.interpolation_sampling) = CalculateInterpolationData(type, attributes); diff --git a/src/tint/inspector/inspector.h b/src/tint/inspector/inspector.h index 684852ce28..49e4bdf20b 100644 --- a/src/tint/inspector/inspector.h +++ b/src/tint/inspector/inspector.h @@ -172,10 +172,12 @@ class Inspector { /// @param name the name of the variable being added /// @param type the type of the variable /// @param attributes the variable attributes + /// @param location the location value if provided /// @param variables the list to add the variables to void AddEntryPointInOutVariables(std::string name, const sem::Type* type, utils::VectorRef attributes, + std::optional location, std::vector& variables) const; /// Recursively determine if the type contains builtin. diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc index 5bfd08eeb4..5a190f6080 100644 --- a/src/tint/inspector/inspector_test.cc +++ b/src/tint/inspector/inspector_test.cc @@ -291,7 +291,7 @@ TEST_P(InspectorGetEntryPointComponentAndCompositionTest, Test) { auto* in_var = Param("in_var", tint_type(), utils::Vector{ - Location(0u), + Location(0_u), Flat(), }); Func("foo", utils::Vector{in_var}, tint_type(), @@ -302,7 +302,7 @@ TEST_P(InspectorGetEntryPointComponentAndCompositionTest, Test) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0u), + Location(0_u), }); Inspector& inspector = Build(); @@ -336,17 +336,17 @@ INSTANTIATE_TEST_SUITE_P(InspectorGetEntryPointTest, TEST_F(InspectorGetEntryPointTest, MultipleInOutVariables) { auto* in_var0 = Param("in_var0", ty.u32(), utils::Vector{ - Location(0u), + Location(0_u), Flat(), }); auto* in_var1 = Param("in_var1", ty.u32(), utils::Vector{ - Location(1u), + Location(1_u), Flat(), }); auto* in_var4 = Param("in_var4", ty.u32(), utils::Vector{ - Location(4u), + Location(4_u), Flat(), }); Func("foo", utils::Vector{in_var0, in_var1, in_var4}, ty.u32(), @@ -357,7 +357,7 @@ TEST_F(InspectorGetEntryPointTest, MultipleInOutVariables) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0u), + Location(0_u), }); Inspector& inspector = Build(); @@ -393,7 +393,7 @@ TEST_F(InspectorGetEntryPointTest, MultipleInOutVariables) { TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) { auto* in_var_foo = Param("in_var_foo", ty.u32(), utils::Vector{ - Location(0u), + Location(0_u), Flat(), }); Func("foo", utils::Vector{in_var_foo}, ty.u32(), @@ -404,12 +404,12 @@ TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0u), + Location(0_u), }); auto* in_var_bar = Param("in_var_bar", ty.u32(), utils::Vector{ - Location(0u), + Location(0_u), Flat(), }); Func("bar", utils::Vector{in_var_bar}, ty.u32(), @@ -420,7 +420,7 @@ TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(1u), + Location(1_u), }); Inspector& inspector = Build(); @@ -464,7 +464,7 @@ TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) { }); auto* in_var1 = Param("in_var1", ty.f32(), utils::Vector{ - Location(0u), + Location(0_u), }); Func("foo", utils::Vector{in_var0, in_var1}, ty.f32(), utils::Vector{ @@ -596,8 +596,8 @@ TEST_F(InspectorGetEntryPointTest, MixInOutVariablesAndStruct) { utils::Vector{ Param("param_a", ty.Of(struct_a)), Param("param_b", ty.Of(struct_b)), - Param("param_c", ty.f32(), utils::Vector{Location(3u)}), - Param("param_d", ty.f32(), utils::Vector{Location(4u)}), + Param("param_c", ty.f32(), utils::Vector{Location(3_u)}), + Param("param_d", ty.f32(), utils::Vector{Location(4_u)}), }, ty.Of(struct_a), utils::Vector{ @@ -1136,7 +1136,7 @@ TEST_F(InspectorGetEntryPointTest, NumWorkgroupsStructReferenced) { TEST_F(InspectorGetEntryPointTest, ImplicitInterpolate) { Structure("in_struct", utils::Vector{ - Member("struct_inner", ty.f32(), utils::Vector{Location(0)}), + Member("struct_inner", ty.f32(), utils::Vector{Location(0_a)}), }); Func("ep_func", @@ -1167,7 +1167,7 @@ TEST_P(InspectorGetEntryPointInterpolateTest, Test) { "in_struct", utils::Vector{ Member("struct_inner", ty.f32(), - utils::Vector{Interpolate(params.in_type, params.in_sampling), Location(0)}), + utils::Vector{Interpolate(params.in_type, params.in_sampling), Location(0_a)}), }); Func("ep_func", diff --git a/src/tint/inspector/test_inspector_builder.cc b/src/tint/inspector/test_inspector_builder.cc index 342167c769..ce341a6c0e 100644 --- a/src/tint/inspector/test_inspector_builder.cc +++ b/src/tint/inspector/test_inspector_builder.cc @@ -54,7 +54,7 @@ const ast::Struct* InspectorBuilder::MakeInOutStruct(std::string name, std::tie(member_name, location) = var; members.Push(Member(member_name, ty.u32(), utils::Vector{ - Location(location), + Location(AInt(location)), Flat(), })); } diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index bdd2807fb4..ead13e9c9f 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -2928,17 +2928,19 @@ class ProgramBuilder { /// Creates an ast::LocationAttribute /// @param source the source information - /// @param location the location value + /// @param location the location value expression /// @returns the location attribute pointer - const ast::LocationAttribute* Location(const Source& source, uint32_t location) { - return create(source, location); + template + const ast::LocationAttribute* Location(const Source& source, EXPR&& location) { + return create(source, Expr(std::forward(location))); } /// Creates an ast::LocationAttribute - /// @param location the location value + /// @param location the location value expression /// @returns the location attribute pointer - const ast::LocationAttribute* Location(uint32_t location) { - return create(source_, location); + template + const ast::LocationAttribute* Location(EXPR&& location) { + return create(source_, Expr(std::forward(location))); } /// Creates an ast::IdAttribute diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc index 0acfa1d777..661b42c051 100644 --- a/src/tint/reader/spirv/function.cc +++ b/src/tint/reader/spirv/function.cc @@ -1109,7 +1109,9 @@ void FunctionEmitter::IncrementLocation(AttributeList* attributes) { // Replace this location attribute with a new one with one higher index. // The old one doesn't leak because it's kept in the builder's AST node // list. - attr = builder_.Location(loc_attr->source, loc_attr->value + 1); + attr = builder_.Location( + loc_attr->source, + AInt(loc_attr->value->As()->value + 1)); } } } diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc index 2942591e11..37ea5ca3f3 100644 --- a/src/tint/reader/spirv/parser_impl.cc +++ b/src/tint/reader/spirv/parser_impl.cc @@ -1723,25 +1723,22 @@ DecorationList ParserImpl::GetMemberPipelineDecorations(const Struct& struct_typ return result; } -const ast::Attribute* ParserImpl::SetLocation(AttributeList* attributes, - const ast::Attribute* replacement) { +void ParserImpl::SetLocation(AttributeList* attributes, const ast::Attribute* replacement) { if (!replacement) { - return nullptr; + return; } for (auto*& attribute : *attributes) { if (attribute->Is()) { // Replace this location attribute with the replacement. // The old one doesn't leak because it's kept in the builder's AST node // list. - const ast::Attribute* result = nullptr; - result = attribute; attribute = replacement; - return result; // Assume there is only one such decoration. + return; // Assume there is only one such decoration. } } // The list didn't have a location. Add it. attributes->Push(replacement); - return nullptr; + return; } bool ParserImpl::ConvertPipelineDecorations(const Type* store_type, @@ -1759,7 +1756,7 @@ bool ParserImpl::ConvertPipelineDecorations(const Type* store_type, return Fail() << "malformed Location decoration on ID requires one " "literal operand"; } - SetLocation(attributes, create(Source{}, deco[1])); + SetLocation(attributes, builder_.Location(AInt(deco[1]))); if (store_type->IsIntegerScalarOrVector()) { // Default to flat interpolation for integral user-defined IO types. type = ast::InterpolationType::kFlat; diff --git a/src/tint/reader/spirv/parser_impl.h b/src/tint/reader/spirv/parser_impl.h index 12d62266a8..948a9a8879 100644 --- a/src/tint/reader/spirv/parser_impl.h +++ b/src/tint/reader/spirv/parser_impl.h @@ -280,9 +280,7 @@ class ParserImpl : Reader { /// Assumes the list contains at most one Location decoration. /// @param decos the attribute list to modify /// @param replacement the location decoration to place into the list - /// @returns the location decoration that was replaced, if one was replaced, - /// or null otherwise. - const ast::Attribute* SetLocation(AttributeList* decos, const ast::Attribute* replacement); + void SetLocation(AttributeList* decos, const ast::Attribute* replacement); /// Converts a SPIR-V struct member decoration into a number of AST /// decorations. If the decoration is recognized but deliberately dropped, diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc index d57ad78495..f3b424e71b 100644 --- a/src/tint/reader/wgsl/parser_impl.cc +++ b/src/tint/reader/wgsl/parser_impl.cc @@ -3551,7 +3551,9 @@ Maybe ParserImpl::attribute() { } match(Token::Type::kComma); - return create(t.source(), val.value); + return builder_.Location(t.source(), + create( + val.value, ast::IntLiteralExpression::Suffix::kNone)); }); } diff --git a/src/tint/reader/wgsl/parser_impl_function_decl_test.cc b/src/tint/reader/wgsl/parser_impl_function_decl_test.cc index cba4a7d6c5..96669dcdc2 100644 --- a/src/tint/reader/wgsl/parser_impl_function_decl_test.cc +++ b/src/tint/reader/wgsl/parser_impl_function_decl_test.cc @@ -256,7 +256,10 @@ TEST_F(ParserImplTest, FunctionDecl_ReturnTypeAttributeList) { ASSERT_EQ(ret_type_attributes.Length(), 1u); auto* loc = ret_type_attributes[0]->As(); ASSERT_TRUE(loc != nullptr); - EXPECT_EQ(loc->value, 1u); + EXPECT_TRUE(loc->value->Is()); + + auto* exp = loc->value->As(); + EXPECT_EQ(1u, exp->value); auto* body = f->body; ASSERT_EQ(body->statements.Length(), 1u); diff --git a/src/tint/reader/wgsl/parser_impl_function_header_test.cc b/src/tint/reader/wgsl/parser_impl_function_header_test.cc index 1a8704e523..fe81317c49 100644 --- a/src/tint/reader/wgsl/parser_impl_function_header_test.cc +++ b/src/tint/reader/wgsl/parser_impl_function_header_test.cc @@ -54,9 +54,12 @@ TEST_F(ParserImplTest, FunctionHeader_AttributeReturnType) { EXPECT_EQ(f->params.Length(), 0u); EXPECT_TRUE(f->return_type->Is()); ASSERT_EQ(f->return_type_attributes.Length(), 1u); + auto* loc = f->return_type_attributes[0]->As(); ASSERT_TRUE(loc != nullptr); - EXPECT_EQ(loc->value, 1u); + ASSERT_TRUE(loc->value->Is()); + auto* exp = loc->value->As(); + EXPECT_EQ(exp->value, 1u); } TEST_F(ParserImplTest, FunctionHeader_InvariantReturnType) { diff --git a/src/tint/reader/wgsl/parser_impl_param_list_test.cc b/src/tint/reader/wgsl/parser_impl_param_list_test.cc index ce542e4c23..46c8c807ba 100644 --- a/src/tint/reader/wgsl/parser_impl_param_list_test.cc +++ b/src/tint/reader/wgsl/parser_impl_param_list_test.cc @@ -117,8 +117,12 @@ TEST_F(ParserImplTest, ParamList_Attributes) { EXPECT_TRUE(e.value[1]->Is()); auto attrs_1 = e.value[1]->attributes; ASSERT_EQ(attrs_1.Length(), 1u); - EXPECT_TRUE(attrs_1[0]->Is()); - EXPECT_EQ(attrs_1[0]->As()->value, 1u); + + ASSERT_TRUE(attrs_1[0]->Is()); + auto* attr = attrs_1[0]->As(); + ASSERT_TRUE(attr->value->Is()); + auto* loc = attr->value->As(); + EXPECT_EQ(loc->value, 1u); EXPECT_EQ(e.value[1]->source.range.begin.line, 1u); EXPECT_EQ(e.value[1]->source.range.begin.column, 52u); diff --git a/src/tint/reader/wgsl/parser_impl_variable_attribute_list_test.cc b/src/tint/reader/wgsl/parser_impl_variable_attribute_list_test.cc index 2745e5fc7b..03cd016490 100644 --- a/src/tint/reader/wgsl/parser_impl_variable_attribute_list_test.cc +++ b/src/tint/reader/wgsl/parser_impl_variable_attribute_list_test.cc @@ -31,7 +31,13 @@ TEST_F(ParserImplTest, AttributeList_Parses) { ASSERT_NE(attr_1, nullptr); ASSERT_TRUE(attr_0->Is()); - EXPECT_EQ(attr_0->As()->value, 4u); + + auto* loc = attr_0->As(); + ASSERT_TRUE(loc->value->Is()); + + auto* exp = loc->value->As(); + EXPECT_EQ(exp->value, 4u); + ASSERT_TRUE(attr_1->Is()); EXPECT_EQ(attr_1->As()->builtin, ast::BuiltinValue::kPosition); } diff --git a/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc index e48cc85c48..299f49ba9d 100644 --- a/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc +++ b/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc @@ -29,7 +29,9 @@ TEST_F(ParserImplTest, Attribute_Location) { ASSERT_TRUE(var_attr->Is()); auto* loc = var_attr->As(); - EXPECT_EQ(loc->value, 4u); + ASSERT_TRUE(loc->value->Is()); + auto* exp = loc->value->As(); + EXPECT_EQ(exp->value, 4u); } TEST_F(ParserImplTest, Attribute_Location_TrailingComma) { @@ -44,7 +46,9 @@ TEST_F(ParserImplTest, Attribute_Location_TrailingComma) { ASSERT_TRUE(var_attr->Is()); auto* loc = var_attr->As(); - EXPECT_EQ(loc->value, 4u); + ASSERT_TRUE(loc->value->Is()); + auto* exp = loc->value->As(); + EXPECT_EQ(exp->value, 4u); } TEST_F(ParserImplTest, Attribute_Location_MissingLeftParen) { diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc index 6203728ae0..b182f6d85e 100644 --- a/src/tint/resolver/attribute_validation_test.cc +++ b/src/tint/resolver/attribute_validation_test.cc @@ -104,7 +104,7 @@ static utils::Vector createAttributes(const Source& so case AttributeKind::kInvariant: return {builder.Invariant(source)}; case AttributeKind::kLocation: - return {builder.Location(source, 1)}; + return {builder.Location(source, 1_a)}; case AttributeKind::kOffset: return {builder.create(source, 4u)}; case AttributeKind::kSize: @@ -286,7 +286,7 @@ TEST_P(VertexShaderParameterAttributeTest, IsValid) { auto& params = GetParam(); auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind); if (params.kind != AttributeKind::kLocation) { - attrs.Push(Location(Source{{34, 56}}, 2)); + attrs.Push(Location(Source{{34, 56}}, 2_a)); } auto* p = Param("a", ty.vec4(), attrs); Func("vertex_main", utils::Vector{p}, ty.vec4(), @@ -388,7 +388,7 @@ using FragmentShaderReturnTypeAttributeTest = TestWithParams; TEST_P(FragmentShaderReturnTypeAttributeTest, IsValid) { auto& params = GetParam(); auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind); - attrs.Push(Location(Source{{34, 56}}, 2)); + attrs.Push(Location(Source{{34, 56}}, 2_a)); Func("frag_main", utils::Empty, ty.vec4(), utils::Vector{Return(Construct(ty.vec4()))}, utils::Vector{ @@ -495,8 +495,8 @@ TEST_F(EntryPointParameterAttributeTest, DuplicateAttribute) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(Source{{12, 34}}, 2), - Location(Source{{56, 78}}, 3), + Location(Source{{12, 34}}, 2_a), + Location(Source{{56, 78}}, 3_a), }); EXPECT_FALSE(r()->Resolve()); @@ -531,8 +531,8 @@ TEST_F(EntryPointReturnTypeAttributeTest, DuplicateAttribute) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(Source{{12, 34}}, 2), - Location(Source{{56, 78}}, 3), + Location(Source{{12, 34}}, 2_a), + Location(Source{{56, 78}}, 3_a), }); EXPECT_FALSE(r()->Resolve()); @@ -1101,7 +1101,7 @@ TEST_F(InvariantAttributeTests, InvariantWithPosition) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); } @@ -1110,7 +1110,7 @@ TEST_F(InvariantAttributeTests, InvariantWithoutPosition) { auto* param = Param("p", ty.vec4(), utils::Vector{ Invariant(Source{{12, 34}}), - Location(0), + Location(0_a), }); Func("main", utils::Vector{param}, ty.vec4(), utils::Vector{ @@ -1120,7 +1120,7 @@ TEST_F(InvariantAttributeTests, InvariantWithoutPosition) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -1219,7 +1219,7 @@ TEST_P(InterpolateParameterTest, All) { utils::Vector{ Param("a", ty.f32(), utils::Vector{ - Location(0), + Location(0_a), Interpolate(Source{{12, 34}}, params.type, params.sampling), }), }, @@ -1245,7 +1245,7 @@ TEST_P(InterpolateParameterTest, IntegerScalar) { utils::Vector{ Param("a", ty.i32(), utils::Vector{ - Location(0), + Location(0_a), Interpolate(Source{{12, 34}}, params.type, params.sampling), }), }, @@ -1276,7 +1276,7 @@ TEST_P(InterpolateParameterTest, IntegerVector) { utils::Vector{ Param("a", ty.vec4(), utils::Vector{ - Location(0), + Location(0_a), Interpolate(Source{{12, 34}}, params.type, params.sampling), }), }, @@ -1319,7 +1319,8 @@ INSTANTIATE_TEST_SUITE_P( Params{ast::InterpolationType::kFlat, ast::InterpolationSampling::kSample, false})); TEST_F(InterpolateTest, FragmentInput_Integer_MissingFlatInterpolation) { - Func("main", utils::Vector{Param(Source{{12, 34}}, "a", ty.i32(), utils::Vector{Location(0)})}, + Func("main", + utils::Vector{Param(Source{{12, 34}}, "a", ty.i32(), utils::Vector{Location(0_a)})}, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kFragment), @@ -1336,7 +1337,7 @@ TEST_F(InterpolateTest, VertexOutput_Integer_MissingFlatInterpolation) { "S", utils::Vector{ Member("pos", ty.vec4(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), - Member(Source{{12, 34}}, "u", ty.u32(), utils::Vector{Location(0)}), + Member(Source{{12, 34}}, "u", ty.u32(), utils::Vector{Location(0_a)}), }); Func("main", utils::Empty, ty.Of(s), utils::Vector{ diff --git a/src/tint/resolver/builtins_validation_test.cc b/src/tint/resolver/builtins_validation_test.cc index 631890119c..a27052cdd3 100644 --- a/src/tint/resolver/builtins_validation_test.cc +++ b/src/tint/resolver/builtins_validation_test.cc @@ -163,7 +163,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -198,7 +198,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -256,7 +256,7 @@ TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); @@ -301,7 +301,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragDepthNotF32_Struct_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); @@ -330,7 +330,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_Struct_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); @@ -372,7 +372,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleMaskIsNotU32_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(sample_mask) must be 'u32'"); @@ -400,7 +400,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Struct_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); @@ -427,7 +427,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(sample_index) must be 'u32'"); @@ -453,7 +453,7 @@ TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(position) must be 'vec4'"); @@ -745,7 +745,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); } @@ -768,7 +768,7 @@ TEST_F(ResolverBuiltinsValidationTest, FrontFacingParamIsNotBool_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); @@ -797,7 +797,7 @@ TEST_F(ResolverBuiltinsValidationTest, FrontFacingMemberIsNotBool_Fail) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_FALSE(r()->Resolve()); diff --git a/src/tint/resolver/entry_point_validation_test.cc b/src/tint/resolver/entry_point_validation_test.cc index d997912576..2179f68034 100644 --- a/src/tint/resolver/entry_point_validation_test.cc +++ b/src/tint/resolver/entry_point_validation_test.cc @@ -57,7 +57,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Location) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -110,7 +110,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Multiple) { Stage(ast::PipelineStage::kVertex), }, utils::Vector{ - Location(Source{{13, 43}}, 0), + Location(Source{{13, 43}}, 0_a), Builtin(Source{{14, 52}}, ast::BuiltinValue::kPosition), }); @@ -130,7 +130,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_Valid) { // } auto* output = Structure( "Output", utils::Vector{ - Member("a", ty.f32(), utils::Vector{Location(0)}), + Member("a", ty.f32(), utils::Vector{Location(0_a)}), Member("b", ty.f32(), utils::Vector{Builtin(ast::BuiltinValue::kFragDepth)}), }); Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), @@ -156,7 +156,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_MemberMultipleAttribu "Output", utils::Vector{ Member("a", ty.f32(), - utils::Vector{Location(Source{{13, 43}}, 0), + utils::Vector{Location(Source{{13, 43}}, 0_a), Builtin(Source{{14, 52}}, ast::BuiltinValue::kFragDepth)}), }); Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), @@ -182,11 +182,11 @@ TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_MemberMissingAttribut // fn main() -> Output { // return Output(); // } - auto* output = - Structure("Output", utils::Vector{ - Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), - Member(Source{{14, 52}}, "b", ty.f32(), {}), - }); + auto* output = Structure( + "Output", utils::Vector{ + Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}), + Member(Source{{14, 52}}, "b", ty.f32(), {}), + }); Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), utils::Vector{ Return(Construct(ty.Of(output))), @@ -235,7 +235,7 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) { // fn main(@location(0) param : f32) {} auto* param = Param("param", ty.f32(), utils::Vector{ - Location(0), + Location(0_a), }); Func(Source{{12, 34}}, "main", utils::Vector{ @@ -271,7 +271,7 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) { // fn main(@location(0) @builtin(sample_index) param : u32) {} auto* param = Param("param", ty.u32(), utils::Vector{ - Location(Source{{13, 43}}, 0), + Location(Source{{13, 43}}, 0_a), Builtin(Source{{14, 52}}, ast::BuiltinValue::kSampleIndex), }); Func(Source{{12, 34}}, "main", @@ -297,7 +297,7 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_Valid) { // fn main(param : Input) {} auto* input = Structure( "Input", utils::Vector{ - Member("a", ty.f32(), utils::Vector{Location(0)}), + Member("a", ty.f32(), utils::Vector{Location(0_a)}), Member("b", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kSampleIndex)}), }); auto* param = Param("param", ty.Of(input)); @@ -323,7 +323,7 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_MemberMultipleAttribut "Input", utils::Vector{ Member("a", ty.u32(), - utils::Vector{Location(Source{{13, 43}}, 0), + utils::Vector{Location(Source{{13, 43}}, 0_a), Builtin(Source{{14, 52}}, ast::BuiltinValue::kSampleIndex)}), }); auto* param = Param("param", ty.Of(input)); @@ -349,11 +349,11 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_MemberMissingAttribute // }; // @fragment // fn main(param : Input) {} - auto* input = - Structure("Input", utils::Vector{ - Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), - Member(Source{{14, 52}}, "b", ty.f32(), {}), - }); + auto* input = Structure( + "Input", utils::Vector{ + Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}), + Member(Source{{14, 52}}, "b", ty.f32(), {}), + }); auto* param = Param("param", ty.Of(input)); Func(Source{{12, 34}}, "main", utils::Vector{ @@ -628,7 +628,7 @@ TEST_P(TypeValidationTest, BareInputs) { auto* a = Param("a", params.create_ast_type(*this), utils::Vector{ - Location(0), + Location(0_a), Flat(), }); Func(Source{{12, 34}}, "main", @@ -657,10 +657,10 @@ TEST_P(TypeValidationTest, StructInputs) { Enable(ast::Extension::kF16); - auto* input = Structure( - "Input", utils::Vector{ - Member("a", params.create_ast_type(*this), utils::Vector{Location(0), Flat()}), - }); + auto* input = Structure("Input", utils::Vector{ + Member("a", params.create_ast_type(*this), + utils::Vector{Location(0_a), Flat()}), + }); auto* a = Param("a", ty.Of(input), {}); Func(Source{{12, 34}}, "main", utils::Vector{ @@ -695,7 +695,7 @@ TEST_P(TypeValidationTest, BareOutputs) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); if (params.is_valid) { @@ -719,7 +719,7 @@ TEST_P(TypeValidationTest, StructOutputs) { auto* output = Structure( "Output", utils::Vector{ - Member("a", params.create_ast_type(*this), utils::Vector{Location(0)}), + Member("a", params.create_ast_type(*this), utils::Vector{Location(0_a)}), }); Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), utils::Vector{ @@ -751,7 +751,7 @@ TEST_F(LocationAttributeTests, Pass) { auto* p = Param(Source{{12, 34}}, "a", ty.i32(), utils::Vector{ - Location(0), + Location(0_a), Flat(), }); Func("frag_main", @@ -772,7 +772,7 @@ TEST_F(LocationAttributeTests, BadType_Input_bool) { auto* p = Param(Source{{12, 34}}, "a", ty.bool_(), utils::Vector{ - Location(Source{{34, 56}}, 0), + Location(Source{{34, 56}}, 0_a), }); Func("frag_main", utils::Vector{ @@ -803,7 +803,7 @@ TEST_F(LocationAttributeTests, BadType_Output_Array) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(Source{{34, 56}}, 0), + Location(Source{{34, 56}}, 0_a), }); EXPECT_FALSE(r()->Resolve()); @@ -825,7 +825,7 @@ TEST_F(LocationAttributeTests, BadType_Input_Struct) { }); auto* param = Param(Source{{12, 34}}, "param", ty.Of(input), utils::Vector{ - Location(Source{{13, 43}}, 0), + Location(Source{{13, 43}}, 0_a), }); Func(Source{{12, 34}}, "main", utils::Vector{ @@ -853,10 +853,10 @@ TEST_F(LocationAttributeTests, BadType_Input_Struct_NestedStruct) { // }; // @fragment // fn main(param : Input) {} - auto* inner = - Structure("Inner", utils::Vector{ - Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), - }); + auto* inner = Structure( + "Inner", utils::Vector{ + Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}), + }); auto* input = Structure("Input", utils::Vector{ Member(Source{{14, 52}}, "a", ty.Of(inner)), }); @@ -884,7 +884,7 @@ TEST_F(LocationAttributeTests, BadType_Input_Struct_RuntimeArray) { // fn main(param : Input) {} auto* input = Structure( "Input", utils::Vector{ - Member(Source{{13, 43}}, "a", ty.array(), utils::Vector{Location(0)}), + Member(Source{{13, 43}}, "a", ty.array(), utils::Vector{Location(0_a)}), }); auto* param = Param("param", ty.Of(input)); Func(Source{{12, 34}}, "main", @@ -911,7 +911,7 @@ TEST_F(LocationAttributeTests, BadMemberType_Input) { auto* m = Member(Source{{34, 56}}, "m", ty.array(), utils::Vector{ - Location(Source{{12, 34}}, 0u), + Location(Source{{12, 34}}, 0_u), }); auto* s = Structure("S", utils::Vector{m}); auto* p = Param("a", ty.Of(s)); @@ -939,7 +939,7 @@ TEST_F(LocationAttributeTests, BadMemberType_Output) { // fn frag_main() -> S {} auto* m = Member(Source{{34, 56}}, "m", ty.atomic(), utils::Vector{ - Location(Source{{12, 34}}, 0u), + Location(Source{{12, 34}}, 0_u), }); auto* s = Structure("S", utils::Vector{m}); @@ -965,7 +965,7 @@ TEST_F(LocationAttributeTests, BadMemberType_Unused) { auto* m = Member(Source{{34, 56}}, "m", ty.mat3x2(), utils::Vector{ - Location(Source{{12, 34}}, 0u), + Location(Source{{12, 34}}, 0_u), }); Structure("S", utils::Vector{m}); @@ -988,7 +988,7 @@ TEST_F(LocationAttributeTests, ReturnType_Struct_Valid) { // } auto* output = Structure( "Output", utils::Vector{ - Member("a", ty.f32(), utils::Vector{Location(0)}), + Member("a", ty.f32(), utils::Vector{Location(0_a)}), Member("b", ty.f32(), utils::Vector{Builtin(ast::BuiltinValue::kFragDepth)}), }); Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), @@ -1021,7 +1021,7 @@ TEST_F(LocationAttributeTests, ReturnType_Struct) { Stage(ast::PipelineStage::kVertex), }, utils::Vector{ - Location(Source{{13, 43}}, 0), + Location(Source{{13, 43}}, 0_a), }); EXPECT_FALSE(r()->Resolve()); @@ -1041,10 +1041,10 @@ TEST_F(LocationAttributeTests, ReturnType_Struct_NestedStruct) { // }; // @fragment // fn main() -> Output { return Output(); } - auto* inner = - Structure("Inner", utils::Vector{ - Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), - }); + auto* inner = Structure( + "Inner", utils::Vector{ + Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}), + }); auto* output = Structure("Output", utils::Vector{ Member(Source{{14, 52}}, "a", ty.Of(inner)), }); @@ -1072,7 +1072,7 @@ TEST_F(LocationAttributeTests, ReturnType_Struct_RuntimeArray) { // } auto* output = Structure("Output", utils::Vector{ Member(Source{{13, 43}}, "a", ty.array(), - utils::Vector{Location(Source{{12, 34}}, 0)}), + utils::Vector{Location(Source{{12, 34}}, 0_a)}), }); Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), utils::Vector{ @@ -1100,7 +1100,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Input) { create(Source{{12, 34}}, Expr(1_i)), }, utils::Vector{ - Location(Source{{12, 34}}, 1), + Location(Source{{12, 34}}, 1_a), }); EXPECT_FALSE(r()->Resolve()); @@ -1110,7 +1110,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Input) { TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) { auto* input = Param("input", ty.i32(), utils::Vector{ - Location(Source{{12, 34}}, 0u), + Location(Source{{12, 34}}, 0_u), }); Func("main", utils::Vector{input}, ty.void_(), utils::Empty, utils::Vector{ @@ -1125,7 +1125,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) { TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) { auto* m = Member("m", ty.i32(), utils::Vector{ - Location(Source{{12, 34}}, 0u), + Location(Source{{12, 34}}, 0_u), }); auto* s = Structure("S", utils::Vector{m}); Func(Source{{56, 78}}, "main", utils::Empty, ty.Of(s), @@ -1146,7 +1146,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) { TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Input) { auto* m = Member("m", ty.i32(), utils::Vector{ - Location(Source{{12, 34}}, 0u), + Location(Source{{12, 34}}, 0_u), }); auto* s = Structure("S", utils::Vector{m}); auto* input = Param("input", ty.Of(s)); @@ -1168,11 +1168,11 @@ TEST_F(LocationAttributeTests, Duplicate_input) { // @location(1) param_b : f32) {} auto* param_a = Param("param_a", ty.f32(), utils::Vector{ - Location(1), + Location(1_a), }); auto* param_b = Param("param_b", ty.f32(), utils::Vector{ - Location(Source{{12, 34}}, 1), + Location(Source{{12, 34}}, 1_a), }); Func(Source{{12, 34}}, "main", utils::Vector{ @@ -1198,12 +1198,12 @@ TEST_F(LocationAttributeTests, Duplicate_struct) { // @fragment // fn main(param_a : InputA, param_b : InputB) {} auto* input_a = Structure("InputA", utils::Vector{ - Member("a", ty.f32(), utils::Vector{Location(1)}), + Member("a", ty.f32(), utils::Vector{Location(1_a)}), }); - auto* input_b = - Structure("InputB", utils::Vector{ - Member("a", ty.f32(), utils::Vector{Location(Source{{34, 56}}, 1)}), - }); + auto* input_b = Structure( + "InputB", utils::Vector{ + Member("a", ty.f32(), utils::Vector{Location(Source{{34, 56}}, 1_a)}), + }); auto* param_a = Param("param_a", ty.Of(input_a)); auto* param_b = Param("param_b", ty.Of(input_b)); Func(Source{{12, 34}}, "main", diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 8775168b87..82a047caa9 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -640,7 +640,17 @@ sem::Variable* Resolver::Var(const ast::Var* var, bool is_global) { std::optional location; if (auto* attr = ast::GetAttribute(var->attributes)) { - location = attr->value; + auto* materialize = Materialize(Expression(attr->value)); + if (!materialize) { + return nullptr; + } + auto* c = materialize->ConstantValue(); + if (!c) { + // TODO(crbug.com/tint/1633): Add error message about invalid materialization + // when location can be an expression. + return nullptr; + } + location = c->As(); } sem = builder_->create( @@ -725,7 +735,17 @@ sem::Parameter* Resolver::Parameter(const ast::Parameter* param, uint32_t index) std::optional location; if (auto* l = ast::GetAttribute(param->attributes)) { - location = l->value; + auto* materialize = Materialize(Expression(l->value)); + if (!materialize) { + return nullptr; + } + auto* c = materialize->ConstantValue(); + if (!c) { + // TODO(crbug.com/tint/1633): Add error message about invalid materialization when + // location can be an expression. + return nullptr; + } + location = c->As(); } auto* sem = builder_->create( @@ -924,7 +944,17 @@ sem::Function* Resolver::Function(const ast::Function* decl) { Mark(attr); if (auto* a = attr->As()) { - return_location = a->value; + auto* materialize = Materialize(Expression(a->value)); + if (!materialize) { + return nullptr; + } + auto* c = materialize->ConstantValue(); + if (!c) { + // TODO(crbug.com/tint/1633): Add error message about invalid materialization when + // location can be an expression. + return nullptr; + } + return_location = c->As(); } } if (!validator_.NoDuplicateAttributes(decl->attributes)) { @@ -2808,7 +2838,17 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { size = s->size; has_size_attr = true; } else if (auto* l = attr->As()) { - location = l->value; + auto* materialize = Materialize(Expression(l->value)); + if (!materialize) { + return nullptr; + } + auto* c = materialize->ConstantValue(); + if (!c) { + // TODO(crbug.com/tint/1633): Add error message about invalid materialization + // when location can be an expression. + return nullptr; + } + location = c->As(); } } diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc index 3708774d4e..1a3c623366 100644 --- a/src/tint/resolver/resolver_test.cc +++ b/src/tint/resolver/resolver_test.cc @@ -774,9 +774,9 @@ TEST_F(ResolverTest, Function_Parameters) { } TEST_F(ResolverTest, Function_Parameters_Locations) { - auto* param_a = Param("a", ty.f32(), utils::Vector{Location(3)}); + auto* param_a = Param("a", ty.f32(), utils::Vector{Location(3_a)}); auto* param_b = Param("b", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kVertexIndex)}); - auto* param_c = Param("c", ty.u32(), utils::Vector{Location(1)}); + auto* param_c = Param("c", ty.u32(), utils::Vector{Location(1_a)}); GlobalVar("my_vec", ty.vec4(), ast::StorageClass::kPrivate); auto* func = Func("my_func", @@ -809,7 +809,7 @@ TEST_F(ResolverTest, Function_Parameters_Locations) { TEST_F(ResolverTest, Function_GlobalVariable_Location) { auto* var = GlobalVar( "my_vec", ty.vec4(), ast::StorageClass::kIn, - utils::Vector{Location(3), Disable(ast::DisabledValidation::kIgnoreStorageClass)}); + utils::Vector{Location(3_a), Disable(ast::DisabledValidation::kIgnoreStorageClass)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -856,7 +856,7 @@ TEST_F(ResolverTest, Function_ReturnType_Location) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(2), + Location(2_a), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/tint/resolver/struct_pipeline_stage_use_test.cc b/src/tint/resolver/struct_pipeline_stage_use_test.cc index c8e77ea904..107b241cc0 100644 --- a/src/tint/resolver/struct_pipeline_stage_use_test.cc +++ b/src/tint/resolver/struct_pipeline_stage_use_test.cc @@ -29,7 +29,7 @@ namespace { using ResolverPipelineStageUseTest = ResolverTest; TEST_F(ResolverPipelineStageUseTest, UnusedStruct) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -39,7 +39,7 @@ TEST_F(ResolverPipelineStageUseTest, UnusedStruct) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointParam) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); Func("foo", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty, utils::Empty); @@ -51,7 +51,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointParam) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointReturnType) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); Func("foo", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))}, utils::Empty); @@ -64,7 +64,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointReturnType) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderParam) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.vec4(), utils::Vector{Return(Construct(ty.vec4()))}, @@ -96,7 +96,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderReturnType) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty, utils::Vector{Stage(ast::PipelineStage::kFragment)}); @@ -110,7 +110,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderReturnType) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); Func("main", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); @@ -160,7 +160,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedMultipleStages) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); auto* s_alias = Alias("S_alias", ty.Of(s)); Func("main", utils::Vector{Param("param", ty.Of(s_alias))}, ty.void_(), utils::Empty, @@ -175,7 +175,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamLocationSet) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3_a)})}); Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty, utils::Vector{Stage(ast::PipelineStage::kFragment)}); @@ -189,7 +189,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamLocationSet) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeViaAlias) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})}); auto* s_alias = Alias("S_alias", ty.Of(s)); Func("main", utils::Empty, ty.Of(s_alias), @@ -205,7 +205,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeViaAlias) { } TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeLocationSet) { - auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3)})}); + auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3_a)})}); Func("main", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 022f0db50c..1812789432 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -121,12 +121,13 @@ bool IsValidStorageTextureTexelFormat(ast::TexelFormat format) { } // Helper to stringify a pipeline IO attribute. -std::string attr_to_str(const ast::Attribute* attr) { +std::string attr_to_str(const ast::Attribute* attr, + std::optional location = std::nullopt) { std::stringstream str; if (auto* builtin = attr->As()) { str << "builtin(" << builtin->builtin << ")"; - } else if (auto* location = attr->As()) { - str << "location(" << location->value << ")"; + } else if (attr->Is()) { + str << "location(" << location.value() << ")"; } return str.str(); } @@ -1123,7 +1124,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) auto validate_entry_point_attributes_inner = [&](utils::VectorRef attrs, const sem::Type* ty, Source source, ParamOrRetType param_or_ret, - bool is_struct_member) { + bool is_struct_member, + std::optional location) { // Temporally forbid using f16 types in entry point IO. // TODO(tint:1473, tint:1502): Remove this error after f16 is supported in entry point // IO. @@ -1143,7 +1145,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) if (auto* builtin = attr->As()) { if (pipeline_io_attribute) { AddError("multiple entry point IO attributes", attr->source); - AddNote("previously consumed " + attr_to_str(pipeline_io_attribute), + AddNote("previously consumed " + attr_to_str(pipeline_io_attribute, location), pipeline_io_attribute->source); return false; } @@ -1162,7 +1164,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) return false; } builtins.emplace(builtin->builtin); - } else if (auto* location = attr->As()) { + } else if (auto* loc_attr = attr->As()) { if (pipeline_io_attribute) { AddError("multiple entry point IO attributes", attr->source); AddNote("previously consumed " + attr_to_str(pipeline_io_attribute), @@ -1173,7 +1175,13 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) bool is_input = param_or_ret == ParamOrRetType::kParameter; - if (!LocationAttribute(location, ty, locations, stage, source, is_input)) { + if (!location.has_value()) { + TINT_ICE(Resolver, diagnostics_) << "Location has no value"; + return false; + } + + if (!LocationAttribute(loc_attr, location.value(), ty, locations, stage, source, + is_input)) { return false; } } else if (auto* interpolate = attr->As()) { @@ -1266,9 +1274,10 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) // Outer lambda for validating the entry point attributes for a type. auto validate_entry_point_attributes = [&](utils::VectorRef attrs, const sem::Type* ty, Source source, - ParamOrRetType param_or_ret) { + ParamOrRetType param_or_ret, + std::optional location) { if (!validate_entry_point_attributes_inner(attrs, ty, source, param_or_ret, - /*is_struct_member*/ false)) { + /*is_struct_member*/ false, location)) { return false; } @@ -1277,7 +1286,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) if (!validate_entry_point_attributes_inner( member->Declaration()->attributes, member->Type(), member->Declaration()->source, param_or_ret, - /*is_struct_member*/ true)) { + /*is_struct_member*/ true, member->Location())) { AddNote("while analysing entry point '" + symbols_.NameFor(decl->symbol) + "'", decl->source); return false; @@ -1291,7 +1300,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) for (auto* param : func->Parameters()) { auto* param_decl = param->Declaration(); if (!validate_entry_point_attributes(param_decl->attributes, param->Type(), - param_decl->source, ParamOrRetType::kParameter)) { + param_decl->source, ParamOrRetType::kParameter, + param->Location())) { return false; } } @@ -1304,7 +1314,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) if (!func->ReturnType()->Is()) { if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(), - decl->source, ParamOrRetType::kReturnType)) { + decl->source, ParamOrRetType::kReturnType, + func->ReturnLocation())) { return false; } } @@ -2177,8 +2188,9 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons invariant_attribute = invariant; } else if (auto* location = attr->As()) { has_location = true; - if (!LocationAttribute(location, member->Type(), locations, stage, - member->Declaration()->source)) { + TINT_ASSERT(Resolver, member->Location().has_value()); + if (!LocationAttribute(location, member->Location().value(), member->Type(), + locations, stage, member->Declaration()->source)) { return false; } } else if (auto* builtin = attr->As()) { @@ -2220,7 +2232,8 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons return true; } -bool Validator::LocationAttribute(const ast::LocationAttribute* location, +bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr, + uint32_t location, const sem::Type* type, std::unordered_set& locations, ast::PipelineStage stage, @@ -2228,7 +2241,7 @@ bool Validator::LocationAttribute(const ast::LocationAttribute* location, const bool is_input) const { std::string inputs_or_output = is_input ? "inputs" : "output"; if (stage == ast::PipelineStage::kCompute) { - AddError("attribute is not valid for compute shader " + inputs_or_output, location->source); + AddError("attribute is not valid for compute shader " + inputs_or_output, loc_attr->source); return false; } @@ -2239,15 +2252,16 @@ bool Validator::LocationAttribute(const ast::LocationAttribute* location, AddNote( "'location' attribute must only be applied to declarations of " "numeric scalar or numeric vector type", - location->source); + loc_attr->source); return false; } - if (locations.count(location->value)) { - AddError(attr_to_str(location) + " attribute appears multiple times", location->source); + if (locations.count(location)) { + AddError(attr_to_str(loc_attr, location) + " attribute appears multiple times", + loc_attr->source); return false; } - locations.emplace(location->value); + locations.emplace(location); return true; } diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index 8bec86fa8d..a00f6ab94f 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h @@ -273,14 +273,16 @@ class Validator { bool LocalVariable(const sem::Variable* v) const; /// Validates a location attribute - /// @param location the location attribute to validate + /// @param loc_attr the location attribute to validate + /// @param location the location value /// @param type the variable type /// @param locations the set of locations in the module /// @param stage the current pipeline stage /// @param source the source of the attribute /// @param is_input true if this is an input variable /// @returns true on success, false otherwise. - bool LocationAttribute(const ast::LocationAttribute* location, + bool LocationAttribute(const ast::LocationAttribute* loc_attr, + uint32_t location, const sem::Type* type, std::unordered_set& locations, ast::PipelineStage stage, diff --git a/src/tint/sem/sem_struct_test.cc b/src/tint/sem/sem_struct_test.cc index 453b8f7ddc..97465899e7 100644 --- a/src/tint/sem/sem_struct_test.cc +++ b/src/tint/sem/sem_struct_test.cc @@ -19,6 +19,7 @@ namespace tint::sem { namespace { +using namespace tint::number_suffixes; // NOLINT using StructTest = TestHelper; TEST_F(StructTest, Creation) { @@ -107,7 +108,7 @@ TEST_F(StructTest, Layout) { TEST_F(StructTest, Location) { auto* st = Structure("st", utils::Vector{ - Member("a", ty.i32(), utils::Vector{Location(1u)}), + Member("a", ty.i32(), utils::Vector{Location(1_u)}), Member("b", ty.u32()), }); diff --git a/src/tint/transform/canonicalize_entry_point_io.cc b/src/tint/transform/canonicalize_entry_point_io.cc index 31feb86b6f..b08d44d1be 100644 --- a/src/tint/transform/canonicalize_entry_point_io.cc +++ b/src/tint/transform/canonicalize_entry_point_io.cc @@ -37,21 +37,32 @@ CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default; namespace { -// Comparison function used to reorder struct members such that all members with -// location attributes appear first (ordered by location slot), followed by -// those with builtin attributes. -bool StructMemberComparator(const ast::StructMember* a, const ast::StructMember* b) { - auto* a_loc = ast::GetAttribute(a->attributes); - auto* b_loc = ast::GetAttribute(b->attributes); - auto* a_blt = ast::GetAttribute(a->attributes); - auto* b_blt = ast::GetAttribute(b->attributes); +/// Info for a struct member +struct MemberInfo { + /// The struct member item + const ast::StructMember* member; + /// The struct member location if provided + std::optional location; +}; + +/// Comparison function used to reorder struct members such that all members with +/// location attributes appear first (ordered by location slot), followed by +/// those with builtin attributes. +/// @param a a struct member +/// @param b another struct member +/// @returns true if a comes before b +bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) { + auto* a_loc = ast::GetAttribute(a.member->attributes); + auto* b_loc = ast::GetAttribute(b.member->attributes); + auto* a_blt = ast::GetAttribute(a.member->attributes); + auto* b_blt = ast::GetAttribute(b.member->attributes); if (a_loc) { if (!b_loc) { // `a` has location attribute and `b` does not: `a` goes first. return true; } // Both have location attributes: smallest goes first. - return a_loc->value < b_loc->value; + return a.location < b.location; } else { if (b_loc) { // `b` has location attribute and `a` does not: `b` goes first. @@ -88,6 +99,8 @@ struct CanonicalizeEntryPointIO::State { utils::Vector attributes; /// The value itself. const ast::Expression* value; + /// The output location. + std::optional location; }; /// The clone context. @@ -101,14 +114,15 @@ struct CanonicalizeEntryPointIO::State { /// The new entry point wrapper function's parameters. utils::Vector wrapper_ep_parameters; + /// The members of the wrapper function's struct parameter. - utils::Vector wrapper_struct_param_members; + utils::Vector wrapper_struct_param_members; /// The name of the wrapper function's struct parameter. Symbol wrapper_struct_param_name; /// The parameters that will be passed to the original function. utils::Vector inner_call_parameters; /// The members of the wrapper function's struct return type. - utils::Vector wrapper_struct_output_members; + utils::Vector wrapper_struct_output_members; /// The wrapper function output values. utils::Vector wrapper_output_values; /// The body of the wrapper function. @@ -153,10 +167,12 @@ struct CanonicalizeEntryPointIO::State { /// Add a shader input to the entry point. /// @param name the name of the shader input /// @param type the type of the shader input + /// @param location the location if provided /// @param attributes the attributes to apply to the shader input /// @returns an expression which evaluates to the value of the shader input const ast::Expression* AddInput(std::string name, const sem::Type* type, + std::optional location, utils::Vector attributes) { auto* ast_type = CreateASTTypeFor(ctx, type); if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) { @@ -214,7 +230,7 @@ struct CanonicalizeEntryPointIO::State { Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name) : ctx.dst->Symbols().New(name); wrapper_struct_param_members.Push( - ctx.dst->Member(symbol, ast_type, std::move(attributes))); + {ctx.dst->Member(symbol, ast_type, std::move(attributes)), location}); return ctx.dst->MemberAccessor(InputStructSymbol(), symbol); } } @@ -222,10 +238,12 @@ struct CanonicalizeEntryPointIO::State { /// Add a shader output to the entry point. /// @param name the name of the shader output /// @param type the type of the shader output + /// @param location the location if provided /// @param attributes the attributes to apply to the shader output /// @param value the value of the shader output void AddOutput(std::string name, const sem::Type* type, + std::optional location, utils::Vector attributes, const ast::Expression* value) { // Vulkan requires that integer user-defined vertex outputs are always decorated with @@ -256,6 +274,7 @@ struct CanonicalizeEntryPointIO::State { output.type = CreateASTTypeFor(ctx, type); output.attributes = std::move(attributes); output.value = value; + output.location = location; wrapper_output_values.Push(output); } @@ -280,7 +299,7 @@ struct CanonicalizeEntryPointIO::State { } auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol); - auto* input_expr = AddInput(name, param->Type(), std::move(attributes)); + auto* input_expr = AddInput(name, param->Type(), param->Location(), std::move(attributes)); inner_call_parameters.Push(input_expr); } @@ -308,7 +327,8 @@ struct CanonicalizeEntryPointIO::State { auto name = ctx.src->Symbols().NameFor(member_ast->symbol); auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate); - auto* input_expr = AddInput(name, member->Type(), std::move(attributes)); + auto* input_expr = + AddInput(name, member->Type(), member->Location(), std::move(attributes)); inner_struct_values.Push(input_expr); } @@ -337,7 +357,7 @@ struct CanonicalizeEntryPointIO::State { auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate); // Extract the original structure member. - AddOutput(name, member->Type(), std::move(attributes), + AddOutput(name, member->Type(), member->Location(), std::move(attributes), ctx.dst->MemberAccessor(original_result, name)); } } else if (!inner_ret_type->Is()) { @@ -345,8 +365,8 @@ struct CanonicalizeEntryPointIO::State { CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate); // Propagate the non-struct return value as is. - AddOutput("value", func_sem->ReturnType(), std::move(attributes), - ctx.dst->Expr(original_result)); + AddOutput("value", func_sem->ReturnType(), func_sem->ReturnLocation(), + std::move(attributes), ctx.dst->Expr(original_result)); } } @@ -365,7 +385,7 @@ struct CanonicalizeEntryPointIO::State { // No existing sample mask builtin was found, so create a new output value // using the fixed sample mask. - AddOutput("fixed_sample_mask", ctx.dst->create(), + AddOutput("fixed_sample_mask", ctx.dst->create(), std::nullopt, {ctx.dst->Builtin(ast::BuiltinValue::kSampleMask)}, ctx.dst->Expr(u32(cfg.fixed_sample_mask))); } @@ -373,7 +393,7 @@ struct CanonicalizeEntryPointIO::State { /// Add a point size builtin to the wrapper function output. void AddVertexPointSize() { // Create a new output value and assign it a literal 1.0 value. - AddOutput("vertex_point_size", ctx.dst->create(), + AddOutput("vertex_point_size", ctx.dst->create(), std::nullopt, {ctx.dst->Builtin(ast::BuiltinValue::kPointSize)}, ctx.dst->Expr(1_f)); } @@ -392,10 +412,14 @@ struct CanonicalizeEntryPointIO::State { std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(), StructMemberComparator); + utils::Vector members; + for (auto& mem : wrapper_struct_param_members) { + members.Push(mem.member); + } + // Create the new struct type. auto struct_name = ctx.dst->Sym(); - auto* in_struct = - ctx.dst->create(struct_name, wrapper_struct_param_members, utils::Empty); + auto* in_struct = ctx.dst->create(struct_name, members, utils::Empty); ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct); // Create a new function parameter using this struct type. @@ -423,7 +447,8 @@ struct CanonicalizeEntryPointIO::State { member_names.insert(ctx.dst->Symbols().NameFor(name)); wrapper_struct_output_members.Push( - ctx.dst->Member(name, outval.type, std::move(outval.attributes))); + {ctx.dst->Member(name, outval.type, std::move(outval.attributes)), + outval.location}); assignments.Push( ctx.dst->Assign(ctx.dst->MemberAccessor(wrapper_result, name), outval.value)); } @@ -432,9 +457,13 @@ struct CanonicalizeEntryPointIO::State { std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(), StructMemberComparator); + utils::Vector members; + for (auto& mem : wrapper_struct_output_members) { + members.Push(mem.member); + } + // Create the new struct type. - auto* out_struct = ctx.dst->create( - ctx.dst->Sym(), wrapper_struct_output_members, utils::Empty); + auto* out_struct = ctx.dst->create(ctx.dst->Sym(), members, utils::Empty); ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct); // Create the output struct object, assign its members, and return it. diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc index 40d8d30900..3c0dce9029 100644 --- a/src/tint/transform/vertex_pulling.cc +++ b/src/tint/transform/vertex_pulling.cc @@ -692,7 +692,7 @@ struct State { /// @param func the entry point function /// @param param the parameter to process void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) { - if (auto* location = ast::GetAttribute(param->attributes)) { + if (ast::HasAttribute(param->attributes)) { // Create a function-scope variable to replace the parameter. auto func_var_sym = ctx.Clone(param->symbol); auto* func_var_type = ctx.Clone(param->type); @@ -701,8 +701,15 @@ struct State { // Capture mapping from location to the new variable. LocationInfo info; info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); }; - info.type = ctx.src->Sem().Get(param)->Type(); - location_info[location->value] = info; + + auto* sem = ctx.src->Sem().Get(param); + info.type = sem->Type(); + + if (!sem->Location().has_value()) { + TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value"; + return; + } + location_info[sem->Location().value()] = info; } else if (auto* builtin = ast::GetAttribute(param->attributes)) { // Check for existing vertex_index and instance_index builtins. if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { @@ -742,12 +749,16 @@ struct State { return ctx.dst->MemberAccessor(param_sym, member_sym); }; - if (auto* location = ast::GetAttribute(member->attributes)) { + if (ast::HasAttribute(member->attributes)) { // Capture mapping from location to struct member. LocationInfo info; info.expr = member_expr; - info.type = ctx.src->Sem().Get(member)->Type(); - location_info[location->value] = info; + + auto* sem = ctx.src->Sem().Get(member); + info.type = sem->Type(); + + TINT_ASSERT(Transform, sem->Location().has_value()); + location_info[sem->Location().value()] = info; has_locations = true; } else if (auto* builtin = ast::GetAttribute(member->attributes)) { diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index 8b2d89835d..132f542239 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -1856,7 +1856,7 @@ bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) { return Switch( global, // [&](const ast::Var* var) { - auto* sem = builder_.Sem().Get(global); + auto* sem = builder_.Sem().Get(global); switch (sem->StorageClass()) { case ast::StorageClass::kUniform: return EmitUniformVariable(var, sem); @@ -2005,7 +2005,7 @@ bool GeneratorImpl::EmitWorkgroupVariable(const sem::Variable* var) { return true; } -bool GeneratorImpl::EmitIOVariable(const sem::Variable* var) { +bool GeneratorImpl::EmitIOVariable(const sem::GlobalVariable* var) { auto* decl = var->Declaration(); if (auto* b = ast::GetAttribute(decl->attributes)) { @@ -2018,7 +2018,7 @@ bool GeneratorImpl::EmitIOVariable(const sem::Variable* var) { } auto out = line(); - EmitAttributes(out, decl->attributes); + EmitAttributes(out, var, decl->attributes); EmitInterpolationQualifiers(out, decl->attributes); auto name = builder_.Symbols().NameFor(decl->symbol); @@ -2065,15 +2065,16 @@ void GeneratorImpl::EmitInterpolationQualifiers( } bool GeneratorImpl::EmitAttributes(std::ostream& out, + const sem::GlobalVariable* var, utils::VectorRef attributes) { if (attributes.IsEmpty()) { return true; } bool first = true; for (auto* attr : attributes) { - if (auto* location = attr->As()) { + if (attr->As()) { out << (first ? "layout(" : ", "); - out << "location = " << std::to_string(location->value); + out << "location = " << std::to_string(var->Location().value()); first = false; } } diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h index 502df8b245..e70bdc244f 100644 --- a/src/tint/writer/glsl/generator_impl.h +++ b/src/tint/writer/glsl/generator_impl.h @@ -324,7 +324,7 @@ class GeneratorImpl : public TextGenerator { /// Handles emitting a global variable with the input or output storage class /// @param var the global variable /// @returns true on success - bool EmitIOVariable(const sem::Variable* var); + bool EmitIOVariable(const sem::GlobalVariable* var); /// Handles emitting interpolation qualifiers /// @param out the output of the expression stream @@ -333,9 +333,12 @@ class GeneratorImpl : public TextGenerator { utils::VectorRef attrs); /// Handles emitting attributes /// @param out the output of the expression stream + /// @param var the global variable semantics /// @param attrs the attributes /// @returns true if the attributes were emitted - bool EmitAttributes(std::ostream& out, utils::VectorRef attrs); + bool EmitAttributes(std::ostream& out, + const sem::GlobalVariable* var, + utils::VectorRef attrs); /// Handles emitting the entry point function /// @param func the entry point /// @returns true if the entry point function was emitted diff --git a/src/tint/writer/glsl/generator_impl_function_test.cc b/src/tint/writer/glsl/generator_impl_function_test.cc index c388005727..fd74e2d815 100644 --- a/src/tint/writer/glsl/generator_impl_function_test.cc +++ b/src/tint/writer/glsl/generator_impl_function_test.cc @@ -128,7 +128,7 @@ TEST_F(GlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars) // } Func("frag_main", utils::Vector{ - Param("foo", ty.f32(), utils::Vector{Location(0)}), + Param("foo", ty.f32(), utils::Vector{Location(0_a)}), }, ty.f32(), utils::Vector{ @@ -138,7 +138,7 @@ TEST_F(GlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars) Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(1), + Location(1_a), }); GeneratorImpl& gen = SanitizeAndBuild(); @@ -218,8 +218,8 @@ TEST_F(GlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_SharedStruct_Di "Interface", utils::Vector{ Member("pos", ty.vec4(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), - Member("col1", ty.f32(), utils::Vector{Location(1)}), - Member("col2", ty.f32(), utils::Vector{Location(2)}), + Member("col1", ty.f32(), utils::Vector{Location(1_a)}), + Member("col2", ty.f32(), utils::Vector{Location(2_a)}), }); Func("vert_main", utils::Empty, ty.Of(interface_struct), diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 3625db668c..b05db3e5f3 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -3947,23 +3947,24 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) { std::string pre, post; if (auto* decl = mem->Declaration()) { for (auto* attr : decl->attributes) { - if (auto* location = attr->As()) { + if (attr->Is()) { auto& pipeline_stage_uses = str->PipelineStageUses(); if (pipeline_stage_uses.size() != 1) { TINT_ICE(Writer, diagnostics_) << "invalid entry point IO struct uses"; } + auto loc = mem->Location().value(); if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) { - post += " : TEXCOORD" + std::to_string(location->value); + post += " : TEXCOORD" + std::to_string(loc); } else if (pipeline_stage_uses.count( sem::PipelineStageUsage::kVertexOutput)) { - post += " : TEXCOORD" + std::to_string(location->value); + post += " : TEXCOORD" + std::to_string(loc); } else if (pipeline_stage_uses.count( sem::PipelineStageUsage::kFragmentInput)) { - post += " : TEXCOORD" + std::to_string(location->value); + post += " : TEXCOORD" + std::to_string(loc); } else if (pipeline_stage_uses.count( sem::PipelineStageUsage::kFragmentOutput)) { - post += " : SV_Target" + std::to_string(location->value); + post += " : SV_Target" + std::to_string(loc); } else { TINT_ICE(Writer, diagnostics_) << "invalid use of location attribute"; } diff --git a/src/tint/writer/hlsl/generator_impl_function_test.cc b/src/tint/writer/hlsl/generator_impl_function_test.cc index 14e8a7072d..bcd1891d3b 100644 --- a/src/tint/writer/hlsl/generator_impl_function_test.cc +++ b/src/tint/writer/hlsl/generator_impl_function_test.cc @@ -117,7 +117,7 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars) // fn frag_main(@location(0) foo : f32) -> @location(1) f32 { // return foo; // } - auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0)}); + auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0_a)}); Func("frag_main", utils::Vector{foo_in}, ty.f32(), utils::Vector{ Return("foo"), @@ -126,7 +126,7 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars) Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(1), + Location(1_a), }); GeneratorImpl& gen = SanitizeAndBuild(); @@ -210,8 +210,8 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_SharedStruct_Di "Interface", utils::Vector{ Member("pos", ty.vec4(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), - Member("col1", ty.f32(), utils::Vector{Location(1)}), - Member("col2", ty.f32(), utils::Vector{Location(2)}), + Member("col1", ty.f32(), utils::Vector{Location(1_a)}), + Member("col2", ty.f32(), utils::Vector{Location(2_a)}), }); Func("vert_main", utils::Empty, ty.Of(interface_struct), diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 76176786ff..70bdb61927 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -2785,24 +2785,25 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) { out << " [[" << name << "]]"; return true; }, - [&](const ast::LocationAttribute* loc) { + [&](const ast::LocationAttribute*) { auto& pipeline_stage_uses = str->PipelineStageUses(); if (pipeline_stage_uses.size() != 1) { TINT_ICE(Writer, diagnostics_) << "invalid entry point IO struct uses"; return false; } + uint32_t loc = mem->Location().value(); if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) { - out << " [[attribute(" + std::to_string(loc->value) + ")]]"; + out << " [[attribute(" + std::to_string(loc) + ")]]"; } else if (pipeline_stage_uses.count( sem::PipelineStageUsage::kVertexOutput)) { - out << " [[user(locn" + std::to_string(loc->value) + ")]]"; + out << " [[user(locn" + std::to_string(loc) + ")]]"; } else if (pipeline_stage_uses.count( sem::PipelineStageUsage::kFragmentInput)) { - out << " [[user(locn" + std::to_string(loc->value) + ")]]"; + out << " [[user(locn" + std::to_string(loc) + ")]]"; } else if (pipeline_stage_uses.count( sem::PipelineStageUsage::kFragmentOutput)) { - out << " [[color(" + std::to_string(loc->value) + ")]]"; + out << " [[color(" + std::to_string(loc) + ")]]"; } else { TINT_ICE(Writer, diagnostics_) << "invalid use of location decoration"; return false; diff --git a/src/tint/writer/msl/generator_impl_function_test.cc b/src/tint/writer/msl/generator_impl_function_test.cc index fd612b3615..addd25539c 100644 --- a/src/tint/writer/msl/generator_impl_function_test.cc +++ b/src/tint/writer/msl/generator_impl_function_test.cc @@ -91,7 +91,7 @@ TEST_F(MslGeneratorImplTest, Emit_Attribute_EntryPoint_WithInOutVars) { // fn frag_main(@location(0) foo : f32) -> @location(1) f32 { // return foo; // } - auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0)}); + auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0_a)}); Func("frag_main", utils::Vector{foo_in}, ty.f32(), utils::Vector{ Return("foo"), @@ -100,7 +100,7 @@ TEST_F(MslGeneratorImplTest, Emit_Attribute_EntryPoint_WithInOutVars) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(1), + Location(1_a), }); GeneratorImpl& gen = SanitizeAndBuild(); @@ -188,8 +188,8 @@ TEST_F(MslGeneratorImplTest, Emit_Attribute_EntryPoint_SharedStruct_DifferentSta auto* interface_struct = Structure( "Interface", utils::Vector{ - Member("col1", ty.f32(), utils::Vector{Location(1)}), - Member("col2", ty.f32(), utils::Vector{Location(2)}), + Member("col1", ty.f32(), utils::Vector{Location(1_a)}), + Member("col2", ty.f32(), utils::Vector{Location(2_a)}), Member("pos", ty.vec4(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), }); diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index 04cbaacb99..3ee241da5a 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -884,9 +884,9 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* v) { U32Operand(ConvertBuiltin(builtin->builtin, sem->StorageClass()))}); return true; }, - [&](const ast::LocationAttribute* location) { + [&](const ast::LocationAttribute*) { push_annot(spv::Op::OpDecorate, {Operand(var_id), U32Operand(SpvDecorationLocation), - Operand(location->value)}); + Operand(sem->Location().value())}); return true; }, [&](const ast::InterpolateAttribute* interpolate) { diff --git a/src/tint/writer/spirv/builder_entry_point_test.cc b/src/tint/writer/spirv/builder_entry_point_test.cc index 424e89b5c2..a4128c1402 100644 --- a/src/tint/writer/spirv/builder_entry_point_test.cc +++ b/src/tint/writer/spirv/builder_entry_point_test.cc @@ -48,7 +48,7 @@ TEST_F(BuilderTest, EntryPoint_Parameters) { }); auto* loc1 = Param("loc1", ty.f32(), utils::Vector{ - Location(1u), + Location(1_u), }); auto* mul = Mul(Expr(MemberAccessor("coord", "x")), Expr("loc1")); auto* col = Var("col", ty.f32(), mul); @@ -120,7 +120,7 @@ TEST_F(BuilderTest, EntryPoint_ReturnValue) { // } auto* loc_in = Param("loc_in", ty.u32(), utils::Vector{ - Location(0), + Location(0_a), Flat(), }); auto* cond = @@ -134,7 +134,7 @@ TEST_F(BuilderTest, EntryPoint_ReturnValue) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(0), + Location(0_a), }); spirv::Builder& b = SanitizeAndBuild(); @@ -211,7 +211,7 @@ TEST_F(BuilderTest, EntryPoint_SharedStruct) { auto* interface = Structure( "Interface", utils::Vector{ - Member("value", ty.f32(), utils::Vector{Location(1u)}), + Member("value", ty.f32(), utils::Vector{Location(1_u)}), Member("pos", ty.vec4(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), }); diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc index bd75fb8f74..804d9ead88 100644 --- a/src/tint/writer/wgsl/generator_impl.cc +++ b/src/tint/writer/wgsl/generator_impl.cc @@ -756,7 +756,11 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out, return true; }, [&](const ast::LocationAttribute* location) { - out << "location(" << location->value << ")"; + out << "location("; + if (!EmitExpression(out, location->value)) { + return false; + } + out << ")"; return true; }, [&](const ast::BuiltinAttribute* builtin) { diff --git a/src/tint/writer/wgsl/generator_impl_function_test.cc b/src/tint/writer/wgsl/generator_impl_function_test.cc index 3b80e6990b..af09a137f6 100644 --- a/src/tint/writer/wgsl/generator_impl_function_test.cc +++ b/src/tint/writer/wgsl/generator_impl_function_test.cc @@ -116,7 +116,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_EntryPoint_Parameters) { }); auto* loc1 = Param("loc1", ty.f32(), utils::Vector{ - Location(1u), + Location(1_a), }); auto* func = Func("frag_main", utils::Vector{coord, loc1}, ty.void_(), utils::Empty, utils::Vector{ @@ -143,7 +143,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_EntryPoint_ReturnValue) { Stage(ast::PipelineStage::kFragment), }, utils::Vector{ - Location(1u), + Location(1_a), }); GeneratorImpl& gen = Build(); diff --git a/src/tint/writer/wgsl/generator_impl_type_test.cc b/src/tint/writer/wgsl/generator_impl_type_test.cc index 53900572ed..ef90579c55 100644 --- a/src/tint/writer/wgsl/generator_impl_type_test.cc +++ b/src/tint/writer/wgsl/generator_impl_type_test.cc @@ -274,7 +274,7 @@ TEST_F(WgslGeneratorImplTest, EmitType_Struct_WithEntryPointAttributes) { auto* s = Structure( "S", utils::Vector{ Member("a", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kVertexIndex)}), - Member("b", ty.f32(), utils::Vector{Location(2u)}), + Member("b", ty.f32(), utils::Vector{Location(2_a)}), }); GeneratorImpl& gen = Build();