writer/hlsl: Support overridable workgroup sizes
Use the WGSL_SPEC_CONSTANT preprocessor macros as parameters to [numthreads()] when the dimension is overridable. Remove the macro #undef to make this possible. Bug: tint:713 Change-Id: Icd927044a64a8b8a2f029f9e2db8168ec6a861de Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51264 Commit-Queue: James Price <jrprice@google.com> Auto-Submit: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
0a1927b4c3
commit
3d338315b3
|
@ -47,6 +47,7 @@ const char kOutStructNameSuffix[] = "out";
|
||||||
const char kTintStructInVarPrefix[] = "tint_in";
|
const char kTintStructInVarPrefix[] = "tint_in";
|
||||||
const char kTintStructOutVarPrefix[] = "tint_out";
|
const char kTintStructOutVarPrefix[] = "tint_out";
|
||||||
const char kTempNamePrefix[] = "tint_tmp";
|
const char kTempNamePrefix[] = "tint_tmp";
|
||||||
|
const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
||||||
|
|
||||||
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
|
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
|
||||||
if (stmts->empty()) {
|
if (stmts->empty()) {
|
||||||
|
@ -1994,18 +1995,26 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
|
||||||
auto* func_sem = builder_.Sem().Get(func);
|
auto* func_sem = builder_.Sem().Get(func);
|
||||||
|
|
||||||
if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
|
if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
|
||||||
|
// Emit the workgroup_size attribute.
|
||||||
auto wgsize = func_sem->workgroup_size();
|
auto wgsize = func_sem->workgroup_size();
|
||||||
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
|
out << "[numthreads(";
|
||||||
wgsize[2].overridable_const) {
|
for (int i = 0; i < 3; i++) {
|
||||||
// TODO(crbug.com/tint/713): Handle overridable constants.
|
if (i > 0) {
|
||||||
TINT_UNIMPLEMENTED(builder_.Diagnostics())
|
out << ", ";
|
||||||
<< "pipeline-overridable workgroup sizes are not implemented";
|
|
||||||
}
|
}
|
||||||
uint32_t x = wgsize[0].value;
|
|
||||||
uint32_t y = wgsize[1].value;
|
if (wgsize[i].overridable_const) {
|
||||||
uint32_t z = wgsize[2].value;
|
auto* sem_const = builder_.Sem().Get(wgsize[i].overridable_const);
|
||||||
out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y)
|
if (!sem_const->IsPipelineConstant()) {
|
||||||
<< ", " << std::to_string(z) << ")]" << std::endl;
|
TINT_ICE(builder_.Diagnostics())
|
||||||
|
<< "expected a pipeline-overridable constant";
|
||||||
|
}
|
||||||
|
out << kSpecConstantPrefix << sem_const->ConstantId();
|
||||||
|
} else {
|
||||||
|
out << std::to_string(wgsize[i].value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out << ")]" << std::endl;
|
||||||
make_indent(out);
|
make_indent(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2721,10 +2730,10 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
|
||||||
if (sem->IsPipelineConstant()) {
|
if (sem->IsPipelineConstant()) {
|
||||||
auto const_id = sem->ConstantId();
|
auto const_id = sem->ConstantId();
|
||||||
|
|
||||||
out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
|
out << "#ifndef " << kSpecConstantPrefix << const_id << std::endl;
|
||||||
|
|
||||||
if (var->constructor() != nullptr) {
|
if (var->constructor() != nullptr) {
|
||||||
out << "#define WGSL_SPEC_CONSTANT_" << const_id << " "
|
out << "#define " << kSpecConstantPrefix << const_id << " "
|
||||||
<< constructor_out.str() << std::endl;
|
<< constructor_out.str() << std::endl;
|
||||||
} else {
|
} else {
|
||||||
out << "#error spec constant required for constant id " << const_id
|
out << "#error spec constant required for constant id " << const_id
|
||||||
|
@ -2736,9 +2745,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
|
||||||
builder_.Symbols().NameFor(var->symbol()))) {
|
builder_.Symbols().NameFor(var->symbol()))) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
out << " " << builder_.Symbols().NameFor(var->symbol())
|
out << " " << builder_.Symbols().NameFor(var->symbol()) << " = "
|
||||||
<< " = WGSL_SPEC_CONSTANT_" << const_id << ";" << std::endl;
|
<< kSpecConstantPrefix << const_id << ";" << std::endl;
|
||||||
out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
|
|
||||||
} else {
|
} else {
|
||||||
out << "static const ";
|
out << "static const ";
|
||||||
if (!EmitType(out, type, sem->StorageClass(), sem->AccessControl(),
|
if (!EmitType(out, type, sem->StorageClass(), sem->AccessControl(),
|
||||||
|
|
|
@ -919,11 +919,8 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HlslGeneratorImplTest_Function,
|
TEST_F(HlslGeneratorImplTest_Function,
|
||||||
Emit_Decoration_EntryPoint_Compute_WithWorkgroup) {
|
Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Literal) {
|
||||||
Func("main", ast::VariableList{}, ty.void_(),
|
Func("main", ast::VariableList{}, ty.void_(), {},
|
||||||
{
|
|
||||||
Return(),
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Stage(ast::PipelineStage::kCompute),
|
Stage(ast::PipelineStage::kCompute),
|
||||||
WorkgroupSize(2, 4, 6),
|
WorkgroupSize(2, 4, 6),
|
||||||
|
@ -942,6 +939,69 @@ void main() {
|
||||||
Validate();
|
Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HlslGeneratorImplTest_Function,
|
||||||
|
Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Const) {
|
||||||
|
GlobalConst("width", ty.i32(), Construct(ty.i32(), 2));
|
||||||
|
GlobalConst("height", ty.i32(), Construct(ty.i32(), 3));
|
||||||
|
GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4));
|
||||||
|
Func("main", ast::VariableList{}, ty.void_(), {},
|
||||||
|
{
|
||||||
|
Stage(ast::PipelineStage::kCompute),
|
||||||
|
WorkgroupSize("width", "height", "depth"),
|
||||||
|
});
|
||||||
|
|
||||||
|
GeneratorImpl& gen = Build();
|
||||||
|
|
||||||
|
ASSERT_TRUE(gen.Generate(out)) << gen.error();
|
||||||
|
EXPECT_EQ(result(), R"(static const int width = int(2);
|
||||||
|
static const int height = int(3);
|
||||||
|
static const int depth = int(4);
|
||||||
|
[numthreads(2, 3, 4)]
|
||||||
|
void main() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
)");
|
||||||
|
|
||||||
|
Validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HlslGeneratorImplTest_Function,
|
||||||
|
Emit_Decoration_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
|
||||||
|
GlobalConst("width", ty.i32(), Construct(ty.i32(), 2), {Override(7u)});
|
||||||
|
GlobalConst("height", ty.i32(), Construct(ty.i32(), 3), {Override(8u)});
|
||||||
|
GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4), {Override(9u)});
|
||||||
|
Func("main", ast::VariableList{}, ty.void_(), {},
|
||||||
|
{
|
||||||
|
Stage(ast::PipelineStage::kCompute),
|
||||||
|
WorkgroupSize("width", "height", "depth"),
|
||||||
|
});
|
||||||
|
|
||||||
|
GeneratorImpl& gen = Build();
|
||||||
|
|
||||||
|
ASSERT_TRUE(gen.Generate(out)) << gen.error();
|
||||||
|
EXPECT_EQ(result(), R"(#ifndef WGSL_SPEC_CONSTANT_7
|
||||||
|
#define WGSL_SPEC_CONSTANT_7 int(2)
|
||||||
|
#endif
|
||||||
|
static const int width = WGSL_SPEC_CONSTANT_7;
|
||||||
|
#ifndef WGSL_SPEC_CONSTANT_8
|
||||||
|
#define WGSL_SPEC_CONSTANT_8 int(3)
|
||||||
|
#endif
|
||||||
|
static const int height = WGSL_SPEC_CONSTANT_8;
|
||||||
|
#ifndef WGSL_SPEC_CONSTANT_9
|
||||||
|
#define WGSL_SPEC_CONSTANT_9 int(4)
|
||||||
|
#endif
|
||||||
|
static const int depth = WGSL_SPEC_CONSTANT_9;
|
||||||
|
[numthreads(WGSL_SPEC_CONSTANT_7, WGSL_SPEC_CONSTANT_8, WGSL_SPEC_CONSTANT_9)]
|
||||||
|
void main() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
)");
|
||||||
|
|
||||||
|
Validate();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
|
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
|
||||||
Func("my_func", ast::VariableList{Param("a", ty.array<f32, 5>())}, ty.void_(),
|
Func("my_func", ast::VariableList{Param("a", ty.array<f32, 5>())}, ty.void_(),
|
||||||
{
|
{
|
||||||
|
|
|
@ -45,7 +45,6 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) {
|
||||||
#define WGSL_SPEC_CONSTANT_23 3.0f
|
#define WGSL_SPEC_CONSTANT_23 3.0f
|
||||||
#endif
|
#endif
|
||||||
static const float pos = WGSL_SPEC_CONSTANT_23;
|
static const float pos = WGSL_SPEC_CONSTANT_23;
|
||||||
#undef WGSL_SPEC_CONSTANT_23
|
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +61,6 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) {
|
||||||
#error spec constant required for constant id 23
|
#error spec constant required for constant id 23
|
||||||
#endif
|
#endif
|
||||||
static const float pos = WGSL_SPEC_CONSTANT_23;
|
static const float pos = WGSL_SPEC_CONSTANT_23;
|
||||||
#undef WGSL_SPEC_CONSTANT_23
|
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,12 +82,10 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) {
|
||||||
#define WGSL_SPEC_CONSTANT_0 3.0f
|
#define WGSL_SPEC_CONSTANT_0 3.0f
|
||||||
#endif
|
#endif
|
||||||
static const float a = WGSL_SPEC_CONSTANT_0;
|
static const float a = WGSL_SPEC_CONSTANT_0;
|
||||||
#undef WGSL_SPEC_CONSTANT_0
|
|
||||||
#ifndef WGSL_SPEC_CONSTANT_1
|
#ifndef WGSL_SPEC_CONSTANT_1
|
||||||
#define WGSL_SPEC_CONSTANT_1 2.0f
|
#define WGSL_SPEC_CONSTANT_1 2.0f
|
||||||
#endif
|
#endif
|
||||||
static const float b = WGSL_SPEC_CONSTANT_1;
|
static const float b = WGSL_SPEC_CONSTANT_1;
|
||||||
#undef WGSL_SPEC_CONSTANT_1
|
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue