[spirv-reader] Internally, generate typed expressions

The AST only wants expressions, not their result types.
But the SPIR-V reader wants to track the AST type as well.
So introduce a TypedExpression concept for internal use.

Bug: tint:3
Change-Id: Ia832f7422440ef0e8e04630cdca98cae20e18921
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20040
Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
David Neto 2020-04-20 21:06:43 +00:00
parent 53b5730dfc
commit b572d53bf2
4 changed files with 75 additions and 46 deletions

View File

@ -212,7 +212,8 @@ bool FunctionEmitter::EmitFunctionVariables() {
// (OpenCL also allows the ID of an OpVariable, but we don't handle that // (OpenCL also allows the ID of an OpVariable, but we don't handle that
// here.) // here.)
var->set_constructor( var->set_constructor(
parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1))); parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1))
.expr);
} }
// TODO(dneto): Add the initializer via Variable::set_constructor. // TODO(dneto): Add the initializer via Variable::set_constructor.
auto var_decl_stmt = auto var_decl_stmt =
@ -224,12 +225,14 @@ bool FunctionEmitter::EmitFunctionVariables() {
return success(); return success();
} }
std::unique_ptr<ast::Expression> FunctionEmitter::MakeExpression(uint32_t id) { TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
if (failed()) { if (failed()) {
return nullptr; return {};
} }
if (identifier_values_.count(id)) { if (identifier_values_.count(id)) {
return std::make_unique<ast::IdentifierExpression>(namer_.Name(id)); return TypedExpression(
parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
std::make_unique<ast::IdentifierExpression>(namer_.Name(id)));
} }
if (singly_used_values_.count(id)) { if (singly_used_values_.count(id)) {
auto expr = std::move(singly_used_values_[id]); auto expr = std::move(singly_used_values_[id]);
@ -243,18 +246,19 @@ std::unique_ptr<ast::Expression> FunctionEmitter::MakeExpression(uint32_t id) {
const auto* inst = def_use_mgr_->GetDef(id); const auto* inst = def_use_mgr_->GetDef(id);
if (inst == nullptr) { if (inst == nullptr) {
Fail() << "ID " << id << " does not have a defining SPIR-V instruction"; Fail() << "ID " << id << " does not have a defining SPIR-V instruction";
return nullptr; return {};
} }
switch (inst->opcode()) { switch (inst->opcode()) {
case SpvOpVariable: case SpvOpVariable:
// This occurs for module-scope variables. // This occurs for module-scope variables.
return std::make_unique<ast::IdentifierExpression>( return TypedExpression(parser_impl_.ConvertType(inst->type_id()),
namer_.Name(inst->result_id())); std::make_unique<ast::IdentifierExpression>(
namer_.Name(inst->result_id())));
default: default:
break; break;
} }
Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint(); Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint();
return nullptr; return {};
} }
bool FunctionEmitter::EmitFunctionBodyStatements() { bool FunctionEmitter::EmitFunctionBodyStatements() {
@ -284,8 +288,8 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(
bool FunctionEmitter::EmitConstDefinition( bool FunctionEmitter::EmitConstDefinition(
const spvtools::opt::Instruction& inst, const spvtools::opt::Instruction& inst,
std::unique_ptr<ast::Expression> ast_expr) { TypedExpression ast_expr) {
if (!ast_expr) { if (!ast_expr.expr) {
return false; return false;
} }
auto ast_const = auto ast_const =
@ -294,7 +298,7 @@ bool FunctionEmitter::EmitConstDefinition(
if (!ast_const) { if (!ast_const) {
return false; return false;
} }
ast_const->set_constructor(std::move(ast_expr)); ast_const->set_constructor(std::move(ast_expr.expr));
ast_const->set_is_const(true); ast_const->set_is_const(true);
ast_body_.emplace_back( ast_body_.emplace_back(
std::make_unique<ast::VariableDeclStatement>(std::move(ast_const))); std::make_unique<ast::VariableDeclStatement>(std::move(ast_const)));
@ -306,11 +310,12 @@ bool FunctionEmitter::EmitConstDefinition(
bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
// Handle combinatorial instructions first. // Handle combinatorial instructions first.
auto combinatorial_expr = MaybeEmitCombinatorialValue(inst); auto combinatorial_expr = MaybeEmitCombinatorialValue(inst);
if (combinatorial_expr != nullptr) { if (combinatorial_expr.expr != nullptr) {
if (def_use_mgr_->NumUses(&inst) == 1) { if (def_use_mgr_->NumUses(&inst) == 1) {
// If it's used once, then defer emitting the expression until it's used. // If it's used once, then defer emitting the expression until it's used.
// Any supporting statements have already been emitted. // Any supporting statements have already been emitted.
singly_used_values_[inst.result_id()] = std::move(combinatorial_expr); singly_used_values_.insert(
std::make_pair(inst.result_id(), std::move(combinatorial_expr)));
return success(); return success();
} }
// Otherwise, generate a const definition for it now and later use // Otherwise, generate a const definition for it now and later use
@ -327,7 +332,7 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
auto lhs = MakeExpression(inst.GetSingleWordInOperand(0)); auto lhs = MakeExpression(inst.GetSingleWordInOperand(0));
auto rhs = MakeExpression(inst.GetSingleWordInOperand(1)); auto rhs = MakeExpression(inst.GetSingleWordInOperand(1));
ast_body_.emplace_back(std::make_unique<ast::AssignmentStatement>( ast_body_.emplace_back(std::make_unique<ast::AssignmentStatement>(
std::move(lhs), std::move(rhs))); std::move(lhs.expr), std::move(rhs.expr)));
return success(); return success();
} }
case SpvOpLoad: case SpvOpLoad:
@ -344,10 +349,10 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
return Fail() << "unhandled instruction with opcode " << inst.opcode(); return Fail() << "unhandled instruction with opcode " << inst.opcode();
} }
std::unique_ptr<ast::Expression> FunctionEmitter::MaybeEmitCombinatorialValue( TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
const spvtools::opt::Instruction& inst) { const spvtools::opt::Instruction& inst) {
if (inst.result_id() == 0) { if (inst.result_id() == 0) {
return nullptr; return {};
} }
// TODO(dneto): Fill in the following cases. // TODO(dneto): Fill in the following cases.
@ -356,10 +361,14 @@ std::unique_ptr<ast::Expression> FunctionEmitter::MaybeEmitCombinatorialValue(
return this->MakeExpression(inst.GetSingleWordInOperand(operand_index)); return this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
}; };
auto* ast_type =
inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr;
auto binary_op = ConvertBinaryOp(inst.opcode()); auto binary_op = ConvertBinaryOp(inst.opcode());
if (binary_op != ast::BinaryOp::kNone) { if (binary_op != ast::BinaryOp::kNone) {
return std::make_unique<ast::BinaryExpression>(binary_op, operand(0), return {ast_type, std::make_unique<ast::BinaryExpression>(
operand(1)); binary_op, std::move(operand(0).expr),
std::move(operand(1).expr))};
} }
// binary operator // binary operator
// unary operator // unary operator
@ -393,7 +402,7 @@ std::unique_ptr<ast::Expression> FunctionEmitter::MaybeEmitCombinatorialValue(
// OpCompositeExtract // OpCompositeExtract
// OpCompositeInsert // OpCompositeInsert
return nullptr; return {};
} }
} // namespace spirv } // namespace spirv

View File

@ -97,12 +97,12 @@ class FunctionEmitter {
/// @param ast_expr the already-computed AST expression for the value /// @param ast_expr the already-computed AST expression for the value
/// @returns false if emission failed. /// @returns false if emission failed.
bool EmitConstDefinition(const spvtools::opt::Instruction& inst, bool EmitConstDefinition(const spvtools::opt::Instruction& inst,
std::unique_ptr<ast::Expression> ast_expr); TypedExpression ast_expr);
/// Makes an expression /// Makes an expression
/// @param id the SPIR-V ID of the value /// @param id the SPIR-V ID of the value
/// @returns true if emission has not yet failed. /// @returns true if emission has not yet failed.
std::unique_ptr<ast::Expression> MakeExpression(uint32_t id); TypedExpression MakeExpression(uint32_t id);
/// Creates an expression and supporting statements for a combinatorial /// Creates an expression and supporting statements for a combinatorial
/// instruction, or returns null. A SPIR-V instruction is combinatorial /// instruction, or returns null. A SPIR-V instruction is combinatorial
@ -113,7 +113,7 @@ class FunctionEmitter {
/// combinatorial. /// combinatorial.
/// @param inst a SPIR-V instruction representing an exrpression /// @param inst a SPIR-V instruction representing an exrpression
/// @returns an AST expression for the instruction, or nullptr. /// @returns an AST expression for the instruction, or nullptr.
std::unique_ptr<ast::Expression> MaybeEmitCombinatorialValue( TypedExpression MaybeEmitCombinatorialValue(
const spvtools::opt::Instruction& inst); const spvtools::opt::Instruction& inst);
private: private:
@ -135,8 +135,7 @@ class FunctionEmitter {
// The set of IDs that have already had an identifier name generated for it. // The set of IDs that have already had an identifier name generated for it.
std::unordered_set<uint32_t> identifier_values_; std::unordered_set<uint32_t> identifier_values_;
// Mapping from SPIR-V ID that is used at most once, to its AST expression. // Mapping from SPIR-V ID that is used at most once, to its AST expression.
std::unordered_map<uint32_t, std::unique_ptr<ast::Expression>> std::unordered_map<uint32_t, TypedExpression> singly_used_values_;
singly_used_values_;
}; };
} // namespace spirv } // namespace spirv

View File

@ -715,7 +715,7 @@ bool ParserImpl::EmitModuleScopeVariables() {
// (OpenCL also allows the ID of an OpVariable, but we don't handle that // (OpenCL also allows the ID of an OpVariable, but we don't handle that
// here.) // here.)
ast_var->set_constructor( ast_var->set_constructor(
MakeConstantExpression(var.GetSingleWordInOperand(1))); MakeConstantExpression(var.GetSingleWordInOperand(1)).expr);
} }
// TODO(dneto): initializers (a.k.a. constructor expression) // TODO(dneto): initializers (a.k.a. constructor expression)
ast_module_.AddGlobalVariable(std::move(ast_var)); ast_module_.AddGlobalVariable(std::move(ast_var));
@ -763,48 +763,50 @@ std::unique_ptr<ast::Variable> ParserImpl::MakeVariable(uint32_t id,
return ast_var; return ast_var;
} }
std::unique_ptr<ast::Expression> ParserImpl::MakeConstantExpression( TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
uint32_t id) {
if (!success_) { if (!success_) {
return nullptr; return {};
} }
const auto* inst = def_use_mgr_->GetDef(id); const auto* inst = def_use_mgr_->GetDef(id);
if (inst == nullptr) { if (inst == nullptr) {
Fail() << "ID " << id << " is not a registered instruction"; Fail() << "ID " << id << " is not a registered instruction";
return nullptr; return {};
} }
auto* ast_type = ConvertType(inst->type_id()); auto* ast_type = ConvertType(inst->type_id());
if (ast_type == nullptr) { if (ast_type == nullptr) {
return nullptr; return {};
} }
// TODO(dneto): Handle spec constants too? // TODO(dneto): Handle spec constants too?
const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id); const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id);
if (spirv_const == nullptr) { if (spirv_const == nullptr) {
Fail() << "ID " << id << " is not a constant"; Fail() << "ID " << id << " is not a constant";
return nullptr; return {};
} }
// TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
// So canonicalization should map that way too. // So canonicalization should map that way too.
// Currently "null<type>" is missing from the WGSL parser. // Currently "null<type>" is missing from the WGSL parser.
// See https://bugs.chromium.org/p/tint/issues/detail?id=34 // See https://bugs.chromium.org/p/tint/issues/detail?id=34
if (ast_type->IsU32()) { if (ast_type->IsU32()) {
return std::make_unique<ast::ScalarConstructorExpression>( return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::UintLiteral>(ast_type, spirv_const->GetU32())); std::make_unique<ast::UintLiteral>(
ast_type, spirv_const->GetU32()))};
} }
if (ast_type->IsI32()) { if (ast_type->IsI32()) {
return std::make_unique<ast::ScalarConstructorExpression>( return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(ast_type, spirv_const->GetS32())); std::make_unique<ast::IntLiteral>(
ast_type, spirv_const->GetS32()))};
} }
if (ast_type->IsF32()) { if (ast_type->IsF32()) {
return std::make_unique<ast::ScalarConstructorExpression>( return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(ast_type, spirv_const->GetFloat())); std::make_unique<ast::FloatLiteral>(
ast_type, spirv_const->GetFloat()))};
} }
if (ast_type->IsBool()) { if (ast_type->IsBool()) {
const bool value = spirv_const->AsNullConstant() const bool value = spirv_const->AsNullConstant()
? false ? false
: spirv_const->AsBoolConstant()->value(); : spirv_const->AsBoolConstant()->value();
return std::make_unique<ast::ScalarConstructorExpression>( return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(ast_type, value)); std::make_unique<ast::BoolLiteral>(ast_type, value))};
} }
auto* spirv_composite_const = spirv_const->AsCompositeConstant(); auto* spirv_composite_const = spirv_const->AsCompositeConstant();
if (spirv_composite_const != nullptr) { if (spirv_composite_const != nullptr) {
@ -820,21 +822,21 @@ std::unique_ptr<ast::Expression> ParserImpl::MakeConstantExpression(
if (def == nullptr) { if (def == nullptr) {
Fail() << "internal error: SPIR-V constant doesn't have defining " Fail() << "internal error: SPIR-V constant doesn't have defining "
"instruction"; "instruction";
return nullptr; return {};
} }
auto ast_component = MakeConstantExpression(def->result_id()); auto ast_component = MakeConstantExpression(def->result_id());
if (!success_) { if (!success_) {
// We've already emitted a diagnostic. // We've already emitted a diagnostic.
return nullptr; return {};
} }
ast_components.emplace_back(std::move(ast_component)); ast_components.emplace_back(std::move(ast_component.expr));
} }
return std::make_unique<ast::TypeConstructorExpression>( return {ast_type, std::make_unique<ast::TypeConstructorExpression>(
ast_type, std::move(ast_components)); ast_type, std::move(ast_components))};
} }
Fail() << "Unhandled constant type " << inst->type_id() << " for value ID " Fail() << "Unhandled constant type " << inst->type_id() << " for value ID "
<< id; << id;
return nullptr; return {};
} }
bool ParserImpl::EmitFunctions() { bool ParserImpl::EmitFunctions() {

View File

@ -50,6 +50,25 @@ namespace spirv {
using Decoration = std::vector<uint32_t>; using Decoration = std::vector<uint32_t>;
using DecorationList = std::vector<Decoration>; using DecorationList = std::vector<Decoration>;
// An AST expression with its type.
struct TypedExpression {
/// Dummy constructor
TypedExpression() : type(nullptr), expr(nullptr) {}
/// Constructor
/// @param t the type
/// @param e the expression
TypedExpression(ast::type::Type* t, std::unique_ptr<ast::Expression> e)
: type(t), expr(std::move(e)) {}
/// Move constructor
/// @param other the other typed expression
TypedExpression(TypedExpression&& other)
: type(other.type), expr(std::move(other.expr)) {}
/// The type
ast::type::Type* type;
/// The expression
std::unique_ptr<ast::Expression> expr;
};
/// Parser implementation for SPIR-V. /// Parser implementation for SPIR-V.
class ParserImpl : Reader { class ParserImpl : Reader {
public: public:
@ -224,7 +243,7 @@ class ParserImpl : Reader {
/// Creates an AST expression node for a SPIR-V constant. /// Creates an AST expression node for a SPIR-V constant.
/// @param id the SPIR-V ID of the constant /// @param id the SPIR-V ID of the constant
/// @returns a new Literal node /// @returns a new Literal node
std::unique_ptr<ast::Expression> MakeConstantExpression(uint32_t id); TypedExpression MakeConstantExpression(uint32_t id);
private: private:
/// Converts a specific SPIR-V type to a Tint type. Integer case /// Converts a specific SPIR-V type to a Tint type. Integer case