diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 22010f1806..8761eaa5dd 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2925,6 +2925,60 @@ void Resolver::SetType(const ast::Expression* expr, } bool Resolver::ValidatePipelineStages() { + auto check_workgroup_storage = [&](FunctionInfo* func, + FunctionInfo* entry_point) { + auto stage = entry_point->declaration->pipeline_stage(); + if (stage != ast::PipelineStage::kCompute) { + for (auto* var : func->local_referenced_module_vars) { + if (var->storage_class == ast::StorageClass::kWorkgroup) { + std::stringstream stage_name; + stage_name << stage; + for (auto* user : var->users) { + auto it = expr_info_.find(user->As()); + if (it != expr_info_.end()) { + if (func->declaration->symbol() == + it->second.statement->Function()->symbol()) { + diagnostics_.add_error("workgroup memory cannot be used by " + + stage_name.str() + " pipeline stage", + user->source()); + break; + } + } + } + diagnostics_.add_note("variable is declared here", + var->declaration->source()); + if (func != entry_point) { + TraverseCallChain(entry_point, func, [&](FunctionInfo* f) { + diagnostics_.add_note( + "called by function '" + + builder_->Symbols().NameFor(f->declaration->symbol()) + + "'", + f->declaration->source()); + }); + diagnostics_.add_note("called by entry point '" + + builder_->Symbols().NameFor( + entry_point->declaration->symbol()) + + "'", + entry_point->declaration->source()); + } + return false; + } + } + } + return true; + }; + + for (auto* entry_point : entry_points_) { + if (!check_workgroup_storage(entry_point, entry_point)) { + return false; + } + for (auto* func : entry_point->transitive_calls) { + if (!check_workgroup_storage(func, entry_point)) { + return false; + } + } + } + auto check_intrinsic_calls = [&](FunctionInfo* func, FunctionInfo* entry_point) { auto stage = entry_point->declaration->pipeline_stage(); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 600a588bed..dca2f71f7d 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -63,6 +63,55 @@ class FakeExpr : public ast::Expression { void to_str(const sem::Info&, std::ostream&, size_t) const override {} }; +TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInVertexStage) { + Global(Source{{1, 2}}, "wg", ty.vec4(), ast::StorageClass::kWorkgroup); + Global("dst", ty.vec4(), ast::StorageClass::kPrivate); + auto* stmt = Assign(Expr("dst"), Expr(Source{{3, 4}}, "wg")); + + Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4(), + {stmt, Return(Expr("dst"))}, + ast::DecorationList{Stage(ast::PipelineStage::kVertex)}, + ast::DecorationList{Builtin(ast::Builtin::kPosition)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "3:4 error: workgroup memory cannot be used by vertex pipeline " + "stage\n1:2 note: variable is declared here"); +} + +TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) { + // var wg : vec4; + // var dst : vec4; + // fn f2(){ dst = wg; } + // fn f1() { f2(); } + // [[stage(fragment)]] + // fn f0() -> [[builtin(position)]] vec4 { + // f1(); + // return dst; + //} + + Global(Source{{1, 2}}, "wg", ty.vec4(), ast::StorageClass::kWorkgroup); + Global("dst", ty.vec4(), ast::StorageClass::kPrivate); + auto* stmt = Assign(Expr("dst"), Expr(Source{{3, 4}}, "wg")); + + Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(), {stmt}); + Func(Source{{7, 8}}, "f1", ast::VariableList{}, ty.void_(), + {Ignore(Call("f2"))}); + Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4(), + {Ignore(Call("f1")), Return(Expr("dst"))}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, + ast::DecorationList{Builtin(ast::Builtin::kPosition)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(3:4 error: workgroup memory cannot be used by fragment pipeline stage +1:2 note: variable is declared here +5:6 note: called by function 'f2' +7:8 note: called by function 'f1' +9:10 note: called by entry point 'f0')"); +} + TEST_F(ResolverValidationTest, Error_WithEmptySource) { auto* s = create(); WrapInFunction(s); diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc index c2f8b5ede7..a642443a01 100644 --- a/src/writer/spirv/builder_function_decoration_test.cc +++ b/src/writer/spirv/builder_function_decoration_test.cc @@ -129,7 +129,7 @@ OpName %11 "main" TEST_F(BuilderTest, Decoration_Stage_WithUsedInterfaceIds) { auto* v_in = Global("my_in", ty.f32(), ast::StorageClass::kInput); auto* v_out = Global("my_out", ty.f32(), ast::StorageClass::kOutput); - auto* v_wg = Global("my_wg", ty.f32(), ast::StorageClass::kWorkgroup); + auto* v_wg = Global("my_wg", ty.f32(), ast::StorageClass::kPrivate); auto* func = Func( "main", {}, ty.void_(), @@ -159,8 +159,8 @@ OpName %11 "main" %5 = OpTypePointer Output %3 %6 = OpConstantNull %3 %4 = OpVariable %5 Output %6 -%8 = OpTypePointer Workgroup %3 -%7 = OpVariable %8 Workgroup +%8 = OpTypePointer Private %3 +%7 = OpVariable %8 Private %6 %10 = OpTypeVoid %9 = OpTypeFunction %10 )");