diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 3137cc42d0..c3a6d9b088 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -2605,7 +2605,8 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { default: break; } - return Fail() << "unhandled instruction with opcode " << inst.opcode(); + return Fail() << "unhandled instruction with opcode " << inst.opcode() << ": " + << inst.PrettyPrint(); } TypedExpression FunctionEmitter::MakeOperand( @@ -2703,6 +2704,10 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( return {ast_type, parser_impl_.MakeNullValue(ast_type)}; } + if (opcode == SpvOpSelect) { + return MakeSimpleSelect(inst); + } + // builtin readonly function // glsl.std.450 readonly function @@ -3375,6 +3380,33 @@ bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) { return EmitConstDefOrWriteToHoistedVar(inst, std::move(expr)); } +TypedExpression FunctionEmitter::MakeSimpleSelect( + const spvtools::opt::Instruction& inst) { + auto condition = MakeOperand(inst, 0); + auto operand1 = MakeOperand(inst, 1); + auto operand2 = MakeOperand(inst, 2); + + // SPIR-V validation requires: + // - the condition to be bool or bool vector, so we don't check it here. + // - operand1, operand2, and result type to match. + // - you can't select over pointers or pointer vectors, unless you also have + // a VariablePointers* capability, which is not allowed in by WebGPU. + auto* op_ty = operand1.type; + if (op_ty->IsVector() || op_ty->is_float_scalar() || + op_ty->is_integer_scalar() || op_ty->IsBool()) { + ast::ExpressionList params; + params.push_back(std::move(operand1.expr)); + params.push_back(std::move(operand2.expr)); + // The condition goes last. + params.push_back(std::move(condition.expr)); + return {operand1.type, + std::make_unique( + std::make_unique("select"), + std::move(params))}; + } + return {}; +} + } // namespace spirv } // namespace reader } // namespace tint diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index bde237dc5e..5861570b62 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -676,6 +676,13 @@ class FunctionEmitter { /// @returns false if emission failed bool EmitFunctionCall(const spvtools::opt::Instruction& inst); + /// Returns an expression for an OpSelect, if its operands are scalars + /// or vectors. These translate directly to WGSL select. Otherwise, return + /// an expression with a null owned expression + /// @param inst the SPIR-V OpSelect instruction + /// @returns a typed expression, or one with a null owned expression + TypedExpression MakeSimpleSelect(const spvtools::opt::Instruction& inst); + /// Finds the header block for a structured construct that we can "break" /// out from, from deeply nested control flow, if such a block exists. /// If the construct is: diff --git a/src/reader/spirv/function_logical_test.cc b/src/reader/spirv/function_logical_test.cc index 4a0f1df6d1..8f4272c1f3 100644 --- a/src/reader/spirv/function_logical_test.cc +++ b/src/reader/spirv/function_logical_test.cc @@ -1104,12 +1104,187 @@ TEST_F(SpvFUnordTest, FUnordGreaterThanEqual_Vector) { << ToString(fe.ast_body()); } +TEST_F(SpvFUnordTest, Select_BoolCond_BoolParams) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpSelect %bool %true %true %false + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{ + Variable{ + x_1 + none + __bool + { + Call{ + Identifier{select} + ( + ScalarConstructor{true} + ScalarConstructor{false} + ScalarConstructor{true} + ) + } + } + } +})")) << ToString(fe.ast_body()); +} + +TEST_F(SpvFUnordTest, Select_BoolCond_IntScalarParams) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpSelect %uint %true %uint_10 %uint_20 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{ + Variable{ + x_1 + none + __u32 + { + Call{ + Identifier{select} + ( + ScalarConstructor{10} + ScalarConstructor{20} + ScalarConstructor{true} + ) + } + } + } +})")) << ToString(fe.ast_body()); +} + +TEST_F(SpvFUnordTest, Select_BoolCond_FloatScalarParams) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpSelect %float %true %float_50 %float_60 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{ + Variable{ + x_1 + none + __f32 + { + Call{ + Identifier{select} + ( + ScalarConstructor{50.000000} + ScalarConstructor{60.000000} + ScalarConstructor{true} + ) + } + } + } +})")) << ToString(fe.ast_body()); +} + +TEST_F(SpvFUnordTest, Select_BoolCond_VectorParams) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpSelect %v2uint %true %v2uint_10_20 %v2uint_20_10 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{ + Variable{ + x_1 + none + __vec_2__u32 + { + Call{ + Identifier{select} + ( + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{10} + ScalarConstructor{20} + } + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{20} + ScalarConstructor{10} + } + ScalarConstructor{true} + ) + } + } + } +})")) << ToString(fe.ast_body()); +} + +TEST_F(SpvFUnordTest, Select_VecBoolCond_VectorParams) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpSelect %v2uint %v2bool_t_f %v2uint_10_20 %v2uint_20_10 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{ + Variable{ + x_1 + none + __vec_2__u32 + { + Call{ + Identifier{select} + ( + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{10} + ScalarConstructor{20} + } + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{20} + ScalarConstructor{10} + } + TypeConstructor{ + __vec_2__bool + ScalarConstructor{true} + ScalarConstructor{false} + } + ) + } + } + } +})")) << ToString(fe.ast_body()); +} + // TODO(dneto): OpAny - likely builtin function TBD // TODO(dneto): OpAll - likely builtin function TBD // TODO(dneto): OpIsNan - likely builtin function TBD // TODO(dneto): OpIsInf - likely builtin function TBD // TODO(dneto): Kernel-guarded instructions. -// TODO(dneto): OpSelect - likely builtin function TBD +// TODO(dneto): OpSelect over more general types, as in SPIR-V 1.4 } // namespace } // namespace spirv