From 9360046a86f67db1794f05ad33655e3305bfa695 Mon Sep 17 00:00:00 2001 From: David Neto Date: Tue, 14 Dec 2021 22:15:39 +0000 Subject: [PATCH] spirv-reader: Use GenerateExpressionWithLoadIfNeeded more * Rename GenerateNonReferenceExpression to GenerateExpressionWithLoadIfNeeded. This version takes an ast::Expression * Add a variant that takes a sem::Expression, because the sem expression already knows the resolved type, and so we can save a lookup. * Replace most uses of GenerateExpression ... GenerateLoadIfNeeded with a call to one of the above. This is a non-functional change. Followup to the fix in tint:1343. Bug: tint:1343 Change-Id: If19a1bc7670edd2badc1533861d8b42f0825c7b8 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72720 Auto-Submit: David Neto Kokoro: Kokoro Reviewed-by: Antonio Maiorano Commit-Queue: David Neto --- src/writer/spirv/builder.cc | 99 ++++++++++++++----------------------- src/writer/spirv/builder.h | 11 ++++- 2 files changed, 45 insertions(+), 65 deletions(-) diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index a0eba5a161..04334a20c1 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -408,15 +408,10 @@ bool Builder::GenerateAssignStatement(const ast::AssignmentStatement* assign) { if (lhs_id == 0) { return false; } - auto rhs_id = GenerateExpression(assign->rhs); + auto rhs_id = GenerateExpressionWithLoadIfNeeded(assign->rhs); if (rhs_id == 0) { return false; } - - // If the thing we're assigning is a reference then we must load it first. - auto* type = TypeOf(assign->rhs); - rhs_id = GenerateLoadIfNeeded(type, rhs_id); - return GenerateStore(lhs_id, rhs_id); } } @@ -706,14 +701,10 @@ uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) { bool Builder::GenerateFunctionVariable(const ast::Variable* var) { uint32_t init_id = 0; if (var->constructor) { - init_id = GenerateExpression(var->constructor); + init_id = GenerateExpressionWithLoadIfNeeded(var->constructor); if (init_id == 0) { return false; } - auto* type = TypeOf(var->constructor); - if (type->Is()) { - init_id = GenerateLoadIfNeeded(type, init_id); - } } if (var->is_const) { @@ -914,12 +905,10 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) { bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, AccessorInfo* info) { - auto idx_id = GenerateExpression(expr->index); + auto idx_id = GenerateExpressionWithLoadIfNeeded(expr->index); if (idx_id == 0) { return 0; } - auto* type = TypeOf(expr->index); - idx_id = GenerateLoadIfNeeded(type, idx_id); // If the source is a reference, we access chain into it. // In the future, pointers may support access-chaining. @@ -1183,8 +1172,19 @@ uint32_t Builder::GenerateIdentifierExpression( return val; } -uint32_t Builder::GenerateNonReferenceExpression(const ast::Expression* expr) { +uint32_t Builder::GenerateExpressionWithLoadIfNeeded( + const sem::Expression* expr) { + // The semantic node directly knows both the AST node and the resolved type. + if (const auto id = GenerateExpression(expr->Declaration())) { + return GenerateLoadIfNeeded(expr->Type(), id); + } + return 0; +} + +uint32_t Builder::GenerateExpressionWithLoadIfNeeded( + const ast::Expression* expr) { if (const auto id = GenerateExpression(expr)) { + // Perform a lookup to get the resolved type. return GenerateLoadIfNeeded(TypeOf(expr), id); } return 0; @@ -1212,11 +1212,6 @@ uint32_t Builder::GenerateUnaryOpExpression( auto result = result_op(); auto result_id = result.to_i(); - auto val_id = GenerateExpression(expr->expr); - if (val_id == 0) { - return 0; - } - spv::Op op = spv::Op::OpNop; switch (expr->op) { case ast::UnaryOp::kComplement: @@ -1237,10 +1232,13 @@ uint32_t Builder::GenerateUnaryOpExpression( // Address-of converts a reference to a pointer, and dereference converts // a pointer to a reference. These are the same thing in SPIR-V, so this // is a no-op. - return val_id; + return GenerateExpression(expr->expr); } - val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id); + auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr); + if (val_id == 0) { + return 0; + } auto type_id = GenerateTypeIfNeeded(TypeOf(expr)); if (type_id == 0) { @@ -1380,11 +1378,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion( OperandList ops; for (auto* e : args) { uint32_t id = 0; - id = GenerateExpression(e->Declaration()); - if (id == 0) { - return 0; - } - id = GenerateLoadIfNeeded(e->Type(), id); + id = GenerateExpressionWithLoadIfNeeded(e); if (id == 0) { return 0; } @@ -1532,11 +1526,10 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough( return 0; } - auto val_id = GenerateExpression(from_expr); + auto val_id = GenerateExpressionWithLoadIfNeeded(from_expr); if (val_id == 0) { return 0; } - val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id); auto* from_type = TypeOf(from_expr)->UnwrapRef(); @@ -1804,11 +1797,10 @@ uint32_t Builder::GenerateConstantVectorSplatIfNeeded(const sem::Vector* type, uint32_t Builder::GenerateShortCircuitBinaryExpression( const ast::BinaryExpression* expr) { - auto lhs_id = GenerateExpression(expr->lhs); + auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs); if (lhs_id == 0) { return false; } - lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id); // Get the ID of the basic block where control flow will diverge. It's the // last basic block generated for the left-hand-side of the operator. @@ -1848,11 +1840,10 @@ uint32_t Builder::GenerateShortCircuitBinaryExpression( if (!GenerateLabel(block_id)) { return 0; } - auto rhs_id = GenerateExpression(expr->rhs); + auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs); if (rhs_id == 0) { return 0; } - rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id); // Get the block ID of the last basic block generated for the right-hand-side // expression. That block will be an immediate predecessor to the merge block. @@ -1971,17 +1962,15 @@ uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) { return GenerateShortCircuitBinaryExpression(expr); } - auto lhs_id = GenerateExpression(expr->lhs); + auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs); if (lhs_id == 0) { return 0; } - lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id); - auto rhs_id = GenerateExpression(expr->rhs); + auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs); if (rhs_id == 0) { return 0; } - rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id); auto result = result_op(); auto result_id = result.to_i(); @@ -2258,11 +2247,7 @@ uint32_t Builder::GenerateFunctionCall(const sem::Call* call, size_t arg_idx = 0; for (auto* arg : expr->args) { - auto id = GenerateExpression(arg); - if (id == 0) { - return 0; - } - id = GenerateLoadIfNeeded(TypeOf(arg), id); + auto id = GenerateExpressionWithLoadIfNeeded(arg); if (id == 0) { return 0; } @@ -2715,12 +2700,7 @@ bool Builder::GenerateTextureIntrinsic(const sem::Call* call, // Generates the given expression, returning the operand ID auto gen = [&](const sem::Expression* expr) { - auto val_id = GenerateExpression(expr->Declaration()); - if (val_id == 0) { - return Operand::Int(0); - } - val_id = GenerateLoadIfNeeded(expr->Type(), val_id); - + const auto val_id = GenerateExpressionWithLoadIfNeeded(expr); return Operand::Int(val_id); }; @@ -3218,11 +3198,7 @@ bool Builder::GenerateAtomicIntrinsic(const sem::Call* call, uint32_t value_id = 0; if (call->Arguments().size() > 1) { - value_id = GenerateExpression(call->Arguments().back()->Declaration()); - if (value_id == 0) { - return false; - } - value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id); + value_id = GenerateExpressionWithLoadIfNeeded(call->Arguments().back()); if (value_id == 0) { return false; } @@ -3458,11 +3434,10 @@ uint32_t Builder::GenerateBitcastExpression( return 0; } - auto val_id = GenerateExpression(expr->expr); + auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr); if (val_id == 0) { return 0; } - val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id); // Bitcast does not allow same types, just emit a CopyObject auto* to_type = TypeOf(expr)->UnwrapRef(); @@ -3489,11 +3464,10 @@ bool Builder::GenerateConditionalBlock( const ast::BlockStatement* true_body, size_t cur_else_idx, const ast::ElseStatementList& else_stmts) { - auto cond_id = GenerateExpression(cond); + auto cond_id = GenerateExpressionWithLoadIfNeeded(cond); if (cond_id == 0) { return false; } - cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id); auto merge_block = result_op(); auto merge_block_id = merge_block.to_i(); @@ -3585,7 +3559,7 @@ bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) { if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) { // It's a break-if. TINT_ASSERT(Writer, !backedge_stack_.empty()); - const auto cond_id = GenerateNonReferenceExpression(stmt->condition); + const auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition); if (!cond_id) { return false; } @@ -3600,7 +3574,8 @@ bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) { is_just_a_break(es.back()->body)) { // It's a break-unless. TINT_ASSERT(Writer, !backedge_stack_.empty()); - const auto cond_id = GenerateNonReferenceExpression(stmt->condition); + const auto cond_id = + GenerateExpressionWithLoadIfNeeded(stmt->condition); if (!cond_id) { return false; } @@ -3626,11 +3601,10 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { merge_stack_.push_back(merge_block_id); - auto cond_id = GenerateExpression(stmt->condition); + auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition); if (cond_id == 0) { return false; } - cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition), cond_id); auto default_block = result_op(); auto default_block_id = default_block.to_i(); @@ -3724,11 +3698,10 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { bool Builder::GenerateReturnStatement(const ast::ReturnStatement* stmt) { if (stmt->value) { - auto val_id = GenerateExpression(stmt->value); + auto val_id = GenerateExpressionWithLoadIfNeeded(stmt->value); if (val_id == 0) { return false; } - val_id = GenerateLoadIfNeeded(TypeOf(stmt->value), val_id); if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) { return false; } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 1b43ee418c..d2b5237100 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -458,9 +458,16 @@ class Builder { /// type, then return the SPIR-V ID for the expression. Otherwise implement /// the WGSL Load Rule: generate an OpLoad and return the ID of the result. /// Returns 0 if the expression could not be generated. - /// @param expr the expression to be generate + /// @param expr the semantic expression node to be generated /// @returns the the ID of the expression, or loaded expression - uint32_t GenerateNonReferenceExpression(const ast::Expression* expr); + uint32_t GenerateExpressionWithLoadIfNeeded(const sem::Expression* expr); + /// Generates an expression. If the WGSL expression does not have reference + /// type, then return the SPIR-V ID for the expression. Otherwise implement + /// the WGSL Load Rule: generate an OpLoad and return the ID of the result. + /// Returns 0 if the expression could not be generated. + /// @param expr the AST expression to be generated + /// @returns the the ID of the expression, or loaded expression + uint32_t GenerateExpressionWithLoadIfNeeded(const ast::Expression* expr); /// Generates an OpLoad on the given ID if it has reference type in WGSL, /// othewrise return the ID itself. /// @param type the type of the expression