tint/resolver: Ensure that total workgroup size fits in u32

This is below the 256x256x64 limits as defined by the WebGPU spec:
https://gpuweb.github.io/gpuweb/#limits

Fixed: tint:1692
Change-Id: I3608eb41094fbc7c77a40ea32f0f7418c31e0a05
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105401
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
Ben Clayton 2022-10-12 17:03:57 +00:00 committed by Dawn LUCI CQ
parent fafeb9a327
commit 78c839be97
2 changed files with 80 additions and 0 deletions

View File

@ -767,6 +767,77 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_NestedZeroValueConstr
EXPECT_EQ(r()->error(), "12:34 error: workgroup_size argument must be at least 1");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_0x100_0x100) {
// @compute @workgroup_size(0x10000, 0x100, 0x100)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(0x10000_a, 0x100_a, Expr(Source{{12, 34}}, 0x100_a)),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_0x10000) {
// @compute @workgroup_size(0x10000, 0x10000)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(0x10000_a, Expr(Source{{12, 34}}, 0x10000_a)),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_C_0x10000) {
// const C = 1;
// @compute @workgroup_size(0x10000, C, 0x10000)
// fn main() {}
GlobalConst("C", ty.u32(), Expr(1_a));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(0x10000_a, "C", Expr(Source{{12, 34}}, 0x10000_a)),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_C) {
// const C = 0x10000;
// @compute @workgroup_size(0x10000, C)
// fn main() {}
GlobalConst("C", ty.u32(), Expr(0x10000_a));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(0x10000_a, Expr(Source{{12, 34}}, "C")),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_O_0x10000) {
// override O = 0;
// @compute @workgroup_size(0x10000, O, 0x10000)
// fn main() {}
Override("O", ty.u32(), Expr(0_a));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(0x10000_a, "O", Expr(Source{{12, 34}}, 0x10000_a)),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
// var<private> x = 64i;
// @compute @workgroup_size(x)

View File

@ -1125,6 +1125,15 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
}
}
uint64_t total_size = static_cast<uint64_t>(ws[0].value_or(1));
for (size_t i = 1; i < 3; i++) {
total_size *= static_cast<uint64_t>(ws[i].value_or(1));
if (total_size > 0xffffffff) {
AddError("total workgroup grid size cannot exceed 0xffffffff", values[i]->source);
return false;
}
}
current_function_->SetWorkgroupSize(std::move(ws));
return true;
}