validation: workgroup memory cannot be used in vertex/fragment stage

Bug: tint:907
Change-Id: I62c528437fabcd1a1f42c46ab6279ef9a14d96d2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55640
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Sarah 2021-06-24 08:56:16 +00:00 committed by Tint LUCI CQ
parent c8434889d8
commit 9548531887
3 changed files with 106 additions and 3 deletions

View File

@ -2925,6 +2925,60 @@ void Resolver::SetType(const ast::Expression* expr,
}
bool Resolver::ValidatePipelineStages() {
auto check_workgroup_storage = [&](FunctionInfo* func,
FunctionInfo* entry_point) {
auto stage = entry_point->declaration->pipeline_stage();
if (stage != ast::PipelineStage::kCompute) {
for (auto* var : func->local_referenced_module_vars) {
if (var->storage_class == ast::StorageClass::kWorkgroup) {
std::stringstream stage_name;
stage_name << stage;
for (auto* user : var->users) {
auto it = expr_info_.find(user->As<ast::Expression>());
if (it != expr_info_.end()) {
if (func->declaration->symbol() ==
it->second.statement->Function()->symbol()) {
diagnostics_.add_error("workgroup memory cannot be used by " +
stage_name.str() + " pipeline stage",
user->source());
break;
}
}
}
diagnostics_.add_note("variable is declared here",
var->declaration->source());
if (func != entry_point) {
TraverseCallChain(entry_point, func, [&](FunctionInfo* f) {
diagnostics_.add_note(
"called by function '" +
builder_->Symbols().NameFor(f->declaration->symbol()) +
"'",
f->declaration->source());
});
diagnostics_.add_note("called by entry point '" +
builder_->Symbols().NameFor(
entry_point->declaration->symbol()) +
"'",
entry_point->declaration->source());
}
return false;
}
}
}
return true;
};
for (auto* entry_point : entry_points_) {
if (!check_workgroup_storage(entry_point, entry_point)) {
return false;
}
for (auto* func : entry_point->transitive_calls) {
if (!check_workgroup_storage(func, entry_point)) {
return false;
}
}
}
auto check_intrinsic_calls = [&](FunctionInfo* func,
FunctionInfo* entry_point) {
auto stage = entry_point->declaration->pipeline_stage();

View File

@ -63,6 +63,55 @@ class FakeExpr : public ast::Expression {
void to_str(const sem::Info&, std::ostream&, size_t) const override {}
};
TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInVertexStage) {
Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
Global("dst", ty.vec4<f32>(), ast::StorageClass::kPrivate);
auto* stmt = Assign(Expr("dst"), Expr(Source{{3, 4}}, "wg"));
Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4<f32>(),
{stmt, Return(Expr("dst"))},
ast::DecorationList{Stage(ast::PipelineStage::kVertex)},
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"3:4 error: workgroup memory cannot be used by vertex pipeline "
"stage\n1:2 note: variable is declared here");
}
TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) {
// var<workgroup> wg : vec4<f32>;
// var<workgroup> dst : vec4<f32>;
// fn f2(){ dst = wg; }
// fn f1() { f2(); }
// [[stage(fragment)]]
// fn f0() -> [[builtin(position)]] vec4<f32> {
// f1();
// return dst;
//}
Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
Global("dst", ty.vec4<f32>(), ast::StorageClass::kPrivate);
auto* stmt = Assign(Expr("dst"), Expr(Source{{3, 4}}, "wg"));
Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(), {stmt});
Func(Source{{7, 8}}, "f1", ast::VariableList{}, ty.void_(),
{Ignore(Call("f2"))});
Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4<f32>(),
{Ignore(Call("f1")), Return(Expr("dst"))},
ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(3:4 error: workgroup memory cannot be used by fragment pipeline stage
1:2 note: variable is declared here
5:6 note: called by function 'f2'
7:8 note: called by function 'f1'
9:10 note: called by entry point 'f0')");
}
TEST_F(ResolverValidationTest, Error_WithEmptySource) {
auto* s = create<FakeStmt>();
WrapInFunction(s);

View File

@ -129,7 +129,7 @@ OpName %11 "main"
TEST_F(BuilderTest, Decoration_Stage_WithUsedInterfaceIds) {
auto* v_in = Global("my_in", ty.f32(), ast::StorageClass::kInput);
auto* v_out = Global("my_out", ty.f32(), ast::StorageClass::kOutput);
auto* v_wg = Global("my_wg", ty.f32(), ast::StorageClass::kWorkgroup);
auto* v_wg = Global("my_wg", ty.f32(), ast::StorageClass::kPrivate);
auto* func = Func(
"main", {}, ty.void_(),
@ -159,8 +159,8 @@ OpName %11 "main"
%5 = OpTypePointer Output %3
%6 = OpConstantNull %3
%4 = OpVariable %5 Output %6
%8 = OpTypePointer Workgroup %3
%7 = OpVariable %8 Workgroup
%8 = OpTypePointer Private %3
%7 = OpVariable %8 Private %6
%10 = OpTypeVoid
%9 = OpTypeFunction %10
)");