tint/resolver: Ensure that total workgroup size fits in u32
This is below the 256x256x64 limits as defined by the WebGPU spec: https://gpuweb.github.io/gpuweb/#limits Fixed: tint:1692 Change-Id: I3608eb41094fbc7c77a40ea32f0f7418c31e0a05 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105401 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: David Neto <dneto@google.com> Commit-Queue: David Neto <dneto@google.com> Auto-Submit: Ben Clayton <bclayton@google.com> Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
parent
fafeb9a327
commit
78c839be97
|
@ -767,6 +767,77 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_NestedZeroValueConstr
|
|||
EXPECT_EQ(r()->error(), "12:34 error: workgroup_size argument must be at least 1");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_0x100_0x100) {
|
||||
// @compute @workgroup_size(0x10000, 0x100, 0x100)
|
||||
// fn main() {}
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(0x10000_a, 0x100_a, Expr(Source{{12, 34}}, 0x100_a)),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_0x10000) {
|
||||
// @compute @workgroup_size(0x10000, 0x10000)
|
||||
// fn main() {}
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(0x10000_a, Expr(Source{{12, 34}}, 0x10000_a)),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_C_0x10000) {
|
||||
// const C = 1;
|
||||
// @compute @workgroup_size(0x10000, C, 0x10000)
|
||||
// fn main() {}
|
||||
GlobalConst("C", ty.u32(), Expr(1_a));
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(0x10000_a, "C", Expr(Source{{12, 34}}, 0x10000_a)),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_C) {
|
||||
// const C = 0x10000;
|
||||
// @compute @workgroup_size(0x10000, C)
|
||||
// fn main() {}
|
||||
GlobalConst("C", ty.u32(), Expr(0x10000_a));
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(0x10000_a, Expr(Source{{12, 34}}, "C")),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_OverflowsU32_0x10000_O_0x10000) {
|
||||
// override O = 0;
|
||||
// @compute @workgroup_size(0x10000, O, 0x10000)
|
||||
// fn main() {}
|
||||
Override("O", ty.u32(), Expr(0_a));
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(0x10000_a, "O", Expr(Source{{12, 34}}, 0x10000_a)),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: total workgroup grid size cannot exceed 0xffffffff");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
|
||||
// var<private> x = 64i;
|
||||
// @compute @workgroup_size(x)
|
||||
|
|
|
@ -1125,6 +1125,15 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
|
|||
}
|
||||
}
|
||||
|
||||
uint64_t total_size = static_cast<uint64_t>(ws[0].value_or(1));
|
||||
for (size_t i = 1; i < 3; i++) {
|
||||
total_size *= static_cast<uint64_t>(ws[i].value_or(1));
|
||||
if (total_size > 0xffffffff) {
|
||||
AddError("total workgroup grid size cannot exceed 0xffffffff", values[i]->source);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
current_function_->SetWorkgroupSize(std::move(ws));
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue