diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 519baec8da..d80973dd64 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -699,7 +699,10 @@ sem::Function* Resolver::Function(const ast::Function* decl) { return nullptr; } - if (!ValidateFunction(func)) { + auto stage = current_function_ + ? current_function_->Declaration()->PipelineStage() + : ast::PipelineStage::kNone; + if (!ValidateFunction(func, stage)) { return nullptr; } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 04b493a04c..21dc853a8d 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -260,7 +260,7 @@ class Resolver { bool ValidateEntryPoint(const sem::Function* func, ast::PipelineStage stage); bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt); bool ValidateFallthroughStatement(const sem::Statement* stmt); - bool ValidateFunction(const sem::Function* func); + bool ValidateFunction(const sem::Function* func, ast::PipelineStage stage); bool ValidateFunctionCall(const sem::Call* call); bool ValidateGlobalVariable(const sem::Variable* var); bool ValidateIfStatement(const sem::IfStatement* stmt); diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc index f689108e87..8fa86de169 100644 --- a/src/tint/resolver/resolver_validation.cc +++ b/src/tint/resolver/resolver_validation.cc @@ -832,7 +832,8 @@ bool Resolver::ValidateInterpolateAttribute( return true; } -bool Resolver::ValidateFunction(const sem::Function* func) { +bool Resolver::ValidateFunction(const sem::Function* func, + ast::PipelineStage stage) { auto* decl = func->Declaration(); auto name = builder_->Symbols().NameFor(decl->symbol); @@ -917,9 +918,6 @@ bool Resolver::ValidateFunction(const sem::Function* func) { } if (decl->IsEntryPoint()) { - auto stage = current_function_ - ? current_function_->Declaration()->PipelineStage() - : ast::PipelineStage::kNone; if (!ValidateEntryPoint(func, stage)) { return false; }