diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 617d1f6e47..59bb68d9c9 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1230,6 +1230,8 @@ libtint_source_set("libtint_ir_src") { "ir/if.h", "ir/instruction.cc", "ir/instruction.h", + "ir/load.cc", + "ir/load.h", "ir/loop.cc", "ir/loop.h", "ir/module.cc", @@ -2250,6 +2252,7 @@ if (tint_build_unittests) { "ir/from_program_test.cc", "ir/from_program_unary_test.cc", "ir/from_program_var_test.cc", + "ir/load_test.cc", "ir/module_test.cc", "ir/store_test.cc", "ir/test_helper.h", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 9bf2714fa0..78132891fe 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -748,6 +748,8 @@ if(${TINT_BUILD_IR}) ir/if.h ir/instruction.cc ir/instruction.h + ir/load.cc + ir/load.h ir/loop.cc ir/loop.h ir/module.cc @@ -1467,6 +1469,7 @@ if(TINT_BUILD_TESTS) ir/from_program_test.cc ir/from_program_unary_test.cc ir/from_program_var_test.cc + ir/load_test.cc ir/module_test.cc ir/store_test.cc ir/test_helper.h diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc index 36311088b1..983815ef30 100644 --- a/src/tint/ir/builder.cc +++ b/src/tint/ir/builder.cc @@ -17,6 +17,8 @@ #include #include "src/tint/constant/scalar.h" +#include "src/tint/type/pointer.h" +#include "src/tint/type/reference.h" namespace tint::ir { @@ -227,6 +229,12 @@ ir::Builtin* Builder::Builtin(const type::Type* type, return ir.values.Create(type, func, args); } +ir::Load* Builder::Load(Value* from) { + auto* ptr = from->Type()->As(); + TINT_ASSERT(IR, ptr); + return ir.values.Create(ptr->StoreType(), from); +} + ir::Store* Builder::Store(Value* to, Value* from) { return ir.values.Create(to, from); } diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h index f4bc2c5008..7c0cab6067 100644 --- a/src/tint/ir/builder.h +++ b/src/tint/ir/builder.h @@ -28,6 +28,7 @@ #include "src/tint/ir/function.h" #include "src/tint/ir/function_terminator.h" #include "src/tint/ir/if.h" +#include "src/tint/ir/load.h" #include "src/tint/ir/loop.h" #include "src/tint/ir/module.h" #include "src/tint/ir/root_terminator.h" @@ -338,7 +339,12 @@ class Builder { builtin::Function func, utils::VectorRef args); - /// Creates an store instruction + /// Creates a load instruction + /// @param from the expression being loaded from + /// @returns the instruction + ir::Load* Load(Value* from); + + /// Creates a store instruction /// @param to the expression being stored too /// @param from the expression being stored /// @returns the instruction diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc index ed5cc259e5..577d646101 100644 --- a/src/tint/ir/disassembler.cc +++ b/src/tint/ir/disassembler.cc @@ -27,6 +27,7 @@ #include "src/tint/ir/discard.h" #include "src/tint/ir/function_terminator.h" #include "src/tint/ir/if.h" +#include "src/tint/ir/load.h" #include "src/tint/ir/loop.h" #include "src/tint/ir/root_terminator.h" #include "src/tint/ir/store.h" @@ -432,6 +433,11 @@ void Disassembler::EmitInstruction(const Instruction* inst) { out_ << " = convert " << c->FromType()->FriendlyName() << ", "; EmitArgs(c); }, + [&](const ir::Load* l) { + EmitValue(l); + out_ << " = load "; + EmitValue(l->from); + }, [&](const ir::Store* s) { out_ << "store "; EmitValue(s->to); diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc index 62fd3b9833..bb9c3f9f1f 100644 --- a/src/tint/ir/from_program.cc +++ b/src/tint/ir/from_program.cc @@ -74,6 +74,7 @@ #include "src/tint/sem/builtin.h" #include "src/tint/sem/call.h" #include "src/tint/sem/function.h" +#include "src/tint/sem/load.h" #include "src/tint/sem/materialize.h" #include "src/tint/sem/module.h" #include "src/tint/sem/switch_statement.h" @@ -422,15 +423,20 @@ class Impl { return; } - auto* ty = lhs.Get()->Type(); - auto* rhs = ty->UnwrapRef()->is_signed_integer_scalar() ? builder_.Constant(1_i) - : builder_.Constant(1_u); + // Load from the LHS. + auto* lhs_value = builder_.Load(lhs.Get()); + current_flow_block_->instructions.Push(lhs_value); + + auto* ty = lhs_value->Type(); + + auto* rhs = + ty->is_signed_integer_scalar() ? builder_.Constant(1_i) : builder_.Constant(1_u); Binary* inst = nullptr; if (stmt->increment) { - inst = builder_.Add(ty, lhs.Get(), rhs); + inst = builder_.Add(ty, lhs_value, rhs); } else { - inst = builder_.Subtract(ty, lhs.Get(), rhs); + inst = builder_.Subtract(ty, lhs_value, rhs); } current_flow_block_->instructions.Push(inst); @@ -448,38 +454,44 @@ class Impl { if (!rhs) { return; } - auto* ty = lhs.Get()->Type(); + + // Load from the LHS. + auto* lhs_value = builder_.Load(lhs.Get()); + current_flow_block_->instructions.Push(lhs_value); + + auto* ty = lhs_value->Type(); + Binary* inst = nullptr; switch (stmt->op) { case ast::BinaryOp::kAnd: - inst = builder_.And(ty, lhs.Get(), rhs.Get()); + inst = builder_.And(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kOr: - inst = builder_.Or(ty, lhs.Get(), rhs.Get()); + inst = builder_.Or(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kXor: - inst = builder_.Xor(ty, lhs.Get(), rhs.Get()); + inst = builder_.Xor(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kShiftLeft: - inst = builder_.ShiftLeft(ty, lhs.Get(), rhs.Get()); + inst = builder_.ShiftLeft(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kShiftRight: - inst = builder_.ShiftRight(ty, lhs.Get(), rhs.Get()); + inst = builder_.ShiftRight(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kAdd: - inst = builder_.Add(ty, lhs.Get(), rhs.Get()); + inst = builder_.Add(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kSubtract: - inst = builder_.Subtract(ty, lhs.Get(), rhs.Get()); + inst = builder_.Subtract(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kMultiply: - inst = builder_.Multiply(ty, lhs.Get(), rhs.Get()); + inst = builder_.Multiply(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kDivide: - inst = builder_.Divide(ty, lhs.Get(), rhs.Get()); + inst = builder_.Divide(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kModulo: - inst = builder_.Modulo(ty, lhs.Get(), rhs.Get()); + inst = builder_.Modulo(ty, lhs_value, rhs.Get()); break; case ast::BinaryOp::kLessThanEqual: case ast::BinaryOp::kGreaterThanEqual: @@ -809,7 +821,8 @@ class Impl { utils::Result EmitExpression(const ast::Expression* expr) { // If this is a value that has been const-eval'd return the result. - if (auto* sem = program_->Sem().Get(expr)->As()) { + auto* sem = program_->Sem().GetVal(expr); + if (sem) { if (auto* v = sem->ConstantValue()) { if (auto* cv = v->Clone(clone_ctx_)) { return builder_.Constant(cv); @@ -817,7 +830,7 @@ class Impl { } } - return tint::Switch( + auto result = tint::Switch( expr, // [&](const ast::IndexAccessorExpression* a) { // TODO(dsinclair): Implement @@ -846,6 +859,15 @@ class Impl { "unknown expression type: " + std::string(expr->TypeInfo().name)); return utils::Failure; }); + + // If this expression maps to sem::Load, insert a load instruction to get the result. + if (result && sem->Is()) { + auto* load = builder_.Load(result.Get()); + current_flow_block_->instructions.Push(load); + return load; + } + + return result; } void EmitVariable(const ast::Variable* var) { diff --git a/src/tint/ir/from_program_binary_test.cc b/src/tint/ir/from_program_binary_test.cc index 76e24fb10f..1319083c75 100644 --- a/src/tint/ir/from_program_binary_test.cc +++ b/src/tint/ir/from_program_binary_test.cc @@ -64,8 +64,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Increment) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = add %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = add %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -87,8 +88,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundAdd) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = add %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = add %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -133,8 +135,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Decrement) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = sub %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:i32 = load %v1:ptr + %3:i32 = sub %2:i32, 1i + store %v1:ptr, %3:i32 } -> %func_end # return } %func_end @@ -156,8 +159,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundSubtract) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = sub %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = sub %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -202,8 +206,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundMultiply) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = mul %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = mul %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -248,8 +253,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundDiv) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = div %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = div %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -294,8 +300,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundModulo) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = mod %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = mod %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -340,8 +347,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundAnd) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = and %v1:ptr, false - store %v1:ptr, %2:ptr + %2:bool = load %v1:ptr + %3:bool = and %2:bool, false + store %v1:ptr, %3:bool } -> %func_end # return } %func_end @@ -386,8 +394,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundOr) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = or %v1:ptr, false - store %v1:ptr, %2:ptr + %2:bool = load %v1:ptr + %3:bool = or %2:bool, false + store %v1:ptr, %3:bool } -> %func_end # return } %func_end @@ -432,8 +441,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundXor) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = xor %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = xor %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -712,8 +722,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundShiftLeft) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = shiftl %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = shiftl %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end @@ -758,8 +769,9 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_CompoundShiftRight) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %2:ptr = shiftr %v1:ptr, 1u - store %v1:ptr, %2:ptr + %2:u32 = load %v1:ptr + %3:u32 = shiftr %2:u32, 1u + store %v1:ptr, %3:u32 } -> %func_end # return } %func_end diff --git a/src/tint/ir/from_program_call_test.cc b/src/tint/ir/from_program_call_test.cc index f4c0ee79fc..f6422e267a 100644 --- a/src/tint/ir/from_program_call_test.cc +++ b/src/tint/ir/from_program_call_test.cc @@ -106,7 +106,8 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Convert) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %tint_symbol:f32 = convert i32, %i:ptr + %2:i32 = load %i:ptr + %tint_symbol:f32 = convert i32, %2:i32 } -> %func_end # return } %func_end @@ -143,7 +144,8 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Construct) { %fn2 = func test_function():void [@compute @workgroup_size(1, 1, 1)] { %fn3 = block { - %tint_symbol:vec3 = construct 2.0f, 3.0f, %i:ptr + %2:f32 = load %i:ptr + %tint_symbol:vec3 = construct 2.0f, 3.0f, %2:f32 } -> %func_end # return } %func_end diff --git a/src/tint/ir/load.cc b/src/tint/ir/load.cc new file mode 100644 index 0000000000..1fe55c01bb --- /dev/null +++ b/src/tint/ir/load.cc @@ -0,0 +1,30 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ir/load.h" +#include "src/tint/debug.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ir::Load); + +namespace tint::ir { + +Load::Load(const type::Type* type, Value* f) : Base(), result_type(type), from(f) { + TINT_ASSERT(IR, result_type); + TINT_ASSERT(IR, from); + from->AddUsage(this); +} + +Load::~Load() = default; + +} // namespace tint::ir diff --git a/src/tint/ir/load.h b/src/tint/ir/load.h new file mode 100644 index 0000000000..b15ecedc2e --- /dev/null +++ b/src/tint/ir/load.h @@ -0,0 +1,49 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_IR_LOAD_H_ +#define SRC_TINT_IR_LOAD_H_ + +#include "src/tint/ir/instruction.h" +#include "src/tint/utils/castable.h" + +namespace tint::ir { + +/// A load instruction in the IR. +class Load : public utils::Castable { + public: + /// Constructor + /// @param type the result type + /// @param from the value being loaded from + Load(const type::Type* type, Value* from); + Load(const Load& inst) = delete; + Load(Load&& inst) = delete; + ~Load() override; + + Load& operator=(const Load& inst) = delete; + Load& operator=(Load&& inst) = delete; + + /// @returns the type of the value + const type::Type* Type() const override { return result_type; } + + /// the result type of the instruction + const type::Type* result_type = nullptr; + + /// the value being loaded + Value* from = nullptr; +}; + +} // namespace tint::ir + +#endif // SRC_TINT_IR_LOAD_H_ diff --git a/src/tint/ir/load_test.cc b/src/tint/ir/load_test.cc new file mode 100644 index 0000000000..2c6e5c1402 --- /dev/null +++ b/src/tint/ir/load_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ir/builder.h" +#include "src/tint/ir/instruction.h" +#include "src/tint/ir/test_helper.h" + +namespace tint::ir { +namespace { + +using namespace tint::number_suffixes; // NOLINT + +using IR_InstructionTest = TestHelper; + +TEST_F(IR_InstructionTest, CreateLoad) { + Module mod; + Builder b{mod}; + + auto* store_type = b.ir.types.Get(); + auto* var = b.Declare(b.ir.types.Get( + store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite)); + const auto* inst = b.Load(var); + + ASSERT_TRUE(inst->Is()); + ASSERT_EQ(inst->from, var); + + EXPECT_EQ(inst->Type(), store_type); + + ASSERT_TRUE(inst->from->Is()); + EXPECT_EQ(inst->from, var); +} + +TEST_F(IR_InstructionTest, Load_Usage) { + Module mod; + Builder b{mod}; + + auto* store_type = b.ir.types.Get(); + auto* var = b.Declare(b.ir.types.Get( + store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite)); + const auto* inst = b.Load(var); + + ASSERT_NE(inst->from, nullptr); + ASSERT_EQ(inst->from->Usage().Length(), 1u); + EXPECT_EQ(inst->from->Usage()[0], inst); +} + +} // namespace +} // namespace tint::ir diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc index 219490982a..5fed9f590a 100644 --- a/src/tint/ir/to_program.cc +++ b/src/tint/ir/to_program.cc @@ -23,6 +23,7 @@ #include "src/tint/ir/function_terminator.h" #include "src/tint/ir/if.h" #include "src/tint/ir/instruction.h" +#include "src/tint/ir/load.h" #include "src/tint/ir/module.h" #include "src/tint/ir/store.h" #include "src/tint/ir/user_call.h" @@ -123,11 +124,13 @@ class State { [&](const ir::Block* block) { for (auto* inst : block->instructions) { - auto* stmt = Stmt(inst); + auto stmt = Stmt(inst); if (TINT_UNLIKELY(!stmt)) { return kError; } - stmts.Push(stmt); + if (auto* s = stmt.Get()) { + stmts.Push(s); + } } branch = &block->branch; return kContinue; @@ -239,8 +242,11 @@ class State { const ir::FlowNode* NextNonEmptyNode(const ir::FlowNode* node) { while (node) { if (auto* block = node->As()) { - if (block->instructions.Length() > 0) { - return node; + for (auto* inst : block->instructions) { + // Load instructions will be inlined, so ignore them. + if (!inst->Is()) { + return node; + } } node = block->branch.target; } else { @@ -250,15 +256,16 @@ class State { return nullptr; } - const ast::Statement* Stmt(const ir::Instruction* inst) { - return Switch( + utils::Result Stmt(const ir::Instruction* inst) { + return Switch>( inst, // [&](const ir::Call* i) { return CallStmt(i); }, // [&](const ir::Var* i) { return Var(i); }, // - [&](const ir::Store* i) { return Store(i); }, // + [&](const ir::Load*) { return nullptr; }, + [&](const ir::Store* i) { return Store(i); }, // [&](Default) { UNHANDLED_CASE(inst); - return nullptr; + return utils::Failure; }); } @@ -318,6 +325,7 @@ class State { return Switch( val, // [&](const ir::Constant* c) { return ConstExpr(c); }, + [&](const ir::Load* l) { return LoadExpr(l); }, [&](const ir::Var* v) { return VarExpr(v); }, [&](Default) { UNHANDLED_CASE(val); @@ -339,6 +347,8 @@ class State { }); } + const ast::Expression* LoadExpr(const ir::Load* l) { return Expr(l->from); } + const ast::Expression* VarExpr(const ir::Var* v) { return b.Expr(NameOf(v)); } utils::Result Type(const type::Type* ty) {