diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 2d520cb62b..0a17fd5b82 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -212,7 +212,8 @@ bool FunctionEmitter::EmitFunctionVariables() { // (OpenCL also allows the ID of an OpVariable, but we don't handle that // here.) var->set_constructor( - parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1))); + parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)) + .expr); } // TODO(dneto): Add the initializer via Variable::set_constructor. auto var_decl_stmt = @@ -224,12 +225,14 @@ bool FunctionEmitter::EmitFunctionVariables() { return success(); } -std::unique_ptr FunctionEmitter::MakeExpression(uint32_t id) { +TypedExpression FunctionEmitter::MakeExpression(uint32_t id) { if (failed()) { - return nullptr; + return {}; } if (identifier_values_.count(id)) { - return std::make_unique(namer_.Name(id)); + return TypedExpression( + parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()), + std::make_unique(namer_.Name(id))); } if (singly_used_values_.count(id)) { auto expr = std::move(singly_used_values_[id]); @@ -243,18 +246,19 @@ std::unique_ptr FunctionEmitter::MakeExpression(uint32_t id) { const auto* inst = def_use_mgr_->GetDef(id); if (inst == nullptr) { Fail() << "ID " << id << " does not have a defining SPIR-V instruction"; - return nullptr; + return {}; } switch (inst->opcode()) { case SpvOpVariable: // This occurs for module-scope variables. - return std::make_unique( - namer_.Name(inst->result_id())); + return TypedExpression(parser_impl_.ConvertType(inst->type_id()), + std::make_unique( + namer_.Name(inst->result_id()))); default: break; } Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint(); - return nullptr; + return {}; } bool FunctionEmitter::EmitFunctionBodyStatements() { @@ -284,8 +288,8 @@ bool FunctionEmitter::EmitStatementsInBasicBlock( bool FunctionEmitter::EmitConstDefinition( const spvtools::opt::Instruction& inst, - std::unique_ptr ast_expr) { - if (!ast_expr) { + TypedExpression ast_expr) { + if (!ast_expr.expr) { return false; } auto ast_const = @@ -294,7 +298,7 @@ bool FunctionEmitter::EmitConstDefinition( if (!ast_const) { return false; } - ast_const->set_constructor(std::move(ast_expr)); + ast_const->set_constructor(std::move(ast_expr.expr)); ast_const->set_is_const(true); ast_body_.emplace_back( std::make_unique(std::move(ast_const))); @@ -306,11 +310,12 @@ bool FunctionEmitter::EmitConstDefinition( bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { // Handle combinatorial instructions first. auto combinatorial_expr = MaybeEmitCombinatorialValue(inst); - if (combinatorial_expr != nullptr) { + if (combinatorial_expr.expr != nullptr) { if (def_use_mgr_->NumUses(&inst) == 1) { // If it's used once, then defer emitting the expression until it's used. // Any supporting statements have already been emitted. - singly_used_values_[inst.result_id()] = std::move(combinatorial_expr); + singly_used_values_.insert( + std::make_pair(inst.result_id(), std::move(combinatorial_expr))); return success(); } // Otherwise, generate a const definition for it now and later use @@ -327,7 +332,7 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { auto lhs = MakeExpression(inst.GetSingleWordInOperand(0)); auto rhs = MakeExpression(inst.GetSingleWordInOperand(1)); ast_body_.emplace_back(std::make_unique( - std::move(lhs), std::move(rhs))); + std::move(lhs.expr), std::move(rhs.expr))); return success(); } case SpvOpLoad: @@ -344,10 +349,10 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { return Fail() << "unhandled instruction with opcode " << inst.opcode(); } -std::unique_ptr FunctionEmitter::MaybeEmitCombinatorialValue( +TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( const spvtools::opt::Instruction& inst) { if (inst.result_id() == 0) { - return nullptr; + return {}; } // TODO(dneto): Fill in the following cases. @@ -356,10 +361,14 @@ std::unique_ptr FunctionEmitter::MaybeEmitCombinatorialValue( return this->MakeExpression(inst.GetSingleWordInOperand(operand_index)); }; + auto* ast_type = + inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr; + auto binary_op = ConvertBinaryOp(inst.opcode()); if (binary_op != ast::BinaryOp::kNone) { - return std::make_unique(binary_op, operand(0), - operand(1)); + return {ast_type, std::make_unique( + binary_op, std::move(operand(0).expr), + std::move(operand(1).expr))}; } // binary operator // unary operator @@ -393,7 +402,7 @@ std::unique_ptr FunctionEmitter::MaybeEmitCombinatorialValue( // OpCompositeExtract // OpCompositeInsert - return nullptr; + return {}; } } // namespace spirv diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index a22140cd46..433b078f39 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -97,12 +97,12 @@ class FunctionEmitter { /// @param ast_expr the already-computed AST expression for the value /// @returns false if emission failed. bool EmitConstDefinition(const spvtools::opt::Instruction& inst, - std::unique_ptr ast_expr); + TypedExpression ast_expr); /// Makes an expression /// @param id the SPIR-V ID of the value /// @returns true if emission has not yet failed. - std::unique_ptr MakeExpression(uint32_t id); + TypedExpression MakeExpression(uint32_t id); /// Creates an expression and supporting statements for a combinatorial /// instruction, or returns null. A SPIR-V instruction is combinatorial @@ -113,7 +113,7 @@ class FunctionEmitter { /// combinatorial. /// @param inst a SPIR-V instruction representing an exrpression /// @returns an AST expression for the instruction, or nullptr. - std::unique_ptr MaybeEmitCombinatorialValue( + TypedExpression MaybeEmitCombinatorialValue( const spvtools::opt::Instruction& inst); private: @@ -135,8 +135,7 @@ class FunctionEmitter { // The set of IDs that have already had an identifier name generated for it. std::unordered_set identifier_values_; // Mapping from SPIR-V ID that is used at most once, to its AST expression. - std::unordered_map> - singly_used_values_; + std::unordered_map singly_used_values_; }; } // namespace spirv diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index cd369711ff..795dc7d73b 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -715,7 +715,7 @@ bool ParserImpl::EmitModuleScopeVariables() { // (OpenCL also allows the ID of an OpVariable, but we don't handle that // here.) ast_var->set_constructor( - MakeConstantExpression(var.GetSingleWordInOperand(1))); + MakeConstantExpression(var.GetSingleWordInOperand(1)).expr); } // TODO(dneto): initializers (a.k.a. constructor expression) ast_module_.AddGlobalVariable(std::move(ast_var)); @@ -763,48 +763,50 @@ std::unique_ptr ParserImpl::MakeVariable(uint32_t id, return ast_var; } -std::unique_ptr ParserImpl::MakeConstantExpression( - uint32_t id) { +TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { if (!success_) { - return nullptr; + return {}; } const auto* inst = def_use_mgr_->GetDef(id); if (inst == nullptr) { Fail() << "ID " << id << " is not a registered instruction"; - return nullptr; + return {}; } auto* ast_type = ConvertType(inst->type_id()); if (ast_type == nullptr) { - return nullptr; + return {}; } // TODO(dneto): Handle spec constants too? const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id); if (spirv_const == nullptr) { Fail() << "ID " << id << " is not a constant"; - return nullptr; + return {}; } // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // So canonicalization should map that way too. // Currently "null" is missing from the WGSL parser. // See https://bugs.chromium.org/p/tint/issues/detail?id=34 if (ast_type->IsU32()) { - return std::make_unique( - std::make_unique(ast_type, spirv_const->GetU32())); + return {ast_type, std::make_unique( + std::make_unique( + ast_type, spirv_const->GetU32()))}; } if (ast_type->IsI32()) { - return std::make_unique( - std::make_unique(ast_type, spirv_const->GetS32())); + return {ast_type, std::make_unique( + std::make_unique( + ast_type, spirv_const->GetS32()))}; } if (ast_type->IsF32()) { - return std::make_unique( - std::make_unique(ast_type, spirv_const->GetFloat())); + return {ast_type, std::make_unique( + std::make_unique( + ast_type, spirv_const->GetFloat()))}; } if (ast_type->IsBool()) { const bool value = spirv_const->AsNullConstant() ? false : spirv_const->AsBoolConstant()->value(); - return std::make_unique( - std::make_unique(ast_type, value)); + return {ast_type, std::make_unique( + std::make_unique(ast_type, value))}; } auto* spirv_composite_const = spirv_const->AsCompositeConstant(); if (spirv_composite_const != nullptr) { @@ -820,21 +822,21 @@ std::unique_ptr ParserImpl::MakeConstantExpression( if (def == nullptr) { Fail() << "internal error: SPIR-V constant doesn't have defining " "instruction"; - return nullptr; + return {}; } auto ast_component = MakeConstantExpression(def->result_id()); if (!success_) { // We've already emitted a diagnostic. - return nullptr; + return {}; } - ast_components.emplace_back(std::move(ast_component)); + ast_components.emplace_back(std::move(ast_component.expr)); } - return std::make_unique( - ast_type, std::move(ast_components)); + return {ast_type, std::make_unique( + ast_type, std::move(ast_components))}; } Fail() << "Unhandled constant type " << inst->type_id() << " for value ID " << id; - return nullptr; + return {}; } bool ParserImpl::EmitFunctions() { diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index da73bcf824..91c22e01eb 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -50,6 +50,25 @@ namespace spirv { using Decoration = std::vector; using DecorationList = std::vector; +// An AST expression with its type. +struct TypedExpression { + /// Dummy constructor + TypedExpression() : type(nullptr), expr(nullptr) {} + /// Constructor + /// @param t the type + /// @param e the expression + TypedExpression(ast::type::Type* t, std::unique_ptr e) + : type(t), expr(std::move(e)) {} + /// Move constructor + /// @param other the other typed expression + TypedExpression(TypedExpression&& other) + : type(other.type), expr(std::move(other.expr)) {} + /// The type + ast::type::Type* type; + /// The expression + std::unique_ptr expr; +}; + /// Parser implementation for SPIR-V. class ParserImpl : Reader { public: @@ -224,7 +243,7 @@ class ParserImpl : Reader { /// Creates an AST expression node for a SPIR-V constant. /// @param id the SPIR-V ID of the constant /// @returns a new Literal node - std::unique_ptr MakeConstantExpression(uint32_t id); + TypedExpression MakeConstantExpression(uint32_t id); private: /// Converts a specific SPIR-V type to a Tint type. Integer case