From 333cea405c5e41a7ee2cca852eee914fd724113a Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Thu, 27 Apr 2023 17:21:58 +0000 Subject: [PATCH] tint/resolver: Clean up attribute resolving Attributes resolving was done ad-hoc throughout the resolver, with the validator ensuring that attributes were only applied to the correct nodes. The ad-hoc nature meant that attributes were inconsistently marked and resolved, and the attribute arguments were not always validated (especially when used internally). This change inlines the attribute processing into the appropriate places in the resolver, and uses a standardized error message for attributes that cannot be applied. Change-Id: Ic084820949bbf8276fb2d33c103fa29b77824a69 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/129620 Commit-Queue: Ben Clayton Kokoro: Kokoro Reviewed-by: Dan Sinclair --- .../resolver/attribute_validation_test.cc | 143 ++- .../resolver/entry_point_validation_test.cc | 8 +- src/tint/resolver/resolver.cc | 932 +++++++++++------- src/tint/resolver/resolver.h | 56 +- .../resolver/unresolved_identifier_test.cc | 2 +- src/tint/resolver/validator.cc | 102 +- src/tint/resolver/validator.h | 3 +- src/tint/resolver/variable_test.cc | 2 +- 8 files changed, 753 insertions(+), 495 deletions(-) diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc index 098ab1738b..b9160cea3d 100644 --- a/src/tint/resolver/attribute_validation_test.cc +++ b/src/tint/resolver/attribute_validation_test.cc @@ -131,6 +131,43 @@ static utils::Vector createAttributes(const Source& so return {}; } +static std::string name(AttributeKind kind) { + switch (kind) { + case AttributeKind::kAlign: + return "@align"; + case AttributeKind::kBinding: + return "@binding"; + case AttributeKind::kBuiltin: + return "@builtin"; + case AttributeKind::kDiagnostic: + return "@diagnostic"; + case AttributeKind::kGroup: + return "@group"; + case AttributeKind::kId: + return "@id"; + case AttributeKind::kInterpolate: + return "@interpolate"; + case AttributeKind::kInvariant: + return "@invariant"; + case AttributeKind::kLocation: + return "@location"; + case AttributeKind::kOffset: + return "@offset"; + case AttributeKind::kMustUse: + return "@must_use"; + case AttributeKind::kSize: + return "@size"; + case AttributeKind::kStage: + return "@stage"; + case AttributeKind::kStride: + return "@stride"; + case AttributeKind::kWorkgroup: + return "@workgroup_size"; + case AttributeKind::kBindingAndGroup: + return "@binding"; + } + return ""; +} namespace FunctionInputAndOutputTests { using FunctionParameterAttributeTest = TestWithParams; TEST_P(FunctionParameterAttributeTest, IsValid) { @@ -144,11 +181,16 @@ TEST_P(FunctionParameterAttributeTest, IsValid) { if (params.should_pass) { EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else if (params.kind == AttributeKind::kLocation || params.kind == AttributeKind::kBuiltin || + params.kind == AttributeKind::kInvariant || + params.kind == AttributeKind::kInterpolate) { + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "error: " + name(params.kind) + + " is not valid for non-entry point function parameters"); } else { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "error: attribute is not valid for non-entry point function " - "parameters"); + "error: " + name(params.kind) + " is not valid for function parameters"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -184,9 +226,9 @@ TEST_P(FunctionReturnTypeAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "error: attribute is not valid for non-entry point function " - "return types"); + EXPECT_EQ(r()->error(), "error: " + name(params.kind) + + " is not valid for non-entry point function " + "return types"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -234,10 +276,11 @@ TEST_P(ComputeShaderParameterAttributeTest, IsValid) { } else if (params.kind == AttributeKind::kInterpolate || params.kind == AttributeKind::kLocation || params.kind == AttributeKind::kInvariant) { - EXPECT_EQ(r()->error(), - "12:34 error: attribute is not valid for compute shader inputs"); + EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) + + " is not valid for compute shader inputs"); } else { - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for function parameters"); + EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) + + " is not valid for function parameters"); } } } @@ -277,7 +320,8 @@ TEST_P(FragmentShaderParameterAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for function parameters"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for function parameters"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -331,7 +375,8 @@ TEST_P(VertexShaderParameterAttributeTest, IsValid) { "12:34 error: invariant attribute must only be applied to a " "position builtin"); } else { - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for function parameters"); + EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) + + " is not valid for function parameters"); } } } @@ -378,12 +423,12 @@ TEST_P(ComputeShaderReturnTypeAttributeTest, IsValid) { } else if (params.kind == AttributeKind::kInterpolate || params.kind == AttributeKind::kLocation || params.kind == AttributeKind::kInvariant) { - EXPECT_EQ(r()->error(), - "12:34 error: attribute is not valid for compute shader output"); + EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) + + " is not valid for compute shader output"); } else { - EXPECT_EQ(r()->error(), - "12:34 error: attribute is not valid for entry point return " - "types"); + EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) + + " is not valid for entry point return " + "types"); } } } @@ -434,8 +479,8 @@ TEST_P(FragmentShaderReturnTypeAttributeTest, IsValid) { R"(34:56 error: duplicate location attribute 12:34 note: first attribute declared here)"); } else { - EXPECT_EQ(r()->error(), - R"(12:34 error: attribute is not valid for entry point return types)"); + EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) + + " is not valid for entry point return types"); } } } @@ -484,8 +529,8 @@ TEST_P(VertexShaderReturnTypeAttributeTest, IsValid) { R"(34:56 error: multiple entry point IO attributes 12:34 note: previously consumed @location)"); } else { - EXPECT_EQ(r()->error(), - R"(12:34 error: attribute is not valid for entry point return types)"); + EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) + + " is not valid for entry point return types"); } } } @@ -591,7 +636,8 @@ TEST_P(StructAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for struct declarations"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for struct declarations"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -628,7 +674,8 @@ TEST_P(StructMemberAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for structure members"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for struct members"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -871,7 +918,8 @@ TEST_P(ArrayAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for array types"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for array types"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -898,7 +946,6 @@ TEST_P(VariableAttributeTest, IsValid) { auto& params = GetParam(); auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind); - auto* attr = attrs[0]; if (IsBindingAttribute(params.kind)) { GlobalVar("a", ty.sampler(type::SamplerKind::kSampler), attrs); } else { @@ -910,8 +957,8 @@ TEST_P(VariableAttributeTest, IsValid) { } else { EXPECT_FALSE(r()->Resolve()); if (!IsBindingAttribute(params.kind)) { - EXPECT_EQ(r()->error(), "12:34 error: attribute '" + attr->Name() + - "' is not valid for module-scope 'var'"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for module-scope 'var'"); } } } @@ -944,13 +991,22 @@ TEST_F(VariableAttributeTest, DuplicateAttribute) { 12:34 note: first attribute declared here)"); } -TEST_F(VariableAttributeTest, LocalVariable) { +TEST_F(VariableAttributeTest, LocalVar) { auto* v = Var("a", ty.f32(), utils::Vector{Binding(Source{{12, 34}}, 2_a)}); WrapInFunction(v); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attributes are not valid on local variables"); + EXPECT_EQ(r()->error(), "12:34 error: @binding is not valid for function-scope 'var'"); +} + +TEST_F(VariableAttributeTest, LocalLet) { + auto* v = Let("a", utils::Vector{Binding(Source{{12, 34}}, 2_a)}, Expr(1_a)); + + WrapInFunction(v); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: @binding is not valid for 'let' declaration"); } using ConstantAttributeTest = TestWithParams; @@ -965,7 +1021,7 @@ TEST_P(ConstantAttributeTest, IsValid) { } else { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: attribute is not valid for module-scope 'const' declaration"); + "12:34 error: " + name(params.kind) + " is not valid for 'const' declaration"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -987,17 +1043,14 @@ INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, TestParams{AttributeKind::kWorkgroup, false}, TestParams{AttributeKind::kBindingAndGroup, false})); -TEST_F(ConstantAttributeTest, DuplicateAttribute) { +TEST_F(ConstantAttributeTest, InvalidAttribute) { GlobalConst("a", ty.f32(), Expr(1.23_f), utils::Vector{ Id(Source{{12, 34}}, 0_a), - Id(Source{{56, 78}}, 1_a), }); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - R"(56:78 error: duplicate id attribute -12:34 note: first attribute declared here)"); + EXPECT_EQ(r()->error(), "12:34 error: @id is not valid for 'const' declaration"); } using OverrideAttributeTest = TestWithParams; @@ -1010,7 +1063,8 @@ TEST_P(OverrideAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for 'override' declaration"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for 'override' declaration"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -1056,7 +1110,8 @@ TEST_P(SwitchStatementAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for switch statements"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for switch statements"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -1089,7 +1144,8 @@ TEST_P(SwitchBodyAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for switch body"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for switch body"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -1122,7 +1178,8 @@ TEST_P(IfStatementAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for if statements"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for if statements"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -1155,7 +1212,8 @@ TEST_P(ForStatementAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for for statements"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for for statements"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -1188,7 +1246,8 @@ TEST_P(LoopStatementAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for loop statements"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for loop statements"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -1221,7 +1280,8 @@ TEST_P(WhileStatementAttributeTest, IsValid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for while statements"); + EXPECT_EQ(r()->error(), + "12:34 error: " + name(params.kind) + " is not valid for while statements"); } } INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest, @@ -1251,7 +1311,8 @@ class BlockStatementTest : public TestWithParams { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "error: attribute is not valid for block statements"); + EXPECT_EQ(r()->error(), + "error: " + name(GetParam().kind) + " is not valid for block statements"); } } }; diff --git a/src/tint/resolver/entry_point_validation_test.cc b/src/tint/resolver/entry_point_validation_test.cc index dac6cd2e84..eb2d779012 100644 --- a/src/tint/resolver/entry_point_validation_test.cc +++ b/src/tint/resolver/entry_point_validation_test.cc @@ -1084,7 +1084,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Input) { }); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: attribute is not valid for compute shader output)"); + EXPECT_EQ(r()->error(), R"(12:34 error: @location is not valid for compute shader output)"); } TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) { @@ -1099,7 +1099,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) { }); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: attribute is not valid for compute shader inputs)"); + EXPECT_EQ(r()->error(), R"(12:34 error: @location is not valid for compute shader inputs)"); } TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) { @@ -1119,7 +1119,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: attribute is not valid for compute shader output\n" + "12:34 error: @location is not valid for compute shader output\n" "56:78 note: while analyzing entry point 'main'"); } @@ -1138,7 +1138,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Input) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: attribute is not valid for compute shader inputs\n" + "12:34 error: @location is not valid for compute shader inputs\n" "56:78 note: while analyzing entry point 'main'"); } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 99e68d0cbf..638cac0cbd 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -247,6 +247,20 @@ sem::Variable* Resolver::Let(const ast::Let* v, bool is_global) { } } + for (auto* attribute : v->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, // + [&](const ast::InternalAttribute* attr) -> bool { return InternalAttribute(attr); }, + [&](Default) { + ErrorInvalidAttribute(attribute, "'let' declaration"); + return false; + }); + if (!ok) { + return nullptr; + } + } + if (!v->initializer) { AddError("'let' declaration must have an initializer", v->source); return nullptr; @@ -340,37 +354,51 @@ sem::Variable* Resolver::Override(const ast::Override* v) { /* constant_value */ nullptr, std::nullopt, std::nullopt); sem->SetInitializer(rhs); - if (auto* id_attr = ast::GetAttribute(v->attributes)) { - ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@id"}; - TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); + for (auto* attribute : v->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, // + [&](const ast::IdAttribute* attr) { + ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@id"}; + TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - auto* materialized = Materialize(ValueExpression(id_attr->expr)); - if (!materialized) { + auto* materialized = Materialize(ValueExpression(attr->expr)); + if (!materialized) { + return false; + } + if (!materialized->Type()->IsAnyOf()) { + AddError("@id must be an i32 or u32 value", attr->source); + return false; + } + + auto const_value = materialized->ConstantValue(); + auto value = const_value->ValueAs(); + if (value < 0) { + AddError("@id value must be non-negative", attr->source); + return false; + } + if (value > std::numeric_limits::max()) { + AddError( + "@id value must be between 0 and " + + std::to_string(std::numeric_limits::max()), + attr->source); + return false; + } + + auto o = OverrideId{static_cast(value)}; + sem->SetOverrideId(o); + + // Track the constant IDs that are specified in the shader. + override_ids_.Add(o, sem); + return true; + }, + [&](Default) { + ErrorInvalidAttribute(attribute, "'override' declaration"); + return false; + }); + if (!ok) { return nullptr; } - if (!materialized->Type()->IsAnyOf()) { - AddError("@id must be an i32 or u32 value", id_attr->source); - return nullptr; - } - - auto const_value = materialized->ConstantValue(); - auto value = const_value->ValueAs(); - if (value < 0) { - AddError("@id value must be non-negative", id_attr->source); - return nullptr; - } - if (value > std::numeric_limits::max()) { - AddError("@id value must be between 0 and " + - std::to_string(std::numeric_limits::max()), - id_attr->source); - return nullptr; - } - - auto o = OverrideId{static_cast(value)}; - sem->SetOverrideId(o); - - // Track the constant IDs that are specified in the shader. - override_ids_.Add(o, sem); } builder_->Sem().Add(v, sem); @@ -393,6 +421,18 @@ sem::Variable* Resolver::Const(const ast::Const* c, bool is_global) { return nullptr; } + for (auto* attribute : c->attributes) { + Mark(attribute); + bool ok = Switch(attribute, // + [&](Default) { + ErrorInvalidAttribute(attribute, "'const' declaration"); + return false; + }); + if (!ok) { + return nullptr; + } + } + const sem::ValueExpression* rhs = nullptr; { ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "const initializer"}; @@ -529,72 +569,98 @@ sem::Variable* Resolver::Var(const ast::Var* var, bool is_global) { sem::Variable* sem = nullptr; if (is_global) { + bool has_io_address_space = address_space == builtin::AddressSpace::kIn || + address_space == builtin::AddressSpace::kOut; + + std::optional group, binding, location; + for (auto* attribute : var->attributes) { + Mark(attribute); + enum Status { kSuccess, kErrored, kInvalid }; + auto res = Switch( + attribute, // + [&](const ast::BindingAttribute* attr) { + auto value = BindingAttribute(attr); + if (!value) { + return kErrored; + } + binding = value.Get(); + return kSuccess; + }, + [&](const ast::GroupAttribute* attr) { + auto value = GroupAttribute(attr); + if (!value) { + return kErrored; + } + group = value.Get(); + return kSuccess; + }, + [&](const ast::LocationAttribute* attr) { + if (!has_io_address_space) { + return kInvalid; + } + auto value = LocationAttribute(attr); + if (!value) { + return kErrored; + } + location = value.Get(); + return kSuccess; + }, + [&](const ast::BuiltinAttribute* attr) { + if (!has_io_address_space) { + return kInvalid; + } + return BuiltinAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::InterpolateAttribute* attr) { + if (!has_io_address_space) { + return kInvalid; + } + return InterpolateAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::InvariantAttribute* attr) { + if (!has_io_address_space) { + return kInvalid; + } + return InvariantAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::InternalAttribute* attr) { + return InternalAttribute(attr) ? kSuccess : kErrored; + }, + [&](Default) { return kInvalid; }); + + switch (res) { + case kSuccess: + break; + case kErrored: + return nullptr; + case kInvalid: + ErrorInvalidAttribute(attribute, "module-scope 'var'"); + return nullptr; + } + } + std::optional binding_point; - if (var->HasBindingPoint()) { - uint32_t binding = 0; - { - ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@binding"}; - TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - - auto* attr = ast::GetAttribute(var->attributes); - auto* materialized = Materialize(ValueExpression(attr->expr)); - if (!materialized) { - return nullptr; - } - if (!materialized->Type()->IsAnyOf()) { - AddError("@binding must be an i32 or u32 value", attr->source); - return nullptr; - } - - auto const_value = materialized->ConstantValue(); - auto value = const_value->ValueAs(); - if (value < 0) { - AddError("@binding value must be non-negative", attr->source); - return nullptr; - } - binding = u32(value); - } - - uint32_t group = 0; - { - ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@group"}; - TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - - auto* attr = ast::GetAttribute(var->attributes); - auto* materialized = Materialize(ValueExpression(attr->expr)); - if (!materialized) { - return nullptr; - } - if (!materialized->Type()->IsAnyOf()) { - AddError("@group must be an i32 or u32 value", attr->source); - return nullptr; - } - - auto const_value = materialized->ConstantValue(); - auto value = const_value->ValueAs(); - if (value < 0) { - AddError("@group value must be non-negative", attr->source); - return nullptr; - } - group = u32(value); - } - binding_point = {group, binding}; + if (group && binding) { + binding_point = sem::BindingPoint{group.value(), binding.value()}; } - - std::optional location; - if (auto* attr = ast::GetAttribute(var->attributes)) { - auto value = LocationAttribute(attr); - if (!value) { - return nullptr; - } - location = value.Get(); - } - sem = builder_->create( var, var_ty, sem::EvaluationStage::kRuntime, address_space, access, /* constant_value */ nullptr, binding_point, location); } else { + for (auto* attribute : var->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, + [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); }, + [&](Default) { + ErrorInvalidAttribute(attribute, "function-scope 'var'"); + return false; + }); + if (!ok) { + return nullptr; + } + } sem = builder_->create(var, var_ty, sem::EvaluationStage::kRuntime, address_space, access, current_statement_, /* constant_value */ nullptr); @@ -605,18 +671,93 @@ sem::Variable* Resolver::Var(const ast::Var* var, bool is_global) { return sem; } -sem::Parameter* Resolver::Parameter(const ast::Parameter* param, uint32_t index) { +sem::Parameter* Resolver::Parameter(const ast::Parameter* param, + const ast::Function* func, + uint32_t index) { Mark(param->name); auto add_note = [&] { AddNote("while instantiating parameter " + param->name->symbol.Name(), param->source); }; - for (auto* attr : param->attributes) { - if (!Attribute(attr)) { - return nullptr; + std::optional location, group, binding; + + if (func->IsEntryPoint()) { + for (auto* attribute : param->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, // + [&](const ast::LocationAttribute* attr) { + auto value = LocationAttribute(attr); + if (!value) { + return false; + } + location = value.Get(); + return true; + }, + [&](const ast::BuiltinAttribute* attr) -> bool { return BuiltinAttribute(attr); }, + [&](const ast::InvariantAttribute* attr) -> bool { + return InvariantAttribute(attr); + }, + [&](const ast::InterpolateAttribute* attr) -> bool { + return InterpolateAttribute(attr); + }, + [&](const ast::InternalAttribute* attr) -> bool { return InternalAttribute(attr); }, + [&](const ast::GroupAttribute* attr) -> bool { + if (validator_.IsValidationEnabled( + param->attributes, ast::DisabledValidation::kEntryPointParameter)) { + ErrorInvalidAttribute(attribute, "function parameters"); + return false; + } + auto value = GroupAttribute(attr); + if (!value) { + return false; + } + group = value.Get(); + return true; + }, + [&](const ast::BindingAttribute* attr) -> bool { + if (validator_.IsValidationEnabled( + param->attributes, ast::DisabledValidation::kEntryPointParameter)) { + ErrorInvalidAttribute(attribute, "function parameters"); + return false; + } + auto value = BindingAttribute(attr); + if (!value) { + return false; + } + binding = value.Get(); + return true; + }, + [&](Default) { + ErrorInvalidAttribute(attribute, "function parameters"); + return false; + }); + if (!ok) { + return nullptr; + } + } + } else { + for (auto* attribute : param->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, // + [&](const ast::InternalAttribute* attr) -> bool { return InternalAttribute(attr); }, + [&](Default) { + if (attribute->IsAnyOf()) { + ErrorInvalidAttribute(attribute, "non-entry point function parameters"); + } else { + ErrorInvalidAttribute(attribute, "function parameters"); + } + return false; + }); + if (!ok) { + return nullptr; + } } } + if (!validator_.NoDuplicateAttributes(param->attributes)) { return nullptr; } @@ -642,72 +783,22 @@ sem::Parameter* Resolver::Parameter(const ast::Parameter* param, uint32_t index) } std::optional binding_point; - if (param->HasBindingPoint()) { - binding_point = sem::BindingPoint{}; - { - ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@binding value"}; - TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - - auto* attr = ast::GetAttribute(param->attributes); - auto* materialized = Materialize(ValueExpression(attr->expr)); - if (!materialized) { - return nullptr; - } - binding_point->binding = materialized->ConstantValue()->ValueAs(); - } - { - ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@group value"}; - TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - - auto* attr = ast::GetAttribute(param->attributes); - auto* materialized = Materialize(ValueExpression(attr->expr)); - if (!materialized) { - return nullptr; - } - binding_point->group = materialized->ConstantValue()->ValueAs(); - } - } - - std::optional location; - if (auto* attr = ast::GetAttribute(param->attributes)) { - auto value = LocationAttribute(attr); - if (!value) { - return nullptr; - } - location = value.Get(); + if (group && binding) { + binding_point = sem::BindingPoint{group.value(), binding.value()}; } auto* sem = builder_->create( param, index, ty, builtin::AddressSpace::kUndefined, builtin::Access::kUndefined, sem::ParameterUsage::kNone, binding_point, location); builder_->Sem().Add(param, sem); + + if (!validator_.Parameter(sem)) { + return nullptr; + } + return sem; } -utils::Result Resolver::LocationAttribute(const ast::LocationAttribute* attr) { - ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@location value"}; - TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - - auto* materialized = Materialize(ValueExpression(attr->expr)); - if (!materialized) { - return utils::Failure; - } - - if (!materialized->Type()->IsAnyOf()) { - AddError("@location must be an i32 or u32 value", attr->source); - return utils::Failure; - } - - auto const_value = materialized->ConstantValue(); - auto value = const_value->ValueAs(); - if (value < 0) { - AddError("@location value must be non-negative", attr->source); - return utils::Failure; - } - - return static_cast(value); -} - builtin::Access Resolver::DefaultAccessForAddressSpace(builtin::AddressSpace address_space) { // https://gpuweb.github.io/gpuweb/wgsl/#storage-class switch (address_space) { @@ -796,12 +887,6 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { return nullptr; } - for (auto* attr : v->attributes) { - if (!Attribute(attr)) { - return nullptr; - } - } - if (!validator_.NoDuplicateAttributes(v->attributes)) { return nullptr; } @@ -860,8 +945,28 @@ sem::Function* Resolver::Function(const ast::Function* decl) { validator_.DiagnosticFilters().Push(); TINT_DEFER(validator_.DiagnosticFilters().Pop()); - for (auto* attr : decl->attributes) { - if (!Attribute(attr)) { + + for (auto* attribute : decl->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, + [&](const ast::DiagnosticAttribute* attr) { return DiagnosticAttribute(attr); }, + [&](const ast::StageAttribute* attr) { return StageAttribute(attr); }, + [&](const ast::MustUseAttribute* attr) { return MustUseAttribute(attr); }, + [&](const ast::WorkgroupAttribute* attr) { + auto value = WorkgroupAttribute(attr); + if (!value) { + return false; + } + func->SetWorkgroupSize(value.Get()); + return true; + }, + [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); }, + [&](Default) { + ErrorInvalidAttribute(attribute, "functions"); + return false; + }); + if (!ok) { return nullptr; } } @@ -884,15 +989,11 @@ sem::Function* Resolver::Function(const ast::Function* decl) { } } - auto* p = Parameter(param, parameter_index++); + auto* p = Parameter(param, decl, parameter_index++); if (!p) { return nullptr; } - if (!validator_.Parameter(decl, p)) { - return nullptr; - } - func->AddParameter(p); auto* p_ty = const_cast(p->Type()); @@ -925,18 +1026,73 @@ sem::Function* Resolver::Function(const ast::Function* decl) { } func->SetReturnType(return_type); - // Determine if the return type has a location - for (auto* attr : decl->return_type_attributes) { - if (!Attribute(attr)) { - return nullptr; - } + if (decl->IsEntryPoint()) { + // Determine if the return type has a location + bool permissive = validator_.IsValidationDisabled( + decl->attributes, ast::DisabledValidation::kEntryPointParameter) || + validator_.IsValidationDisabled( + decl->attributes, ast::DisabledValidation::kFunctionParameter); + for (auto* attribute : decl->return_type_attributes) { + Mark(attribute); + enum Status { kSuccess, kErrored, kInvalid }; + auto res = Switch( + attribute, // + [&](const ast::LocationAttribute* attr) { + auto value = LocationAttribute(attr); + if (!value) { + return kErrored; + } + func->SetReturnLocation(value.Get()); + return kSuccess; + }, + [&](const ast::BuiltinAttribute* attr) { + return BuiltinAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::InternalAttribute* attr) { + return InternalAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::InterpolateAttribute* attr) { + return InterpolateAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::InvariantAttribute* attr) { + return InvariantAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::BindingAttribute* attr) { + if (!permissive) { + return kInvalid; + } + return BindingAttribute(attr) ? kSuccess : kErrored; + }, + [&](const ast::GroupAttribute* attr) { + if (!permissive) { + return kInvalid; + } + return GroupAttribute(attr) ? kSuccess : kErrored; + }, + [&](Default) { return kInvalid; }); - if (auto* loc_attr = attr->As()) { - auto value = LocationAttribute(loc_attr); - if (!value) { + switch (res) { + case kSuccess: + break; + case kErrored: + return nullptr; + case kInvalid: + ErrorInvalidAttribute(attribute, "entry point return types"); + return nullptr; + } + } + } else { + for (auto* attribute : decl->return_type_attributes) { + Mark(attribute); + bool ok = Switch(attribute, // + [&](Default) { + ErrorInvalidAttribute(attribute, + "non-entry point function return types"); + return false; + }); + if (!ok) { return nullptr; } - func->SetReturnLocation(value.Get()); } } @@ -964,10 +1120,6 @@ sem::Function* Resolver::Function(const ast::Function* decl) { ApplyDiagnosticSeverities(func); - if (!WorkgroupSize(decl)) { - return nullptr; - } - if (decl->IsEntryPoint()) { entry_points_.Push(func); } @@ -1016,94 +1168,6 @@ sem::Function* Resolver::Function(const ast::Function* decl) { return func; } -bool Resolver::WorkgroupSize(const ast::Function* func) { - // Set work-group size defaults. - sem::WorkgroupSize ws; - for (size_t i = 0; i < 3; i++) { - ws[i] = 1; - } - - auto* attr = ast::GetAttribute(func->attributes); - if (!attr) { - return true; - } - - auto values = attr->Values(); - utils::Vector args; - utils::Vector arg_tys; - - constexpr const char* kErrBadExpr = - "workgroup_size argument must be a constant or override-expression of type " - "abstract-integer, i32 or u32"; - - for (size_t i = 0; i < 3; i++) { - // Each argument to this attribute can either be a literal, an identifier for a - // module-scope constants, a const-expression, or nullptr if not specified. - auto* value = values[i]; - if (!value) { - break; - } - const auto* expr = ValueExpression(value); - if (!expr) { - return false; - } - auto* ty = expr->Type(); - if (!ty->IsAnyOf()) { - AddError(kErrBadExpr, value->source); - return false; - } - - if (expr->Stage() != sem::EvaluationStage::kConstant && - expr->Stage() != sem::EvaluationStage::kOverride) { - AddError(kErrBadExpr, value->source); - return false; - } - - args.Push(expr); - arg_tys.Push(ty); - } - - auto* common_ty = type::Type::Common(arg_tys); - if (!common_ty) { - AddError("workgroup_size arguments must be of the same type, either i32 or u32", - attr->source); - return false; - } - - // If all arguments are abstract-integers, then materialize to i32. - if (common_ty->Is()) { - common_ty = builder_->create(); - } - - for (size_t i = 0; i < args.Length(); i++) { - auto* materialized = Materialize(args[i], common_ty); - if (!materialized) { - return false; - } - if (auto* value = materialized->ConstantValue()) { - if (value->ValueAs() < 1) { - AddError("workgroup_size argument must be at least 1", values[i]->source); - return false; - } - ws[i] = value->ValueAs(); - } else { - ws[i] = std::nullopt; - } - } - - uint64_t total_size = static_cast(ws[0].value_or(1)); - for (size_t i = 1; i < 3; i++) { - total_size *= static_cast(ws[i].value_or(1)); - if (total_size > 0xffffffff) { - AddError("total workgroup grid size cannot exceed 0xffffffff", values[i]->source); - return false; - } - } - - current_function_->SetWorkgroupSize(std::move(ws)); - return true; -} - bool Resolver::Statements(utils::VectorRef stmts) { sem::Behaviors behaviors{sem::Behavior::kNext}; @@ -3474,25 +3538,186 @@ sem::ValueExpression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { return sem; } -bool Resolver::Attribute(const ast::Attribute* attr) { - Mark(attr); - return Switch( - attr, // - [&](const ast::BuiltinAttribute* b) { return BuiltinAttribute(b); }, - [&](const ast::DiagnosticAttribute* d) { return DiagnosticControl(d->control); }, - [&](const ast::InterpolateAttribute* i) { return InterpolateAttribute(i); }, - [&](const ast::InternalAttribute* i) { return InternalAttribute(i); }, - [&](Default) { return true; }); +utils::Result Resolver::LocationAttribute(const ast::LocationAttribute* attr) { + ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@location value"}; + TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); + + auto* materialized = Materialize(ValueExpression(attr->expr)); + if (!materialized) { + return utils::Failure; + } + + if (!materialized->Type()->IsAnyOf()) { + AddError("@location must be an i32 or u32 value", attr->source); + return utils::Failure; + } + + auto const_value = materialized->ConstantValue(); + auto value = const_value->ValueAs(); + if (value < 0) { + AddError("@location value must be non-negative", attr->source); + return utils::Failure; + } + + return static_cast(value); } -bool Resolver::BuiltinAttribute(const ast::BuiltinAttribute* attr) { +utils::Result Resolver::BindingAttribute(const ast::BindingAttribute* attr) { + ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@binding"}; + TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); + + auto* materialized = Materialize(ValueExpression(attr->expr)); + if (!materialized) { + return utils::Failure; + } + if (!materialized->Type()->IsAnyOf()) { + AddError("@binding must be an i32 or u32 value", attr->source); + return utils::Failure; + } + + auto const_value = materialized->ConstantValue(); + auto value = const_value->ValueAs(); + if (value < 0) { + AddError("@binding value must be non-negative", attr->source); + return utils::Failure; + } + return static_cast(value); +} + +utils::Result Resolver::GroupAttribute(const ast::GroupAttribute* attr) { + ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@group"}; + TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); + + auto* materialized = Materialize(ValueExpression(attr->expr)); + if (!materialized) { + return utils::Failure; + } + if (!materialized->Type()->IsAnyOf()) { + AddError("@group must be an i32 or u32 value", attr->source); + return utils::Failure; + } + + auto const_value = materialized->ConstantValue(); + auto value = const_value->ValueAs(); + if (value < 0) { + AddError("@group value must be non-negative", attr->source); + return utils::Failure; + } + return static_cast(value); +} + +utils::Result Resolver::WorkgroupAttribute( + const ast::WorkgroupAttribute* attr) { + // Set work-group size defaults. + sem::WorkgroupSize ws; + for (size_t i = 0; i < 3; i++) { + ws[i] = 1; + } + + auto values = attr->Values(); + utils::Vector args; + utils::Vector arg_tys; + + constexpr const char* kErrBadExpr = + "workgroup_size argument must be a constant or override-expression of type " + "abstract-integer, i32 or u32"; + + for (size_t i = 0; i < 3; i++) { + // Each argument to this attribute can either be a literal, an identifier for a + // module-scope constants, a const-expression, or nullptr if not specified. + auto* value = values[i]; + if (!value) { + break; + } + const auto* expr = ValueExpression(value); + if (!expr) { + return utils::Failure; + } + auto* ty = expr->Type(); + if (!ty->IsAnyOf()) { + AddError(kErrBadExpr, value->source); + return utils::Failure; + } + + if (expr->Stage() != sem::EvaluationStage::kConstant && + expr->Stage() != sem::EvaluationStage::kOverride) { + AddError(kErrBadExpr, value->source); + return utils::Failure; + } + + args.Push(expr); + arg_tys.Push(ty); + } + + auto* common_ty = type::Type::Common(arg_tys); + if (!common_ty) { + AddError("workgroup_size arguments must be of the same type, either i32 or u32", + attr->source); + return utils::Failure; + } + + // If all arguments are abstract-integers, then materialize to i32. + if (common_ty->Is()) { + common_ty = builder_->create(); + } + + for (size_t i = 0; i < args.Length(); i++) { + auto* materialized = Materialize(args[i], common_ty); + if (!materialized) { + return utils::Failure; + } + if (auto* value = materialized->ConstantValue()) { + if (value->ValueAs() < 1) { + AddError("workgroup_size argument must be at least 1", values[i]->source); + return utils::Failure; + } + ws[i] = value->ValueAs(); + } else { + ws[i] = std::nullopt; + } + } + + uint64_t total_size = static_cast(ws[0].value_or(1)); + for (size_t i = 1; i < 3; i++) { + total_size *= static_cast(ws[i].value_or(1)); + if (total_size > 0xffffffff) { + AddError("total workgroup grid size cannot exceed 0xffffffff", values[i]->source); + return utils::Failure; + } + } + + return ws; +} + +utils::Result Resolver::BuiltinAttribute( + const ast::BuiltinAttribute* attr) { auto* builtin_expr = BuiltinValueExpression(attr->builtin); if (!builtin_expr) { - return false; + return utils::Failure; } // Apply the resolved tint::sem::BuiltinEnumExpression to the // attribute. builder_->Sem().Add(attr, builtin_expr); + return builtin_expr->Value(); +} + +bool Resolver::DiagnosticAttribute(const ast::DiagnosticAttribute* attr) { + return DiagnosticControl(attr->control); +} + +bool Resolver::StageAttribute(const ast::StageAttribute*) { + return true; +} + +bool Resolver::MustUseAttribute(const ast::MustUseAttribute*) { + return true; +} + +bool Resolver::InvariantAttribute(const ast::InvariantAttribute*) { + return true; +} + +bool Resolver::StrideAttribute(const ast::StrideAttribute*) { return true; } @@ -3626,24 +3851,30 @@ bool Resolver::ArrayAttributes(utils::VectorRef attribute return false; } - for (auto* attr : attributes) { - Mark(attr); - if (auto* sd = attr->As()) { - // If the element type is not plain, then el_ty->Align() may be 0, in which case we - // could get a DBZ in ArrayStrideAttribute(). In this case, validation will error - // about the invalid array element type (which is tested later), so this is just a - // seatbelt. - if (IsPlain(el_ty)) { - explicit_stride = sd->stride; - if (!validator_.ArrayStrideAttribute(sd, el_ty->Size(), el_ty->Align())) { - return false; + for (auto* attribute : attributes) { + Mark(attribute); + bool ok = Switch( + attribute, // + [&](const ast::StrideAttribute* attr) { + // If the element type is not plain, then el_ty->Align() may be 0, in which case we + // could get a DBZ in ArrayStrideAttribute(). In this case, validation will error + // about the invalid array element type (which is tested later), so this is just a + // seatbelt. + if (IsPlain(el_ty)) { + explicit_stride = attr->stride; + if (!validator_.ArrayStrideAttribute(attr, el_ty->Size(), el_ty->Align())) { + return false; + } } - } - continue; + return true; + }, + [&](Default) { + ErrorInvalidAttribute(attribute, "array types"); + return false; + }); + if (!ok) { + return false; } - - AddError("attribute is not valid for array types", attr->source); - return false; } return true; @@ -3727,8 +3958,18 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { if (!validator_.NoDuplicateAttributes(str->attributes)) { return nullptr; } - for (auto* attr : str->attributes) { - Mark(attr); + + for (auto* attribute : str->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); }, + [&](Default) { + ErrorInvalidAttribute(attribute, "struct declarations"); + return false; + }); + if (!ok) { + return nullptr; + } } utils::Vector sem_members; @@ -3781,88 +4022,87 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { bool has_align_attr = false; bool has_size_attr = false; std::optional location; - for (auto* attr : member->attributes) { - if (!Attribute(attr)) { - return nullptr; - } + for (auto* attribute : member->attributes) { + Mark(attribute); bool ok = Switch( - attr, // - [&](const ast::StructMemberOffsetAttribute* o) { - // Offset attributes are not part of the WGSL spec, but are emitted - // by the SPIR-V reader. + attribute, // + [&](const ast::StructMemberOffsetAttribute* attr) { + // Offset attributes are not part of the WGSL spec, but are emitted by the + // SPIR-V reader. + ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@offset value"}; TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - auto* materialized = Materialize(ValueExpression(o->expr)); + auto* materialized = Materialize(ValueExpression(attr->expr)); if (!materialized) { return false; } auto const_value = materialized->ConstantValue(); if (!const_value) { - AddError("@offset must be constant expression", o->expr->source); + AddError("@offset must be constant expression", attr->expr->source); return false; } offset = const_value->ValueAs(); if (offset < struct_size) { - AddError("offsets must be in ascending order", o->source); + AddError("offsets must be in ascending order", attr->source); return false; } has_offset_attr = true; return true; }, - [&](const ast::StructMemberAlignAttribute* a) { + [&](const ast::StructMemberAlignAttribute* attr) { ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@align"}; TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - auto* materialized = Materialize(ValueExpression(a->expr)); + auto* materialized = Materialize(ValueExpression(attr->expr)); if (!materialized) { return false; } if (!materialized->Type()->IsAnyOf()) { - AddError("@align must be an i32 or u32 value", a->source); + AddError("@align must be an i32 or u32 value", attr->source); return false; } auto const_value = materialized->ConstantValue(); if (!const_value) { - AddError("@align must be constant expression", a->source); + AddError("@align must be constant expression", attr->source); return false; } auto value = const_value->ValueAs(); if (value <= 0 || !utils::IsPowerOfTwo(value)) { AddError("@align value must be a positive, power-of-two integer", - a->source); + attr->source); return false; } align = u32(value); has_align_attr = true; return true; }, - [&](const ast::StructMemberSizeAttribute* s) { + [&](const ast::StructMemberSizeAttribute* attr) { ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@size"}; TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - auto* materialized = Materialize(ValueExpression(s->expr)); + auto* materialized = Materialize(ValueExpression(attr->expr)); if (!materialized) { return false; } if (!materialized->Type()->IsAnyOf()) { - AddError("@size must be an i32 or u32 value", s->source); + AddError("@size must be an i32 or u32 value", attr->source); return false; } auto const_value = materialized->ConstantValue(); if (!const_value) { - AddError("@size must be constant expression", s->expr->source); + AddError("@size must be constant expression", attr->expr->source); return false; } { auto value = const_value->ValueAs(); if (value <= 0) { - AddError("@size must be a positive integer", s->source); + AddError("@size must be a positive integer", attr->source); return false; } } @@ -3870,24 +4110,36 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { if (value < size) { AddError("@size must be at least as big as the type's size (" + std::to_string(size) + ")", - s->source); + attr->source); return false; } size = u32(value); has_size_attr = true; return true; }, - [&](const ast::LocationAttribute* loc_attr) { - auto value = LocationAttribute(loc_attr); + [&](const ast::LocationAttribute* attr) { + auto value = LocationAttribute(attr); if (!value) { return false; } location = value.Get(); return true; }, + [&](const ast::BuiltinAttribute* attr) -> bool { return BuiltinAttribute(attr); }, + [&](const ast::InterpolateAttribute* attr) { return InterpolateAttribute(attr); }, + [&](const ast::InvariantAttribute* attr) { return InvariantAttribute(attr); }, + [&](const ast::StrideAttribute* attr) { + if (validator_.IsValidationEnabled( + member->attributes, ast::DisabledValidation::kIgnoreStrideAttribute)) { + ErrorInvalidAttribute(attribute, "struct members"); + return false; + } + return StrideAttribute(attr); + }, + [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); }, [&](Default) { - // The validator will check attributes can be applied to the struct member. - return true; + ErrorInvalidAttribute(attribute, "struct members"); + return false; }); if (!ok) { return nullptr; @@ -4049,14 +4301,16 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt } // Handle switch body attributes. - for (auto* attr : stmt->body_attributes) { - Mark(attr); - if (auto* dc = attr->As()) { - if (!DiagnosticControl(dc->control)) { + for (auto* attribute : stmt->body_attributes) { + Mark(attribute); + bool ok = Switch( + attribute, + [&](const ast::DiagnosticAttribute* attr) { return DiagnosticAttribute(attr); }, + [&](Default) { + ErrorInvalidAttribute(attribute, "switch body"); return false; - } - } else { - AddError("attribute is not valid for switch body", attr->source); + }); + if (!ok) { return false; } } @@ -4099,14 +4353,6 @@ sem::Statement* Resolver::VariableDeclStatement(const ast::VariableDeclStatement return false; } - for (auto* attr : stmt->variable->attributes) { - Mark(attr); - if (!attr->Is()) { - AddError("attributes are not valid on local variables", attr->source); - return false; - } - } - current_compound_statement_->AddDecl(variable->As()); if (auto* ctor = variable->Initializer()) { @@ -4339,16 +4585,16 @@ SEM* Resolver::StatementScope(const ast::Statement* ast, SEM* sem, F&& callback) // Helper to handle attributes that are supported on certain types of statement. auto handle_attributes = [&](auto* stmt, sem::Statement* sem_stmt, const char* use) { - for (auto* attr : stmt->attributes) { - Mark(attr); - if (auto* dc = attr->template As()) { - if (!DiagnosticControl(dc->control)) { + for (auto* attribute : stmt->attributes) { + Mark(attribute); + bool ok = Switch( + attribute, // + [&](const ast::DiagnosticAttribute* attr) { return DiagnosticAttribute(attr); }, + [&](Default) { + ErrorInvalidAttribute(attribute, use); return false; - } - } else { - utils::StringStream ss; - ss << "attribute is not valid for " << use; - AddError(ss.str(), attr->source); + }); + if (!ok) { return false; } } @@ -4451,6 +4697,10 @@ void Resolver::ErrorMismatchedResolvedIdentifier(const Source& source, sem_.NoteDeclarationSource(resolved.Node()); } +void Resolver::ErrorInvalidAttribute(const ast::Attribute* attr, std::string_view use) { + AddError("@" + attr->Name() + " is not valid for " + std::string(use), attr->source); +} + void Resolver::AddError(const std::string& msg, const Source& source) const { diagnostics_.add_error(diag::System::Resolver, msg, source); } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index edc088a676..b26ad96fd1 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -312,13 +312,45 @@ class Resolver { /// current_function_ bool WorkgroupSize(const ast::Function*); - /// Resolves the attribute @p attr - /// @returns true on success, false on failure - bool Attribute(const ast::Attribute* attr); - /// Resolves the `@builtin` attribute @p attr + /// @returns the builtin value on success + utils::Result BuiltinAttribute(const ast::BuiltinAttribute* attr); + + /// Resolves the `@location` attribute @p attr + /// @returns the location value on success. + utils::Result LocationAttribute(const ast::LocationAttribute* attr); + + /// Resolves the `@binding` attribute @p attr + /// @returns the binding value on success. + utils::Result BindingAttribute(const ast::BindingAttribute* attr); + + /// Resolves the `@group` attribute @p attr + /// @returns the group value on success. + utils::Result GroupAttribute(const ast::GroupAttribute* attr); + + /// Resolves the `@workgroup_size` attribute @p attr + /// @returns the workgroup size on success. + utils::Result WorkgroupAttribute(const ast::WorkgroupAttribute* attr); + + /// Resolves the `@diagnostic` attribute @p attr /// @returns true on success, false on failure - bool BuiltinAttribute(const ast::BuiltinAttribute* attr); + bool DiagnosticAttribute(const ast::DiagnosticAttribute* attr); + + /// Resolves the stage attribute @p attr + /// @returns true on success, false on failure + bool StageAttribute(const ast::StageAttribute* attr); + + /// Resolves the `@must_use` attribute @p attr + /// @returns true on success, false on failure + bool MustUseAttribute(const ast::MustUseAttribute* attr); + + /// Resolves the `@invariant` attribute @p attr + /// @returns true on success, false on failure + bool InvariantAttribute(const ast::InvariantAttribute*); + + /// Resolves the `@stride` attribute @p attr + /// @returns true on success, false on failure + bool StrideAttribute(const ast::StrideAttribute*); /// Resolves the `@interpolate` attribute @p attr /// @returns true on success, false on failure @@ -427,12 +459,11 @@ class Resolver { /// nullptr is returned. /// @note the caller is expected to validate the parameter /// @param param the AST parameter + /// @param func the AST function that owns the parameter /// @param index the index of the parameter - sem::Parameter* Parameter(const ast::Parameter* param, uint32_t index); - - /// @returns the location value for a `@location` attribute, validating the value's range and - /// type. - utils::Result LocationAttribute(const ast::LocationAttribute* attr); + sem::Parameter* Parameter(const ast::Parameter* param, + const ast::Function* func, + uint32_t index); /// Records the address space usage for the given type, and any transient /// dependencies of the type. Validates that the type can be used for the @@ -497,6 +528,11 @@ class Resolver { const ResolvedIdentifier& resolved, std::string_view wanted); + /// Raises an error that the attribute is not valid for the given use. + /// @param attr the invalue attribute + /// @param use the thing that the attribute was applied to + void ErrorInvalidAttribute(const ast::Attribute* attr, std::string_view use); + /// Adds the given error message to the diagnostics void AddError(const std::string& msg, const Source& source) const; diff --git a/src/tint/resolver/unresolved_identifier_test.cc b/src/tint/resolver/unresolved_identifier_test.cc index e52b858b27..580200535b 100644 --- a/src/tint/resolver/unresolved_identifier_test.cc +++ b/src/tint/resolver/unresolved_identifier_test.cc @@ -43,7 +43,7 @@ TEST_F(ResolverUnresolvedIdentifierSuggestions, BuiltinValue) { Func("f", utils::Vector{ Param("p", ty.i32(), utils::Vector{Builtin(Expr(Source{{12, 34}}, "positon"))})}, - ty.void_(), utils::Empty); + ty.void_(), utils::Empty, utils::Vector{Stage(ast::PipelineStage::kVertex)}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: unresolved builtin value 'positon' diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 97eae8839c..c1ba335902 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -606,32 +606,10 @@ bool Validator::GlobalVariable( return false; } - for (auto* attr : decl->attributes) { - bool is_shader_io_attribute = - attr->IsAnyOf(); - bool has_io_address_space = global->AddressSpace() == builtin::AddressSpace::kIn || - global->AddressSpace() == builtin::AddressSpace::kOut; - if (!attr->IsAnyOf() && - (!is_shader_io_attribute || !has_io_address_space)) { - AddError("attribute '" + attr->Name() + "' is not valid for module-scope 'var'", - attr->source); - return false; - } - } - return Var(global); }, [&](const ast::Override*) { return Override(global, override_ids); }, - [&](const ast::Const*) { - if (!decl->attributes.IsEmpty()) { - AddError("attribute is not valid for module-scope 'const' declaration", - decl->attributes[0]->source); - return false; - } - return Const(global); - }, + [&](const ast::Const*) { return Const(global); }, [&](Default) { TINT_ICE(Resolver, diagnostics_) << "Validator::GlobalVariable() called with a unknown variable type: " @@ -773,9 +751,6 @@ bool Validator::Override( ast::GetAttribute((*var)->Declaration()->attributes)->source); return false; } - } else { - AddError("attribute is not valid for 'override' declaration", attr->source); - return false; } } @@ -792,28 +767,13 @@ bool Validator::Const(const sem::Variable*) const { return true; } -bool Validator::Parameter(const ast::Function* func, const sem::Variable* var) const { +bool Validator::Parameter(const sem::Variable* var) const { auto* decl = var->Declaration(); if (IsValidationDisabled(decl->attributes, ast::DisabledValidation::kFunctionParameter)) { return true; } - for (auto* attr : decl->attributes) { - if (!func->IsEntryPoint() && !attr->Is()) { - AddError("attribute is not valid for non-entry point function parameters", - attr->source); - return false; - } - if (!attr->IsAnyOf() && - (IsValidationEnabled(decl->attributes, - ast::DisabledValidation::kEntryPointParameter))) { - AddError("attribute is not valid for function parameters", attr->source); - return false; - } - } - if (auto* ref = var->Type()->As()) { if (IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreAddressSpace)) { bool ok = false; @@ -1028,14 +988,7 @@ bool Validator::Function(const sem::Function* func, ast::PipelineStage stage) co } return true; }, - [&](Default) { - if (!attr->IsAnyOf()) { - AddError("attribute is not valid for functions", attr->source); - return false; - } - return true; - }); + [&](Default) { return true; }); if (!ok) { return false; } @@ -1069,24 +1022,6 @@ bool Validator::Function(const sem::Function* func, ast::PipelineStage stage) co TINT_ICE(Resolver, diagnostics_) << "Function " << decl->name->symbol.Name() << " has no body"; } - - for (auto* attr : decl->return_type_attributes) { - if (!decl->IsEntryPoint()) { - AddError("attribute is not valid for non-entry point function return types", - attr->source); - return false; - } - if (!attr->IsAnyOf() && - (IsValidationEnabled(decl->attributes, - ast::DisabledValidation::kEntryPointParameter) && - IsValidationEnabled(decl->attributes, - ast::DisabledValidation::kFunctionParameter))) { - AddError("attribute is not valid for entry point return types", attr->source); - return false; - } - } } if (decl->IsEntryPoint()) { @@ -1196,7 +1131,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) if (is_invalid_compute_shader_attribute) { std::string input_or_output = param_or_ret == ParamOrRetType::kParameter ? "inputs" : "output"; - AddError("attribute is not valid for compute shader " + input_or_output, + AddError("@" + attr->Name() + " is not valid for compute shader " + input_or_output, attr->source); return false; } @@ -2205,24 +2140,7 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons } return true; }, - [&](Default) { - if (!attr->IsAnyOf()) { - if (attr->Is() && - IsValidationDisabled(member->Declaration()->attributes, - ast::DisabledValidation::kIgnoreStrideAttribute)) { - return true; - } - AddError("attribute is not valid for structure members", attr->source); - return false; - } - return true; - }); + [&](Default) { return true; }); if (!ok) { return false; } @@ -2241,13 +2159,6 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons } } - for (auto* attr : str->Declaration()->attributes) { - if (!(attr->IsAnyOf())) { - AddError("attribute is not valid for struct declarations", attr->source); - return false; - } - } - return true; } @@ -2260,7 +2171,8 @@ bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr, const bool is_input) const { std::string inputs_or_output = is_input ? "inputs" : "output"; if (stage == ast::PipelineStage::kCompute) { - AddError("attribute is not valid for compute shader " + inputs_or_output, loc_attr->source); + AddError("@" + loc_attr->Name() + " is not valid for compute shader " + inputs_or_output, + loc_attr->source); return false; } diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index e0e30517c5..1ab38bc959 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h @@ -348,10 +348,9 @@ class Validator { bool Matrix(const type::Type* el_ty, const Source& source) const; /// Validates a function parameter - /// @param func the function the variable is for /// @param var the variable to validate /// @returns true on success, false otherwise - bool Parameter(const ast::Function* func, const sem::Variable* var) const; + bool Parameter(const sem::Variable* var) const; /// Validates a return /// @param ret the return statement to validate diff --git a/src/tint/resolver/variable_test.cc b/src/tint/resolver/variable_test.cc index 7af6027b44..947f70c128 100644 --- a/src/tint/resolver/variable_test.cc +++ b/src/tint/resolver/variable_test.cc @@ -383,7 +383,7 @@ TEST_F(ResolverVariableTest, LocalVar_ShadowsParam) { } //////////////////////////////////////////////////////////////////////////////////////////////////// -// Function-scope 'let' +// 'let' declaration //////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(ResolverVariableTest, LocalLet) { // struct S { i : i32; }