spirv-reader: Use GenerateExpressionWithLoadIfNeeded more

* Rename GenerateNonReferenceExpression to
  GenerateExpressionWithLoadIfNeeded.
  This version takes an ast::Expression
* Add a variant that takes a sem::Expression, because the sem
  expression already knows the resolved type, and so we can save
  a lookup.
* Replace most uses of GenerateExpression ... GenerateLoadIfNeeded
  with a call to one of the above.

This is a non-functional change.
Followup to the fix in tint:1343.

Bug: tint:1343
Change-Id: If19a1bc7670edd2badc1533861d8b42f0825c7b8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72720
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
David Neto 2021-12-14 22:15:39 +00:00 committed by Tint LUCI CQ
parent 5ad482744d
commit 9360046a86
2 changed files with 45 additions and 65 deletions

View File

@ -408,15 +408,10 @@ bool Builder::GenerateAssignStatement(const ast::AssignmentStatement* assign) {
if (lhs_id == 0) {
return false;
}
auto rhs_id = GenerateExpression(assign->rhs);
auto rhs_id = GenerateExpressionWithLoadIfNeeded(assign->rhs);
if (rhs_id == 0) {
return false;
}
// If the thing we're assigning is a reference then we must load it first.
auto* type = TypeOf(assign->rhs);
rhs_id = GenerateLoadIfNeeded(type, rhs_id);
return GenerateStore(lhs_id, rhs_id);
}
}
@ -706,14 +701,10 @@ uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
uint32_t init_id = 0;
if (var->constructor) {
init_id = GenerateExpression(var->constructor);
init_id = GenerateExpressionWithLoadIfNeeded(var->constructor);
if (init_id == 0) {
return false;
}
auto* type = TypeOf(var->constructor);
if (type->Is<sem::Reference>()) {
init_id = GenerateLoadIfNeeded(type, init_id);
}
}
if (var->is_const) {
@ -914,12 +905,10 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr,
AccessorInfo* info) {
auto idx_id = GenerateExpression(expr->index);
auto idx_id = GenerateExpressionWithLoadIfNeeded(expr->index);
if (idx_id == 0) {
return 0;
}
auto* type = TypeOf(expr->index);
idx_id = GenerateLoadIfNeeded(type, idx_id);
// If the source is a reference, we access chain into it.
// In the future, pointers may support access-chaining.
@ -1183,8 +1172,19 @@ uint32_t Builder::GenerateIdentifierExpression(
return val;
}
uint32_t Builder::GenerateNonReferenceExpression(const ast::Expression* expr) {
uint32_t Builder::GenerateExpressionWithLoadIfNeeded(
const sem::Expression* expr) {
// The semantic node directly knows both the AST node and the resolved type.
if (const auto id = GenerateExpression(expr->Declaration())) {
return GenerateLoadIfNeeded(expr->Type(), id);
}
return 0;
}
uint32_t Builder::GenerateExpressionWithLoadIfNeeded(
const ast::Expression* expr) {
if (const auto id = GenerateExpression(expr)) {
// Perform a lookup to get the resolved type.
return GenerateLoadIfNeeded(TypeOf(expr), id);
}
return 0;
@ -1212,11 +1212,6 @@ uint32_t Builder::GenerateUnaryOpExpression(
auto result = result_op();
auto result_id = result.to_i();
auto val_id = GenerateExpression(expr->expr);
if (val_id == 0) {
return 0;
}
spv::Op op = spv::Op::OpNop;
switch (expr->op) {
case ast::UnaryOp::kComplement:
@ -1237,10 +1232,13 @@ uint32_t Builder::GenerateUnaryOpExpression(
// Address-of converts a reference to a pointer, and dereference converts
// a pointer to a reference. These are the same thing in SPIR-V, so this
// is a no-op.
return val_id;
return GenerateExpression(expr->expr);
}
val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr);
if (val_id == 0) {
return 0;
}
auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
@ -1380,11 +1378,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(
OperandList ops;
for (auto* e : args) {
uint32_t id = 0;
id = GenerateExpression(e->Declaration());
if (id == 0) {
return 0;
}
id = GenerateLoadIfNeeded(e->Type(), id);
id = GenerateExpressionWithLoadIfNeeded(e);
if (id == 0) {
return 0;
}
@ -1532,11 +1526,10 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(
return 0;
}
auto val_id = GenerateExpression(from_expr);
auto val_id = GenerateExpressionWithLoadIfNeeded(from_expr);
if (val_id == 0) {
return 0;
}
val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
auto* from_type = TypeOf(from_expr)->UnwrapRef();
@ -1804,11 +1797,10 @@ uint32_t Builder::GenerateConstantVectorSplatIfNeeded(const sem::Vector* type,
uint32_t Builder::GenerateShortCircuitBinaryExpression(
const ast::BinaryExpression* expr) {
auto lhs_id = GenerateExpression(expr->lhs);
auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs);
if (lhs_id == 0) {
return false;
}
lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
// Get the ID of the basic block where control flow will diverge. It's the
// last basic block generated for the left-hand-side of the operator.
@ -1848,11 +1840,10 @@ uint32_t Builder::GenerateShortCircuitBinaryExpression(
if (!GenerateLabel(block_id)) {
return 0;
}
auto rhs_id = GenerateExpression(expr->rhs);
auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs);
if (rhs_id == 0) {
return 0;
}
rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
// Get the block ID of the last basic block generated for the right-hand-side
// expression. That block will be an immediate predecessor to the merge block.
@ -1971,17 +1962,15 @@ uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) {
return GenerateShortCircuitBinaryExpression(expr);
}
auto lhs_id = GenerateExpression(expr->lhs);
auto lhs_id = GenerateExpressionWithLoadIfNeeded(expr->lhs);
if (lhs_id == 0) {
return 0;
}
lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
auto rhs_id = GenerateExpression(expr->rhs);
auto rhs_id = GenerateExpressionWithLoadIfNeeded(expr->rhs);
if (rhs_id == 0) {
return 0;
}
rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
auto result = result_op();
auto result_id = result.to_i();
@ -2258,11 +2247,7 @@ uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
size_t arg_idx = 0;
for (auto* arg : expr->args) {
auto id = GenerateExpression(arg);
if (id == 0) {
return 0;
}
id = GenerateLoadIfNeeded(TypeOf(arg), id);
auto id = GenerateExpressionWithLoadIfNeeded(arg);
if (id == 0) {
return 0;
}
@ -2715,12 +2700,7 @@ bool Builder::GenerateTextureIntrinsic(const sem::Call* call,
// Generates the given expression, returning the operand ID
auto gen = [&](const sem::Expression* expr) {
auto val_id = GenerateExpression(expr->Declaration());
if (val_id == 0) {
return Operand::Int(0);
}
val_id = GenerateLoadIfNeeded(expr->Type(), val_id);
const auto val_id = GenerateExpressionWithLoadIfNeeded(expr);
return Operand::Int(val_id);
};
@ -3218,11 +3198,7 @@ bool Builder::GenerateAtomicIntrinsic(const sem::Call* call,
uint32_t value_id = 0;
if (call->Arguments().size() > 1) {
value_id = GenerateExpression(call->Arguments().back()->Declaration());
if (value_id == 0) {
return false;
}
value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id);
value_id = GenerateExpressionWithLoadIfNeeded(call->Arguments().back());
if (value_id == 0) {
return false;
}
@ -3458,11 +3434,10 @@ uint32_t Builder::GenerateBitcastExpression(
return 0;
}
auto val_id = GenerateExpression(expr->expr);
auto val_id = GenerateExpressionWithLoadIfNeeded(expr->expr);
if (val_id == 0) {
return 0;
}
val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
// Bitcast does not allow same types, just emit a CopyObject
auto* to_type = TypeOf(expr)->UnwrapRef();
@ -3489,11 +3464,10 @@ bool Builder::GenerateConditionalBlock(
const ast::BlockStatement* true_body,
size_t cur_else_idx,
const ast::ElseStatementList& else_stmts) {
auto cond_id = GenerateExpression(cond);
auto cond_id = GenerateExpressionWithLoadIfNeeded(cond);
if (cond_id == 0) {
return false;
}
cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id);
auto merge_block = result_op();
auto merge_block_id = merge_block.to_i();
@ -3585,7 +3559,7 @@ bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) {
if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) {
// It's a break-if.
TINT_ASSERT(Writer, !backedge_stack_.empty());
const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
const auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
if (!cond_id) {
return false;
}
@ -3600,7 +3574,8 @@ bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) {
is_just_a_break(es.back()->body)) {
// It's a break-unless.
TINT_ASSERT(Writer, !backedge_stack_.empty());
const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
const auto cond_id =
GenerateExpressionWithLoadIfNeeded(stmt->condition);
if (!cond_id) {
return false;
}
@ -3626,11 +3601,10 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) {
merge_stack_.push_back(merge_block_id);
auto cond_id = GenerateExpression(stmt->condition);
auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
if (cond_id == 0) {
return false;
}
cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition), cond_id);
auto default_block = result_op();
auto default_block_id = default_block.to_i();
@ -3724,11 +3698,10 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) {
bool Builder::GenerateReturnStatement(const ast::ReturnStatement* stmt) {
if (stmt->value) {
auto val_id = GenerateExpression(stmt->value);
auto val_id = GenerateExpressionWithLoadIfNeeded(stmt->value);
if (val_id == 0) {
return false;
}
val_id = GenerateLoadIfNeeded(TypeOf(stmt->value), val_id);
if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) {
return false;
}

View File

@ -458,9 +458,16 @@ class Builder {
/// type, then return the SPIR-V ID for the expression. Otherwise implement
/// the WGSL Load Rule: generate an OpLoad and return the ID of the result.
/// Returns 0 if the expression could not be generated.
/// @param expr the expression to be generate
/// @param expr the semantic expression node to be generated
/// @returns the the ID of the expression, or loaded expression
uint32_t GenerateNonReferenceExpression(const ast::Expression* expr);
uint32_t GenerateExpressionWithLoadIfNeeded(const sem::Expression* expr);
/// Generates an expression. If the WGSL expression does not have reference
/// type, then return the SPIR-V ID for the expression. Otherwise implement
/// the WGSL Load Rule: generate an OpLoad and return the ID of the result.
/// Returns 0 if the expression could not be generated.
/// @param expr the AST expression to be generated
/// @returns the the ID of the expression, or loaded expression
uint32_t GenerateExpressionWithLoadIfNeeded(const ast::Expression* expr);
/// Generates an OpLoad on the given ID if it has reference type in WGSL,
/// othewrise return the ID itself.
/// @param type the type of the expression