spirv-reader: support NumWorkgroups

Fixed: tint:1065
Change-Id: Id2a8af247e7da79933703e634478f1dec25f9145
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110220
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Auto-Submit: David Neto <dneto@google.com>
This commit is contained in:
David Neto 2022-11-15 01:32:17 +00:00 committed by Dawn LUCI CQ
parent 267f1748c8
commit 7f06aa06ac
3 changed files with 88 additions and 8 deletions

View File

@ -82,6 +82,8 @@ ast::BuiltinValue EnumConverter::ToBuiltin(spv::BuiltIn b) {
return ast::BuiltinValue::kLocalInvocationIndex;
case spv::BuiltIn::GlobalInvocationId:
return ast::BuiltinValue::kGlobalInvocationId;
case spv::BuiltIn::NumWorkgroups:
return ast::BuiltinValue::kNumWorkgroups;
case spv::BuiltIn::WorkgroupId:
return ast::BuiltinValue::kWorkgroupId;
case spv::BuiltIn::SampleId:

View File

@ -192,6 +192,7 @@ INSTANTIATE_TEST_SUITE_P(
BuiltinCase{spv::BuiltIn::LocalInvocationIndex, true,
ast::BuiltinValue::kLocalInvocationIndex},
BuiltinCase{spv::BuiltIn::GlobalInvocationId, true, ast::BuiltinValue::kGlobalInvocationId},
BuiltinCase{spv::BuiltIn::NumWorkgroups, true, ast::BuiltinValue::kNumWorkgroups},
BuiltinCase{spv::BuiltIn::WorkgroupId, true, ast::BuiltinValue::kWorkgroupId},
BuiltinCase{spv::BuiltIn::SampleId, true, ast::BuiltinValue::kSampleIndex},
BuiltinCase{spv::BuiltIn::SampleMask, true, ast::BuiltinValue::kSampleMask}));
@ -208,8 +209,6 @@ INSTANTIATE_TEST_SUITE_P(EnumConverterBad,
testing::Values(BuiltinCase{static_cast<spv::BuiltIn>(9999), false,
ast::BuiltinValue::kUndefined},
BuiltinCase{static_cast<spv::BuiltIn>(9999), false,
ast::BuiltinValue::kUndefined},
BuiltinCase{spv::BuiltIn::NumWorkgroups, false,
ast::BuiltinValue::kUndefined}));
// Dim

View File

@ -3106,6 +3106,14 @@ TEST_F(SpvModuleScopeVarParserTest, InstanceIndex_U32_FunctParam) {
// Returns the start of a shader for testing LocalInvocationIndex,
// parameterized by store type of %int or %uint
std::string ComputeBuiltinInputPreamble(std::string builtin, std::string store_type) {
std::string ptr_component_type;
if (store_type == "%v3int") {
ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %int\n";
}
if (store_type == "%v3uint") {
ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %uint\n";
}
return R"(
OpCapability Shader
OpMemoryModel Logical Simple
@ -3118,10 +3126,11 @@ std::string ComputeBuiltinInputPreamble(std::string builtin, std::string store_t
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%int = OpTypeInt 32 1
%int_1 = OpConstant %int 1
%v3uint = OpTypeVector %uint 3
%v3int = OpTypeVector %int 3
%ptr_ty = OpTypePointer Input )" +
store_type + R"(
store_type + ptr_component_type + R"(
%1 = OpVariable %ptr_ty Input
)";
}
@ -3331,14 +3340,84 @@ INSTANTIATE_TEST_SUITE_P(Samples,
{"LocalInvocationId", "%v3int", "local_invocation_id"},
{"GlobalInvocationId", "%v3uint", "global_invocation_id"},
{"GlobalInvocationId", "%v3int", "global_invocation_id"},
{"NumWorkgroups", "%v3uint", "num_workgroups"},
{"NumWorkgroups", "%v3int", "num_workgroups"},
{"WorkgroupId", "%v3uint", "workgroup_id"},
{"WorkgroupId", "%v3int", "workgroup_id"}}));
// TODO(dneto): crbug.com/tint/752
// NumWorkgroups support is blocked by crbug.com/tint/752
// When the AST supports NumWorkgroups, add these cases:
// {"NumWorkgroups", "%uint", "num_workgroups"}
// {"NumWorkgroups", "%int", "num_workgroups"}
// For compute shader builtins that are vectors, test loading one component.
struct ComputeBuiltinInputVectorCase {
std::string spirv_builtin;
std::string spirv_store_type;
std::string spirv_component_store_type;
std::string wgsl_builtin;
};
inline std::ostream& operator<<(std::ostream& o, ComputeBuiltinInputVectorCase c) {
return o << "ComputeBuiltinInputVectorCase(" << c.spirv_builtin << " " << c.spirv_store_type
<< " " << c.spirv_component_store_type << " " << c.wgsl_builtin << ")";
}
using SpvModuleScopeVarParserTest_ComputeBuiltinVector =
SpvParserTestBase<::testing::TestWithParam<ComputeBuiltinInputVectorCase>>;
TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltinVector, Load_Component_Direct) {
const auto wgsl_type = WgslType(GetParam().spirv_store_type);
const auto wgsl_component_type = WgslType(GetParam().spirv_component_store_type);
const auto wgsl_builtin = GetParam().wgsl_builtin;
const auto unsigned_wgsl_type = UnsignedWgslType(wgsl_type);
const auto signed_wgsl_type = SignedWgslType(wgsl_type);
const std::string assembly =
ComputeBuiltinInputPreamble(GetParam().spirv_builtin, GetParam().spirv_store_type) +
R"(
%main = OpFunction %void None %voidfn
%entry = OpLabel
%3 = OpAccessChain %ptr_comp_ty %1 %int_1
%2 = OpLoad )" +
GetParam().spirv_component_store_type + R"( %3
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
EXPECT_TRUE(p->error().empty());
const auto module_str = test::ToString(p->program());
std::string expected = R"(var<private> x_1 : ${wgsl_type};
fn main_1() {
let x_2 : ${wgsl_component_type} = x_1.y;
return;
}
@compute @workgroup_size(1i, 1i, 1i)
fn main(@builtin(${wgsl_builtin}) x_1_param : ${unsigned_wgsl_type}) {
x_1 = ${assignment_value};
main_1();
}
)";
expected = utils::ReplaceAll(expected, "${wgsl_type}", wgsl_type);
expected = utils::ReplaceAll(expected, "${wgsl_component_type}", wgsl_component_type);
expected = utils::ReplaceAll(expected, "${unsigned_wgsl_type}", unsigned_wgsl_type);
expected = utils::ReplaceAll(expected, "${wgsl_builtin}", wgsl_builtin);
expected = utils::ReplaceAll(expected, "${assignment_value}",
(wgsl_type == unsigned_wgsl_type)
? "x_1_param"
: "bitcast<" + signed_wgsl_type + ">(x_1_param)");
EXPECT_EQ(module_str, expected) << module_str;
}
INSTANTIATE_TEST_SUITE_P(Samples,
SpvModuleScopeVarParserTest_ComputeBuiltinVector,
::testing::ValuesIn(std::vector<ComputeBuiltinInputVectorCase>{
{"LocalInvocationId", "%v3uint", "%uint", "local_invocation_id"},
{"LocalInvocationId", "%v3int", "%int", "local_invocation_id"},
{"GlobalInvocationId", "%v3uint", "%uint", "global_invocation_id"},
{"GlobalInvocationId", "%v3int", "%int", "global_invocation_id"},
{"NumWorkgroups", "%v3uint", "%uint", "num_workgroups"},
{"NumWorkgroups", "%v3int", "%int", "num_workgroups"},
{"WorkgroupId", "%v3uint", "%uint", "workgroup_id"},
{"WorkgroupId", "%v3int", "%int", "workgroup_id"}}));
TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
const std::string assembly =