diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 76f1e86a76..bbe4bc09b6 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -252,6 +252,7 @@ class Resolver { bool ValidateBreakStatement(const sem::Statement* stmt); bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, const sem::Type* storage_type, + ast::PipelineStage stage, const bool is_input); bool ValidateContinueStatement(const sem::Statement* stmt); bool ValidateDiscardStatement(const sem::Statement* stmt); diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc index e12d2e8397..476b008fd1 100644 --- a/src/tint/resolver/resolver_validation.cc +++ b/src/tint/resolver/resolver_validation.cc @@ -692,11 +692,9 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, const sem::Type* storage_ty, + ast::PipelineStage stage, const bool is_input) { auto* type = storage_ty->UnwrapRef(); - const auto stage = current_function_ - ? current_function_->Declaration()->PipelineStage() - : ast::PipelineStage::kNone; std::stringstream stage_name; stage_name << stage; bool is_stage_mismatch = false; @@ -961,6 +959,9 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func) { Source source, ParamOrRetType param_or_ret, bool is_struct_member) { + auto stage = current_function_ + ? current_function_->Declaration()->PipelineStage() + : ast::PipelineStage::kNone; // Scan attributes for pipeline IO attributes. // Check for overlap with attributes that have been seen previously. const ast::Attribute* pipeline_io_attribute = nullptr; @@ -968,6 +969,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func) { const ast::InvariantAttribute* invariant_attribute = nullptr; for (auto* attr : attrs) { auto is_invalid_compute_shader_attribute = false; + if (auto* builtin = attr->As()) { if (pipeline_io_attribute) { AddError("multiple entry point IO attributes", attr->source); @@ -987,7 +989,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func) { } if (!ValidateBuiltinAttribute( - builtin, ty, + builtin, ty, stage, /* is_input */ param_or_ret == ParamOrRetType::kParameter)) { return false; } @@ -1003,9 +1005,6 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func) { bool is_input = param_or_ret == ParamOrRetType::kParameter; - auto stage = current_function_ - ? current_function_->Declaration()->PipelineStage() - : ast::PipelineStage::kNone; if (!ValidateLocationAttribute(location, ty, locations, stage, source, is_input)) { return false; @@ -2099,19 +2098,19 @@ bool Resolver::ValidateStructure(const sem::Struct* str) { return false; } + auto stage = current_function_ + ? current_function_->Declaration()->PipelineStage() + : ast::PipelineStage::kNone; if (auto* invariant = attr->As()) { invariant_attribute = invariant; } else if (auto* location = attr->As()) { has_location = true; - auto stage = current_function_ - ? current_function_->Declaration()->PipelineStage() - : ast::PipelineStage::kNone; if (!ValidateLocationAttribute(location, member->Type(), locations, stage, member->Declaration()->source)) { return false; } } else if (auto* builtin = attr->As()) { - if (!ValidateBuiltinAttribute(builtin, member->Type(), + if (!ValidateBuiltinAttribute(builtin, member->Type(), stage, /* is_input */ false)) { return false; }