diff --git a/src/tint/reader/spirv/enum_converter.cc b/src/tint/reader/spirv/enum_converter.cc index d5905077e1..8c701565a5 100644 --- a/src/tint/reader/spirv/enum_converter.cc +++ b/src/tint/reader/spirv/enum_converter.cc @@ -82,6 +82,8 @@ ast::BuiltinValue EnumConverter::ToBuiltin(spv::BuiltIn b) { return ast::BuiltinValue::kLocalInvocationIndex; case spv::BuiltIn::GlobalInvocationId: return ast::BuiltinValue::kGlobalInvocationId; + case spv::BuiltIn::NumWorkgroups: + return ast::BuiltinValue::kNumWorkgroups; case spv::BuiltIn::WorkgroupId: return ast::BuiltinValue::kWorkgroupId; case spv::BuiltIn::SampleId: diff --git a/src/tint/reader/spirv/enum_converter_test.cc b/src/tint/reader/spirv/enum_converter_test.cc index b366ebd1e2..bdb0247863 100644 --- a/src/tint/reader/spirv/enum_converter_test.cc +++ b/src/tint/reader/spirv/enum_converter_test.cc @@ -192,6 +192,7 @@ INSTANTIATE_TEST_SUITE_P( BuiltinCase{spv::BuiltIn::LocalInvocationIndex, true, ast::BuiltinValue::kLocalInvocationIndex}, BuiltinCase{spv::BuiltIn::GlobalInvocationId, true, ast::BuiltinValue::kGlobalInvocationId}, + BuiltinCase{spv::BuiltIn::NumWorkgroups, true, ast::BuiltinValue::kNumWorkgroups}, BuiltinCase{spv::BuiltIn::WorkgroupId, true, ast::BuiltinValue::kWorkgroupId}, BuiltinCase{spv::BuiltIn::SampleId, true, ast::BuiltinValue::kSampleIndex}, BuiltinCase{spv::BuiltIn::SampleMask, true, ast::BuiltinValue::kSampleMask})); @@ -208,8 +209,6 @@ INSTANTIATE_TEST_SUITE_P(EnumConverterBad, testing::Values(BuiltinCase{static_cast(9999), false, ast::BuiltinValue::kUndefined}, BuiltinCase{static_cast(9999), false, - ast::BuiltinValue::kUndefined}, - BuiltinCase{spv::BuiltIn::NumWorkgroups, false, ast::BuiltinValue::kUndefined})); // Dim diff --git a/src/tint/reader/spirv/parser_impl_module_var_test.cc b/src/tint/reader/spirv/parser_impl_module_var_test.cc index 44fc4b3ba1..be931a6815 100644 --- a/src/tint/reader/spirv/parser_impl_module_var_test.cc +++ b/src/tint/reader/spirv/parser_impl_module_var_test.cc @@ -3106,6 +3106,14 @@ TEST_F(SpvModuleScopeVarParserTest, InstanceIndex_U32_FunctParam) { // Returns the start of a shader for testing LocalInvocationIndex, // parameterized by store type of %int or %uint std::string ComputeBuiltinInputPreamble(std::string builtin, std::string store_type) { + std::string ptr_component_type; + if (store_type == "%v3int") { + ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %int\n"; + } + if (store_type == "%v3uint") { + ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %uint\n"; + } + return R"( OpCapability Shader OpMemoryModel Logical Simple @@ -3118,10 +3126,11 @@ std::string ComputeBuiltinInputPreamble(std::string builtin, std::string store_t %float = OpTypeFloat 32 %uint = OpTypeInt 32 0 %int = OpTypeInt 32 1 + %int_1 = OpConstant %int 1 %v3uint = OpTypeVector %uint 3 %v3int = OpTypeVector %int 3 %ptr_ty = OpTypePointer Input )" + - store_type + R"( + store_type + ptr_component_type + R"( %1 = OpVariable %ptr_ty Input )"; } @@ -3331,14 +3340,84 @@ INSTANTIATE_TEST_SUITE_P(Samples, {"LocalInvocationId", "%v3int", "local_invocation_id"}, {"GlobalInvocationId", "%v3uint", "global_invocation_id"}, {"GlobalInvocationId", "%v3int", "global_invocation_id"}, + {"NumWorkgroups", "%v3uint", "num_workgroups"}, + {"NumWorkgroups", "%v3int", "num_workgroups"}, {"WorkgroupId", "%v3uint", "workgroup_id"}, {"WorkgroupId", "%v3int", "workgroup_id"}})); -// TODO(dneto): crbug.com/tint/752 -// NumWorkgroups support is blocked by crbug.com/tint/752 -// When the AST supports NumWorkgroups, add these cases: -// {"NumWorkgroups", "%uint", "num_workgroups"} -// {"NumWorkgroups", "%int", "num_workgroups"} +// For compute shader builtins that are vectors, test loading one component. +struct ComputeBuiltinInputVectorCase { + std::string spirv_builtin; + std::string spirv_store_type; + std::string spirv_component_store_type; + std::string wgsl_builtin; +}; +inline std::ostream& operator<<(std::ostream& o, ComputeBuiltinInputVectorCase c) { + return o << "ComputeBuiltinInputVectorCase(" << c.spirv_builtin << " " << c.spirv_store_type + << " " << c.spirv_component_store_type << " " << c.wgsl_builtin << ")"; +} + +using SpvModuleScopeVarParserTest_ComputeBuiltinVector = + SpvParserTestBase<::testing::TestWithParam>; + +TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltinVector, Load_Component_Direct) { + const auto wgsl_type = WgslType(GetParam().spirv_store_type); + const auto wgsl_component_type = WgslType(GetParam().spirv_component_store_type); + const auto wgsl_builtin = GetParam().wgsl_builtin; + const auto unsigned_wgsl_type = UnsignedWgslType(wgsl_type); + const auto signed_wgsl_type = SignedWgslType(wgsl_type); + const std::string assembly = + ComputeBuiltinInputPreamble(GetParam().spirv_builtin, GetParam().spirv_store_type) + + R"( + %main = OpFunction %void None %voidfn + %entry = OpLabel + %3 = OpAccessChain %ptr_comp_ty %1 %int_1 + %2 = OpLoad )" + + GetParam().spirv_component_store_type + R"( %3 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly; + EXPECT_TRUE(p->error().empty()); + const auto module_str = test::ToString(p->program()); + std::string expected = R"(var x_1 : ${wgsl_type}; + +fn main_1() { + let x_2 : ${wgsl_component_type} = x_1.y; + return; +} + +@compute @workgroup_size(1i, 1i, 1i) +fn main(@builtin(${wgsl_builtin}) x_1_param : ${unsigned_wgsl_type}) { + x_1 = ${assignment_value}; + main_1(); +} +)"; + + expected = utils::ReplaceAll(expected, "${wgsl_type}", wgsl_type); + expected = utils::ReplaceAll(expected, "${wgsl_component_type}", wgsl_component_type); + expected = utils::ReplaceAll(expected, "${unsigned_wgsl_type}", unsigned_wgsl_type); + expected = utils::ReplaceAll(expected, "${wgsl_builtin}", wgsl_builtin); + expected = utils::ReplaceAll(expected, "${assignment_value}", + (wgsl_type == unsigned_wgsl_type) + ? "x_1_param" + : "bitcast<" + signed_wgsl_type + ">(x_1_param)"); + + EXPECT_EQ(module_str, expected) << module_str; +} + +INSTANTIATE_TEST_SUITE_P(Samples, + SpvModuleScopeVarParserTest_ComputeBuiltinVector, + ::testing::ValuesIn(std::vector{ + {"LocalInvocationId", "%v3uint", "%uint", "local_invocation_id"}, + {"LocalInvocationId", "%v3int", "%int", "local_invocation_id"}, + {"GlobalInvocationId", "%v3uint", "%uint", "global_invocation_id"}, + {"GlobalInvocationId", "%v3int", "%int", "global_invocation_id"}, + {"NumWorkgroups", "%v3uint", "%uint", "num_workgroups"}, + {"NumWorkgroups", "%v3int", "%int", "num_workgroups"}, + {"WorkgroupId", "%v3uint", "%uint", "workgroup_id"}, + {"WorkgroupId", "%v3int", "%int", "workgroup_id"}})); TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) { const std::string assembly =