diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index bbe4bc09b6..defa5dac57 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -257,7 +257,7 @@ class Resolver { bool ValidateContinueStatement(const sem::Statement* stmt); bool ValidateDiscardStatement(const sem::Statement* stmt); bool ValidateElseStatement(const sem::ElseStatement* stmt); - bool ValidateEntryPoint(const sem::Function* func); + bool ValidateEntryPoint(const sem::Function* func, ast::PipelineStage stage); bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt); bool ValidateFallthroughStatement(const sem::Statement* stmt); bool ValidateFunction(const sem::Function* func); diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc index 476b008fd1..2119433605 100644 --- a/src/tint/resolver/resolver_validation.cc +++ b/src/tint/resolver/resolver_validation.cc @@ -917,7 +917,10 @@ bool Resolver::ValidateFunction(const sem::Function* func) { } if (decl->IsEntryPoint()) { - if (!ValidateEntryPoint(func)) { + auto stage = current_function_ + ? current_function_->Declaration()->PipelineStage() + : ast::PipelineStage::kNone; + if (!ValidateEntryPoint(func, stage)) { return false; } } @@ -937,7 +940,8 @@ bool Resolver::ValidateFunction(const sem::Function* func) { return true; } -bool Resolver::ValidateEntryPoint(const sem::Function* func) { +bool Resolver::ValidateEntryPoint(const sem::Function* func, + ast::PipelineStage stage) { auto* decl = func->Declaration(); // Use a lambda to validate the entry point attributes for a type. @@ -959,9 +963,6 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func) { Source source, ParamOrRetType param_or_ret, bool is_struct_member) { - auto stage = current_function_ - ? current_function_->Declaration()->PipelineStage() - : ast::PipelineStage::kNone; // Scan attributes for pipeline IO attributes. // Check for overlap with attributes that have been seen previously. const ast::Attribute* pipeline_io_attribute = nullptr;