[ir] Add load instruction

When converting an AST expression to IR, check for a sem::Load node
and emit a load instruction if present.

Update conversion of compound assignment and increment/decrement to
load from the LHS.

Convert load instructions to inline variable references when going
back to the AST.

Bug: tint:1718
Change-Id: Ib2b850efb304a71eff95aadac825f015623b6eb3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133220
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
James Price 2023-05-17 18:41:27 +00:00 committed by Dawn LUCI CQ
parent 2731b76ded
commit 90b8cc1e93
12 changed files with 263 additions and 53 deletions

View File

@ -1230,6 +1230,8 @@ libtint_source_set("libtint_ir_src") {
"ir/if.h",
"ir/instruction.cc",
"ir/instruction.h",
"ir/load.cc",
"ir/load.h",
"ir/loop.cc",
"ir/loop.h",
"ir/module.cc",
@ -2250,6 +2252,7 @@ if (tint_build_unittests) {
"ir/from_program_test.cc",
"ir/from_program_unary_test.cc",
"ir/from_program_var_test.cc",
"ir/load_test.cc",
"ir/module_test.cc",
"ir/store_test.cc",
"ir/test_helper.h",

View File

@ -748,6 +748,8 @@ if(${TINT_BUILD_IR})
ir/if.h
ir/instruction.cc
ir/instruction.h
ir/load.cc
ir/load.h
ir/loop.cc
ir/loop.h
ir/module.cc
@ -1467,6 +1469,7 @@ if(TINT_BUILD_TESTS)
ir/from_program_test.cc
ir/from_program_unary_test.cc
ir/from_program_var_test.cc
ir/load_test.cc
ir/module_test.cc
ir/store_test.cc
ir/test_helper.h

View File

@ -17,6 +17,8 @@
#include <utility>
#include "src/tint/constant/scalar.h"
#include "src/tint/type/pointer.h"
#include "src/tint/type/reference.h"
namespace tint::ir {
@ -227,6 +229,12 @@ ir::Builtin* Builder::Builtin(const type::Type* type,
return ir.values.Create<ir::Builtin>(type, func, args);
}
ir::Load* Builder::Load(Value* from) {
auto* ptr = from->Type()->As<type::Pointer>();
TINT_ASSERT(IR, ptr);
return ir.values.Create<ir::Load>(ptr->StoreType(), from);
}
ir::Store* Builder::Store(Value* to, Value* from) {
return ir.values.Create<ir::Store>(to, from);
}

View File

@ -28,6 +28,7 @@
#include "src/tint/ir/function.h"
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/root_terminator.h"
@ -338,7 +339,12 @@ class Builder {
builtin::Function func,
utils::VectorRef<Value*> args);
/// Creates an store instruction
/// Creates a load instruction
/// @param from the expression being loaded from
/// @returns the instruction
ir::Load* Load(Value* from);
/// Creates a store instruction
/// @param to the expression being stored too
/// @param from the expression being stored
/// @returns the instruction

View File

@ -27,6 +27,7 @@
#include "src/tint/ir/discard.h"
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/root_terminator.h"
#include "src/tint/ir/store.h"
@ -432,6 +433,11 @@ void Disassembler::EmitInstruction(const Instruction* inst) {
out_ << " = convert " << c->FromType()->FriendlyName() << ", ";
EmitArgs(c);
},
[&](const ir::Load* l) {
EmitValue(l);
out_ << " = load ";
EmitValue(l->from);
},
[&](const ir::Store* s) {
out_ << "store ";
EmitValue(s->to);

View File

@ -74,6 +74,7 @@
#include "src/tint/sem/builtin.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/load.h"
#include "src/tint/sem/materialize.h"
#include "src/tint/sem/module.h"
#include "src/tint/sem/switch_statement.h"
@ -422,15 +423,20 @@ class Impl {
return;
}
auto* ty = lhs.Get()->Type();
auto* rhs = ty->UnwrapRef()->is_signed_integer_scalar() ? builder_.Constant(1_i)
: builder_.Constant(1_u);
// Load from the LHS.
auto* lhs_value = builder_.Load(lhs.Get());
current_flow_block_->instructions.Push(lhs_value);
auto* ty = lhs_value->Type();
auto* rhs =
ty->is_signed_integer_scalar() ? builder_.Constant(1_i) : builder_.Constant(1_u);
Binary* inst = nullptr;
if (stmt->increment) {
inst = builder_.Add(ty, lhs.Get(), rhs);
inst = builder_.Add(ty, lhs_value, rhs);
} else {
inst = builder_.Subtract(ty, lhs.Get(), rhs);
inst = builder_.Subtract(ty, lhs_value, rhs);
}
current_flow_block_->instructions.Push(inst);
@ -448,38 +454,44 @@ class Impl {
if (!rhs) {
return;
}
auto* ty = lhs.Get()->Type();
// Load from the LHS.
auto* lhs_value = builder_.Load(lhs.Get());
current_flow_block_->instructions.Push(lhs_value);
auto* ty = lhs_value->Type();
Binary* inst = nullptr;
switch (stmt->op) {
case ast::BinaryOp::kAnd:
inst = builder_.And(ty, lhs.Get(), rhs.Get());
inst = builder_.And(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kOr:
inst = builder_.Or(ty, lhs.Get(), rhs.Get());
inst = builder_.Or(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kXor:
inst = builder_.Xor(ty, lhs.Get(), rhs.Get());
inst = builder_.Xor(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kShiftLeft:
inst = builder_.ShiftLeft(ty, lhs.Get(), rhs.Get());
inst = builder_.ShiftLeft(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kShiftRight:
inst = builder_.ShiftRight(ty, lhs.Get(), rhs.Get());
inst = builder_.ShiftRight(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kAdd:
inst = builder_.Add(ty, lhs.Get(), rhs.Get());
inst = builder_.Add(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kSubtract:
inst = builder_.Subtract(ty, lhs.Get(), rhs.Get());
inst = builder_.Subtract(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kMultiply:
inst = builder_.Multiply(ty, lhs.Get(), rhs.Get());
inst = builder_.Multiply(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kDivide:
inst = builder_.Divide(ty, lhs.Get(), rhs.Get());
inst = builder_.Divide(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kModulo:
inst = builder_.Modulo(ty, lhs.Get(), rhs.Get());
inst = builder_.Modulo(ty, lhs_value, rhs.Get());
break;
case ast::BinaryOp::kLessThanEqual:
case ast::BinaryOp::kGreaterThanEqual:
@ -809,7 +821,8 @@ class Impl {
utils::Result<Value*> EmitExpression(const ast::Expression* expr) {
// If this is a value that has been const-eval'd return the result.
if (auto* sem = program_->Sem().Get(expr)->As<sem::ValueExpression>()) {
auto* sem = program_->Sem().GetVal(expr);
if (sem) {
if (auto* v = sem->ConstantValue()) {
if (auto* cv = v->Clone(clone_ctx_)) {
return builder_.Constant(cv);
@ -817,7 +830,7 @@ class Impl {
}
}
return tint::Switch(
auto result = tint::Switch(
expr,
// [&](const ast::IndexAccessorExpression* a) {
// TODO(dsinclair): Implement
@ -846,6 +859,15 @@ class Impl {
"unknown expression type: " + std::string(expr->TypeInfo().name));
return utils::Failure;
});
// If this expression maps to sem::Load, insert a load instruction to get the result.
if (result && sem->Is<sem::Load>()) {
auto* load = builder_.Load(result.Get());
current_flow_block_->instructions.Push(load);
return load;
}
return result;
}
void EmitVariable(const ast::Variable* var) {

View File

@ -64,8 +64,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Increment) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = add %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = add %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -87,8 +88,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundAdd) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = add %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = add %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -133,8 +135,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Decrement) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, i32, read_write> = sub %v1:ptr<private, i32, read_write>, 1u
store %v1:ptr<private, i32, read_write>, %2:ptr<private, i32, read_write>
%2:i32 = load %v1:ptr<private, i32, read_write>
%3:i32 = sub %2:i32, 1i
store %v1:ptr<private, i32, read_write>, %3:i32
} -> %func_end # return
} %func_end
@ -156,8 +159,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundSubtract) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = sub %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = sub %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -202,8 +206,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundMultiply) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = mul %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = mul %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -248,8 +253,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundDiv) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = div %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = div %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -294,8 +300,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundModulo) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = mod %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = mod %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -340,8 +347,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundAnd) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, bool, read_write> = and %v1:ptr<private, bool, read_write>, false
store %v1:ptr<private, bool, read_write>, %2:ptr<private, bool, read_write>
%2:bool = load %v1:ptr<private, bool, read_write>
%3:bool = and %2:bool, false
store %v1:ptr<private, bool, read_write>, %3:bool
} -> %func_end # return
} %func_end
@ -386,8 +394,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundOr) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, bool, read_write> = or %v1:ptr<private, bool, read_write>, false
store %v1:ptr<private, bool, read_write>, %2:ptr<private, bool, read_write>
%2:bool = load %v1:ptr<private, bool, read_write>
%3:bool = or %2:bool, false
store %v1:ptr<private, bool, read_write>, %3:bool
} -> %func_end # return
} %func_end
@ -432,8 +441,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundXor) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = xor %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = xor %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -712,8 +722,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundShiftLeft) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = shiftl %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = shiftl %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end
@ -758,8 +769,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundShiftRight) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%2:ptr<private, u32, read_write> = shiftr %v1:ptr<private, u32, read_write>, 1u
store %v1:ptr<private, u32, read_write>, %2:ptr<private, u32, read_write>
%2:u32 = load %v1:ptr<private, u32, read_write>
%3:u32 = shiftr %2:u32, 1u
store %v1:ptr<private, u32, read_write>, %3:u32
} -> %func_end # return
} %func_end

View File

@ -106,7 +106,8 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Convert) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%tint_symbol:f32 = convert i32, %i:ptr<private, i32, read_write>
%2:i32 = load %i:ptr<private, i32, read_write>
%tint_symbol:f32 = convert i32, %2:i32
} -> %func_end # return
} %func_end
@ -143,7 +144,8 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Construct) {
%fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] {
%fn3 = block {
%tint_symbol:vec3<f32> = construct 2.0f, 3.0f, %i:ptr<private, f32, read_write>
%2:f32 = load %i:ptr<private, f32, read_write>
%tint_symbol:vec3<f32> = construct 2.0f, 3.0f, %2:f32
} -> %func_end # return
} %func_end

30
src/tint/ir/load.cc Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ir/load.h"
#include "src/tint/debug.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Load);
namespace tint::ir {
Load::Load(const type::Type* type, Value* f) : Base(), result_type(type), from(f) {
TINT_ASSERT(IR, result_type);
TINT_ASSERT(IR, from);
from->AddUsage(this);
}
Load::~Load() = default;
} // namespace tint::ir

49
src/tint/ir/load.h Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_IR_LOAD_H_
#define SRC_TINT_IR_LOAD_H_
#include "src/tint/ir/instruction.h"
#include "src/tint/utils/castable.h"
namespace tint::ir {
/// A load instruction in the IR.
class Load : public utils::Castable<Load, Instruction> {
public:
/// Constructor
/// @param type the result type
/// @param from the value being loaded from
Load(const type::Type* type, Value* from);
Load(const Load& inst) = delete;
Load(Load&& inst) = delete;
~Load() override;
Load& operator=(const Load& inst) = delete;
Load& operator=(Load&& inst) = delete;
/// @returns the type of the value
const type::Type* Type() const override { return result_type; }
/// the result type of the instruction
const type::Type* result_type = nullptr;
/// the value being loaded
Value* from = nullptr;
};
} // namespace tint::ir
#endif // SRC_TINT_IR_LOAD_H_

59
src/tint/ir/load_test.cc Normal file
View File

@ -0,0 +1,59 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ir/builder.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/test_helper.h"
namespace tint::ir {
namespace {
using namespace tint::number_suffixes; // NOLINT
using IR_InstructionTest = TestHelper;
TEST_F(IR_InstructionTest, CreateLoad) {
Module mod;
Builder b{mod};
auto* store_type = b.ir.types.Get<type::I32>();
auto* var = b.Declare(b.ir.types.Get<type::Pointer>(
store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
const auto* inst = b.Load(var);
ASSERT_TRUE(inst->Is<Load>());
ASSERT_EQ(inst->from, var);
EXPECT_EQ(inst->Type(), store_type);
ASSERT_TRUE(inst->from->Is<ir::Var>());
EXPECT_EQ(inst->from, var);
}
TEST_F(IR_InstructionTest, Load_Usage) {
Module mod;
Builder b{mod};
auto* store_type = b.ir.types.Get<type::I32>();
auto* var = b.Declare(b.ir.types.Get<type::Pointer>(
store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
const auto* inst = b.Load(var);
ASSERT_NE(inst->from, nullptr);
ASSERT_EQ(inst->from->Usage().Length(), 1u);
EXPECT_EQ(inst->from->Usage()[0], inst);
}
} // namespace
} // namespace tint::ir

View File

@ -23,6 +23,7 @@
#include "src/tint/ir/function_terminator.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/user_call.h"
@ -123,11 +124,13 @@ class State {
[&](const ir::Block* block) {
for (auto* inst : block->instructions) {
auto* stmt = Stmt(inst);
auto stmt = Stmt(inst);
if (TINT_UNLIKELY(!stmt)) {
return kError;
}
stmts.Push(stmt);
if (auto* s = stmt.Get()) {
stmts.Push(s);
}
}
branch = &block->branch;
return kContinue;
@ -239,8 +242,11 @@ class State {
const ir::FlowNode* NextNonEmptyNode(const ir::FlowNode* node) {
while (node) {
if (auto* block = node->As<ir::Block>()) {
if (block->instructions.Length() > 0) {
return node;
for (auto* inst : block->instructions) {
// Load instructions will be inlined, so ignore them.
if (!inst->Is<ir::Load>()) {
return node;
}
}
node = block->branch.target;
} else {
@ -250,15 +256,16 @@ class State {
return nullptr;
}
const ast::Statement* Stmt(const ir::Instruction* inst) {
return Switch(
utils::Result<const ast::Statement*> Stmt(const ir::Instruction* inst) {
return Switch<utils::Result<const ast::Statement*>>(
inst, //
[&](const ir::Call* i) { return CallStmt(i); }, //
[&](const ir::Var* i) { return Var(i); }, //
[&](const ir::Store* i) { return Store(i); }, //
[&](const ir::Load*) { return nullptr; },
[&](const ir::Store* i) { return Store(i); }, //
[&](Default) {
UNHANDLED_CASE(inst);
return nullptr;
return utils::Failure;
});
}
@ -318,6 +325,7 @@ class State {
return Switch(
val, //
[&](const ir::Constant* c) { return ConstExpr(c); },
[&](const ir::Load* l) { return LoadExpr(l); },
[&](const ir::Var* v) { return VarExpr(v); },
[&](Default) {
UNHANDLED_CASE(val);
@ -339,6 +347,8 @@ class State {
});
}
const ast::Expression* LoadExpr(const ir::Load* l) { return Expr(l->from); }
const ast::Expression* VarExpr(const ir::Var* v) { return b.Expr(NameOf(v)); }
utils::Result<ast::Type> Type(const type::Type* ty) {