diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3d2e507057..0977215d87 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -328,6 +328,7 @@ if(${TINT_BUILD_SPV_READER}) list(APPEND TINT_TEST_SRCS reader/spirv/enum_converter_test.cc reader/spirv/fail_stream_test.cc + reader/spirv/function_arithmetic_test.cc reader/spirv/function_decl_test.cc reader/spirv/function_var_test.cc reader/spirv/function_memory_test.cc diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index e80741552d..5a2da1f807 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -21,6 +21,7 @@ #include "source/opt/instruction.h" #include "source/opt/module.h" #include "src/ast/assignment_statement.h" +#include "src/ast/binary_expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/storage_class.h" @@ -34,6 +35,19 @@ namespace tint { namespace reader { namespace spirv { +namespace { +// @returns the AST binary op for the given opcode, or kNone +ast::BinaryOp ConvertBinaryOp(SpvOp opcode) { + switch (opcode) { + case SpvOpIAdd: + return ast::BinaryOp::kAdd; + default: + break; + } + return ast::BinaryOp::kNone; +} +} // namespace + FunctionEmitter::FunctionEmitter(ParserImpl* pi, const spvtools::opt::Function& function) : parser_impl_(*pi), @@ -180,6 +194,11 @@ std::unique_ptr FunctionEmitter::MakeExpression(uint32_t id) { if (identifier_values_.count(id)) { return std::make_unique(namer_.Name(id)); } + if (singly_used_values_.count(id)) { + auto expr = std::move(singly_used_values_[id]); + singly_used_values_.erase(id); + return expr; + } const auto* spirv_constant = constant_mgr_->FindDeclaredConstant(id); if (spirv_constant) { return parser_impl_.MakeConstantExpression(id); @@ -222,7 +241,45 @@ bool FunctionEmitter::EmitStatementsInBasicBlock( return true; } +bool FunctionEmitter::EmitConstDefinition( + const spvtools::opt::Instruction& inst, + std::unique_ptr ast_expr) { + if (!ast_expr) { + return false; + } + auto ast_const = + parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, + parser_impl_.ConvertType(inst.type_id())); + if (!ast_const) { + return false; + } + ast_const->set_constructor(std::move(ast_expr)); + ast_const->set_is_const(true); + ast_body_.emplace_back( + std::make_unique(std::move(ast_const))); + // Save this as an already-named value. + identifier_values_.insert(inst.result_id()); + return success(); +} + bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { + // Handle combinatorial instructions first. + auto combinatorial_expr = MaybeEmitCombinatorialValue(inst); + if (combinatorial_expr != nullptr) { + if (def_use_mgr_->NumUses(&inst) == 1) { + // If it's used once, then defer emitting the expression until it's used. + // Any supporting statements have already been emitted. + singly_used_values_[inst.result_id()] = std::move(combinatorial_expr); + return success(); + } + // Otherwise, generate a const definition for it now and later use + // the const's name at the uses of the value. + return EmitConstDefinition(inst, std::move(combinatorial_expr)); + } + if (failed()) { + return false; + } + switch (inst.opcode()) { case SpvOpStore: { // TODO(dneto): Order of evaluation? @@ -232,27 +289,11 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { std::move(lhs), std::move(rhs))); return success(); } - case SpvOpLoad: { + case SpvOpLoad: // Memory accesses must be issued in SPIR-V program order. // So represent a load by a new const definition. - auto ast_initializer = MakeExpression(inst.GetSingleWordInOperand(0)); - if (!ast_initializer) { - return false; - } - auto ast_const = - parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, - parser_impl_.ConvertType(inst.type_id())); - if (!ast_const) { - return false; - } - ast_const->set_constructor(std::move(ast_initializer)); - ast_const->set_is_const(true); - ast_body_.emplace_back( - std::make_unique(std::move(ast_const))); - // Save this as an already-named value. - identifier_values_.insert(inst.result_id()); - return success(); - } + return EmitConstDefinition( + inst, MakeExpression(inst.GetSingleWordInOperand(0))); case SpvOpFunctionCall: // TODO(dneto): Fill this out. Make this pass, for existing tests return success(); @@ -262,6 +303,58 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { return Fail() << "unhandled instruction with opcode " << inst.opcode(); } +std::unique_ptr FunctionEmitter::MaybeEmitCombinatorialValue( + const spvtools::opt::Instruction& inst) { + if (inst.result_id() == 0) { + return nullptr; + } + + // TODO(dneto): Fill in the following cases. + + auto operand = [this, &inst](uint32_t operand_index) { + return this->MakeExpression(inst.GetSingleWordInOperand(operand_index)); + }; + + auto binary_op = ConvertBinaryOp(inst.opcode()); + if (binary_op != ast::BinaryOp::kNone) { + return std::make_unique(binary_op, operand(0), + operand(1)); + } + // binary operator + // unary operator + // builtin readonly function + // glsl.std.450 readonly function + + // Instructions: + // OpCopyObject + // OpUndef + // OpBitcast + // OpSatConvertSToU + // OpSatConvertUToS + // OpSatConvertFToS + // OpSatConvertFToU + // OpSatConvertSToF + // OpSatConvertUToF + // OpUConvert + // OpSConvert + // OpFConvert + // OpConvertPtrToU // Not in WebGPU + // OpConvertUToPtr // Not in WebGPU + // OpPtrCastToGeneric // Not in Vulkan + // OpGenericCastToPtr // Not in Vulkan + // OpGenericCastToPtrExplicit // Not in Vulkan + // + // OpAccessChain + // OpInBoundsAccessChain + // OpArrayLength + // OpVectorExtractDynamic + // OpVectorInsertDynamic + // OpCompositeExtract + // OpCompositeInsert + + return nullptr; +} + } // namespace spirv } // namespace reader } // namespace tint diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 86f518d5ec..b5d588652e 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -90,11 +90,30 @@ class FunctionEmitter { /// @returns false if emission failed. bool EmitStatement(const spvtools::opt::Instruction& inst); + /// Emits a const definition for a SPIR-V value. + /// @param inst the SPIR-V instruction defining the value + /// @param ast_expr the already-computed AST expression for the value + /// @returns false if emission failed. + bool EmitConstDefinition(const spvtools::opt::Instruction& inst, + std::unique_ptr ast_expr); + /// Makes an expression /// @param id the SPIR-V ID of the value /// @returns true if emission has not yet failed. std::unique_ptr MakeExpression(uint32_t id); + /// Creates an expression and supporting statements for a combinatorial + /// instruction, or returns null. A SPIR-V instruction is combinatorial + /// if it has no side effects and its result depends only on its operands, + /// and not on accessing external state like memory or the state of other + /// invocations. Statements are only created if required to provide values + /// to the expression. Supporting statements are not required to be + /// combinatorial. + /// @param inst a SPIR-V instruction representing an exrpression + /// @returns an AST expression for the instruction, or nullptr. + std::unique_ptr MaybeEmitCombinatorialValue( + const spvtools::opt::Instruction& inst); + private: /// @returns the store type for the OpVariable instruction, or /// null on failure. @@ -113,6 +132,9 @@ class FunctionEmitter { ast::StatementList ast_body_; // The set of IDs that have already had an identifier name generated for it. std::unordered_set identifier_values_; + // Mapping from SPIR-V ID that is used at most once, to its AST expression. + std::unordered_map> + singly_used_values_; }; } // namespace spirv diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc new file mode 100644 index 0000000000..426e80efbe --- /dev/null +++ b/src/reader/spirv/function_arithmetic_test.cc @@ -0,0 +1,188 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/reader/spirv/function.h" + +#include +#include + +#include "gmock/gmock.h" +#include "src/reader/spirv/parser_impl.h" +#include "src/reader/spirv/parser_impl_test_helper.h" +#include "src/reader/spirv/spirv_tools_helpers_test.h" + +namespace tint { +namespace reader { +namespace spirv { +namespace { + +using ::testing::HasSubstr; + +std::string CommonTypes() { + return R"( + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + + %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 + %float = OpTypeFloat 32 + + %uint_10 = OpConstant %uint 10 + %uint_20 = OpConstant %uint 20 + %int_30 = OpConstant %int 30 + %int_40 = OpConstant %int 40 + %float_50 = OpConstant %uint 50 + %float_60 = OpConstant %uint 60 + + %ptr_uint = OpTypePointer Function %uint + %ptr_int = OpTypePointer Function %int + %ptr_float = OpTypePointer Function %float + + %v2uint = OpTypeVector %uint 2 + %v2int = OpTypeVector %int 2 + + %v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20 + %v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10 + %v2int_30_40 = OpConstantComposite %v2int %int_30 %int_40 + %v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30 +)"; +} + +// Returns the AST dump for a given SPIR-V assembly constant. +std::string AstFor(std::string assembly) { + if (assembly == "v2uint_10_20") { + return R"(TypeConstructor{ + __vec_2__u32 + ScalarConstructor{10} + ScalarConstructor{20} + })"; + } + if (assembly == "v2uint_20_10") { + return R"(TypeConstructor{ + __vec_2__u32 + ScalarConstructor{20} + ScalarConstructor{10} + })"; + } + if (assembly == "v2int_30_40") { + return R"(TypeConstructor{ + __vec_2__i32 + ScalarConstructor{30} + ScalarConstructor{40} + })"; + } + if (assembly == "v2int_40_30") { + return R"(TypeConstructor{ + __vec_2__i32 + ScalarConstructor{40} + ScalarConstructor{30} + })"; + } + return "bad case"; +} + +struct BinaryData { + const std::string res_type; + const std::string lhs; + const std::string op; + const std::string rhs; + const std::string ast_type; + const std::string ast_lhs; + const std::string ast_op; + const std::string ast_rhs; +}; +inline std::ostream& operator<<(std::ostream& out, BinaryData data) { + out << "BinaryData{" << data.res_type << "," << data.lhs << "," << data.op + << "," << data.rhs << "," << data.ast_type << "," << data.ast_lhs << "," + << data.ast_op << "," << data.ast_rhs << "}"; + return out; +} + +using SpvBinaryTest = SpvParserTestBase<::testing::TestWithParam>; + +TEST_P(SpvBinaryTest, EmitExpression) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = )" + GetParam().op + + " %" + GetParam().res_type + " %" + GetParam().lhs + + " %" + GetParam().rhs + R"( + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) + << p->error() << "\n" + << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + std::ostringstream ss; + ss << R"(Variable{ + x_1 + none + )" + << GetParam().ast_type << "\n {\n Binary{" + << "\n " << GetParam().ast_lhs << "\n " << GetParam().ast_op + << "\n " << GetParam().ast_rhs; + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(ss.str())) << assembly; +} +INSTANTIATE_TEST_SUITE_P(SpvParserTest, + SpvBinaryTest, + ::testing::Values( + // Both uint + BinaryData{ + "uint", "uint_10", "OpIAdd", "uint_20", "__u32", + "ScalarConstructor{10}", "add", + "ScalarConstructor{20}"}, + // Both int + BinaryData{ + "int", "int_30", "OpIAdd", "int_40", "__i32", + "ScalarConstructor{30}", "add", + "ScalarConstructor{40}"}, + // Mixed, returning uint + BinaryData{ + "uint", "int_30", "OpIAdd", "uint_10", "__u32", + "ScalarConstructor{30}", "add", + "ScalarConstructor{10}"}, + // Mixed, returning int + BinaryData{ + "int", "int_30", "OpIAdd", "uint_10", "__i32", + "ScalarConstructor{30}", "add", + "ScalarConstructor{10}"}, + // Both v2uint + BinaryData{ + "v2uint", "v2uint_10_20", "OpIAdd", "v2uint_20_10", "__vec_2__u32", + AstFor("v2uint_10_20"), "add", + AstFor("v2uint_20_10")}, + // Both v2int + BinaryData{ + "v2int", "v2int_30_40", "OpIAdd", "v2int_40_30", "__vec_2__i32", + AstFor("v2int_30_40"), "add", + AstFor("v2int_40_30")}, + // Mixed, returning v2uint + BinaryData{ + "v2uint", "v2int_30_40", "OpIAdd", "v2uint_10_20", "__vec_2__u32", + AstFor("v2int_30_40"), "add", + AstFor("v2uint_10_20")}, + // Mixed, returning v2int + BinaryData{ + "v2int", "v2int_40_30", "OpIAdd", "v2uint_20_10", "__vec_2__i32", + AstFor("v2int_40_30"), "add", + AstFor("v2uint_20_10")} + )); + +} // namespace +} // namespace spirv +} // namespace reader +} // namespace tint diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h index e6cb026baa..4e98d26d1b 100644 --- a/src/reader/spirv/parser_impl_test_helper.h +++ b/src/reader/spirv/parser_impl_test_helper.h @@ -30,10 +30,11 @@ namespace reader { namespace spirv { /// SPIR-V Parser test class -class SpvParserTest : public testing::Test { +template +class SpvParserTestBase : public T { public: - SpvParserTest() = default; - ~SpvParserTest() = default; + SpvParserTestBase() = default; + ~SpvParserTestBase() = default; /// Sets up the test helper void SetUp() { ctx_.Reset(); } @@ -63,6 +64,9 @@ class SpvParserTest : public testing::Test { Context ctx_; }; +// Use this form when you don't need to template any further. +using SpvParserTest = SpvParserTestBase<::testing::Test>; + /// Returns the string dump of a function body. /// @param body the statement in the body /// @returnss the string dump of a function body.