diff --git a/src/tint/ir/binary.h b/src/tint/ir/binary.h index 0f7fe57958..c2df24ef54 100644 --- a/src/tint/ir/binary.h +++ b/src/tint/ir/binary.h @@ -21,6 +21,7 @@ #include "src/tint/ir/instruction.h" #include "src/tint/ir/value.h" #include "src/tint/symbol_table.h" +#include "src/tint/type/type.h" namespace tint::ir { diff --git a/src/tint/ir/binary_test.cc b/src/tint/ir/binary_test.cc index f90661bae0..d7ea7dd8c9 100644 --- a/src/tint/ir/binary_test.cc +++ b/src/tint/ir/binary_test.cc @@ -28,7 +28,8 @@ TEST_F(IR_InstructionTest, CreateAnd) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.And(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.And(b.builder.ir.types.Get(), b.builder.Constant(4_i), + b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kAnd); @@ -54,7 +55,8 @@ TEST_F(IR_InstructionTest, CreateOr) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Or(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Or(b.builder.ir.types.Get(), b.builder.Constant(4_i), + b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kOr); @@ -80,7 +82,8 @@ TEST_F(IR_InstructionTest, CreateXor) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Xor(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Xor(b.builder.ir.types.Get(), b.builder.Constant(4_i), + b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kXor); @@ -106,7 +109,8 @@ TEST_F(IR_InstructionTest, CreateLogicalAnd) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.LogicalAnd(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.LogicalAnd(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kLogicalAnd); @@ -132,7 +136,8 @@ TEST_F(IR_InstructionTest, CreateLogicalOr) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.LogicalOr(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.LogicalOr(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kLogicalOr); @@ -158,7 +163,8 @@ TEST_F(IR_InstructionTest, CreateEqual) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Equal(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Equal(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kEqual); @@ -184,7 +190,8 @@ TEST_F(IR_InstructionTest, CreateNotEqual) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.NotEqual(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.NotEqual(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kNotEqual); @@ -210,7 +217,8 @@ TEST_F(IR_InstructionTest, CreateLessThan) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.LessThan(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.LessThan(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kLessThan); @@ -236,7 +244,8 @@ TEST_F(IR_InstructionTest, CreateGreaterThan) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.GreaterThan(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.GreaterThan(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kGreaterThan); @@ -262,7 +271,8 @@ TEST_F(IR_InstructionTest, CreateLessThanEqual) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.LessThanEqual(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.LessThanEqual(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kLessThanEqual); @@ -288,8 +298,8 @@ TEST_F(IR_InstructionTest, CreateGreaterThanEqual) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = - b.builder.GreaterThanEqual(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.GreaterThanEqual( + b.builder.ir.types.Get(), b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kGreaterThanEqual); @@ -315,7 +325,8 @@ TEST_F(IR_InstructionTest, CreateShiftLeft) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.ShiftLeft(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.ShiftLeft(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kShiftLeft); @@ -341,7 +352,8 @@ TEST_F(IR_InstructionTest, CreateShiftRight) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.ShiftRight(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.ShiftRight(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kShiftRight); @@ -367,7 +379,8 @@ TEST_F(IR_InstructionTest, CreateAdd) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Add(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Add(b.builder.ir.types.Get(), b.builder.Constant(4_i), + b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kAdd); @@ -393,7 +406,8 @@ TEST_F(IR_InstructionTest, CreateSubtract) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Subtract(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Subtract(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kSubtract); @@ -419,7 +433,8 @@ TEST_F(IR_InstructionTest, CreateMultiply) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Multiply(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Multiply(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kMultiply); @@ -445,7 +460,8 @@ TEST_F(IR_InstructionTest, CreateDivide) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Divide(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Divide(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kDivide); @@ -471,7 +487,8 @@ TEST_F(IR_InstructionTest, CreateModulo) { auto& b = CreateEmptyBuilder(); b.builder.next_temp_id = Temp::Id(42); - const auto* instr = b.builder.Modulo(b.builder.Constant(4_i), b.builder.Constant(2_i)); + const auto* instr = b.builder.Modulo(b.builder.ir.types.Get(), + b.builder.Constant(4_i), b.builder.Constant(2_i)); EXPECT_EQ(instr->GetKind(), Binary::Kind::kModulo); diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc index 35e4f700be..2c92f8ccd3 100644 --- a/src/tint/ir/builder.cc +++ b/src/tint/ir/builder.cc @@ -97,80 +97,85 @@ Temp::Id Builder::AllocateTempId() { return next_temp_id++; } -const Binary* Builder::CreateBinary(Binary::Kind kind, const Value* lhs, const Value* rhs) { - return ir.instructions.Create(kind, Temp(), lhs, rhs); +const Binary* Builder::CreateBinary(Binary::Kind kind, + const type::Type* type, + const Value* lhs, + const Value* rhs) { + return ir.instructions.Create(kind, Temp(type), lhs, rhs); } -const Binary* Builder::And(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kAnd, lhs, rhs); +const Binary* Builder::And(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kAnd, type, lhs, rhs); } -const Binary* Builder::Or(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kOr, lhs, rhs); +const Binary* Builder::Or(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kOr, type, lhs, rhs); } -const Binary* Builder::Xor(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kXor, lhs, rhs); +const Binary* Builder::Xor(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kXor, type, lhs, rhs); } -const Binary* Builder::LogicalAnd(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kLogicalAnd, lhs, rhs); +const Binary* Builder::LogicalAnd(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kLogicalAnd, type, lhs, rhs); } -const Binary* Builder::LogicalOr(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kLogicalOr, lhs, rhs); +const Binary* Builder::LogicalOr(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kLogicalOr, type, lhs, rhs); } -const Binary* Builder::Equal(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kEqual, lhs, rhs); +const Binary* Builder::Equal(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kEqual, type, lhs, rhs); } -const Binary* Builder::NotEqual(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kNotEqual, lhs, rhs); +const Binary* Builder::NotEqual(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kNotEqual, type, lhs, rhs); } -const Binary* Builder::LessThan(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kLessThan, lhs, rhs); +const Binary* Builder::LessThan(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kLessThan, type, lhs, rhs); } -const Binary* Builder::GreaterThan(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kGreaterThan, lhs, rhs); +const Binary* Builder::GreaterThan(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kGreaterThan, type, lhs, rhs); } -const Binary* Builder::LessThanEqual(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kLessThanEqual, lhs, rhs); +const Binary* Builder::LessThanEqual(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kLessThanEqual, type, lhs, rhs); } -const Binary* Builder::GreaterThanEqual(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kGreaterThanEqual, lhs, rhs); +const Binary* Builder::GreaterThanEqual(const type::Type* type, + const Value* lhs, + const Value* rhs) { + return CreateBinary(Binary::Kind::kGreaterThanEqual, type, lhs, rhs); } -const Binary* Builder::ShiftLeft(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kShiftLeft, lhs, rhs); +const Binary* Builder::ShiftLeft(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kShiftLeft, type, lhs, rhs); } -const Binary* Builder::ShiftRight(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kShiftRight, lhs, rhs); +const Binary* Builder::ShiftRight(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kShiftRight, type, lhs, rhs); } -const Binary* Builder::Add(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kAdd, lhs, rhs); +const Binary* Builder::Add(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kAdd, type, lhs, rhs); } -const Binary* Builder::Subtract(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kSubtract, lhs, rhs); +const Binary* Builder::Subtract(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kSubtract, type, lhs, rhs); } -const Binary* Builder::Multiply(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kMultiply, lhs, rhs); +const Binary* Builder::Multiply(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kMultiply, type, lhs, rhs); } -const Binary* Builder::Divide(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kDivide, lhs, rhs); +const Binary* Builder::Divide(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kDivide, type, lhs, rhs); } -const Binary* Builder::Modulo(const Value* lhs, const Value* rhs) { - return CreateBinary(Binary::Kind::kModulo, lhs, rhs); +const Binary* Builder::Modulo(const type::Type* type, const Value* lhs, const Value* rhs) { + return CreateBinary(Binary::Kind::kModulo, type, lhs, rhs); } } // namespace tint::ir diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h index ca21b8d278..835c8c0532 100644 --- a/src/tint/ir/builder.h +++ b/src/tint/ir/builder.h @@ -144,123 +144,148 @@ class Builder { } /// Creates a new Temporary + /// @param type the type of the temporary /// @returns the new temporary - const ir::Temp* Temp() { return ir.values.Create(AllocateTempId()); } + const ir::Temp* Temp(const type::Type* type) { + return ir.values.Create(type, AllocateTempId()); + } /// Creates an op for `lhs kind rhs` /// @param kind the kind of operation + /// @param type the result type of the binary expression /// @param lhs the left-hand-side of the operation /// @param rhs the right-hand-side of the operation /// @returns the operation - const Binary* CreateBinary(Binary::Kind kind, const Value* lhs, const Value* rhs); + const Binary* CreateBinary(Binary::Kind kind, + const type::Type* type, + const Value* lhs, + const Value* rhs); /// Creates an And operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* And(const Value* lhs, const Value* rhs); + const Binary* And(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Or operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Or(const Value* lhs, const Value* rhs); + const Binary* Or(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Xor operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Xor(const Value* lhs, const Value* rhs); + const Binary* Xor(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an LogicalAnd operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* LogicalAnd(const Value* lhs, const Value* rhs); + const Binary* LogicalAnd(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an LogicalOr operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* LogicalOr(const Value* lhs, const Value* rhs); + const Binary* LogicalOr(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Equal operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Equal(const Value* lhs, const Value* rhs); + const Binary* Equal(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an NotEqual operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* NotEqual(const Value* lhs, const Value* rhs); + const Binary* NotEqual(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an LessThan operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* LessThan(const Value* lhs, const Value* rhs); + const Binary* LessThan(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an GreaterThan operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* GreaterThan(const Value* lhs, const Value* rhs); + const Binary* GreaterThan(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an LessThanEqual operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* LessThanEqual(const Value* lhs, const Value* rhs); + const Binary* LessThanEqual(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an GreaterThanEqual operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* GreaterThanEqual(const Value* lhs, const Value* rhs); + const Binary* GreaterThanEqual(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an ShiftLeft operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* ShiftLeft(const Value* lhs, const Value* rhs); + const Binary* ShiftLeft(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an ShiftRight operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* ShiftRight(const Value* lhs, const Value* rhs); + const Binary* ShiftRight(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Add operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Add(const Value* lhs, const Value* rhs); + const Binary* Add(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Subtract operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Subtract(const Value* lhs, const Value* rhs); + const Binary* Subtract(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Multiply operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Multiply(const Value* lhs, const Value* rhs); + const Binary* Multiply(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Divide operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Divide(const Value* lhs, const Value* rhs); + const Binary* Divide(const type::Type* type, const Value* lhs, const Value* rhs); /// Creates an Modulo operation + /// @param type the result type of the expression /// @param lhs the lhs of the add /// @param rhs the rhs of the add /// @returns the operation - const Binary* Modulo(const Value* lhs, const Value* rhs); + const Binary* Modulo(const type::Type* type, const Value* lhs, const Value* rhs); /// @returns a unique temp id Temp::Id AllocateTempId(); diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc index afa56f51d1..8196f44f6a 100644 --- a/src/tint/ir/builder_impl.cc +++ b/src/tint/ir/builder_impl.cc @@ -563,61 +563,62 @@ utils::Result BuilderImpl::EmitBinary(const ast::BinaryExpression* return utils::Failure; } + auto* sem = builder.ir.program->Sem().Get(expr); const Binary* instr = nullptr; switch (expr->op) { case ast::BinaryOp::kAnd: - instr = builder.And(lhs.Get(), rhs.Get()); + instr = builder.And(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kOr: - instr = builder.Or(lhs.Get(), rhs.Get()); + instr = builder.Or(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kXor: - instr = builder.Xor(lhs.Get(), rhs.Get()); + instr = builder.Xor(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kLogicalAnd: - instr = builder.LogicalAnd(lhs.Get(), rhs.Get()); + instr = builder.LogicalAnd(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kLogicalOr: - instr = builder.LogicalOr(lhs.Get(), rhs.Get()); + instr = builder.LogicalOr(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kEqual: - instr = builder.Equal(lhs.Get(), rhs.Get()); + instr = builder.Equal(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kNotEqual: - instr = builder.NotEqual(lhs.Get(), rhs.Get()); + instr = builder.NotEqual(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kLessThan: - instr = builder.LessThan(lhs.Get(), rhs.Get()); + instr = builder.LessThan(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kGreaterThan: - instr = builder.GreaterThan(lhs.Get(), rhs.Get()); + instr = builder.GreaterThan(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kLessThanEqual: - instr = builder.LessThanEqual(lhs.Get(), rhs.Get()); + instr = builder.LessThanEqual(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kGreaterThanEqual: - instr = builder.GreaterThanEqual(lhs.Get(), rhs.Get()); + instr = builder.GreaterThanEqual(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kShiftLeft: - instr = builder.ShiftLeft(lhs.Get(), rhs.Get()); + instr = builder.ShiftLeft(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kShiftRight: - instr = builder.ShiftRight(lhs.Get(), rhs.Get()); + instr = builder.ShiftRight(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kAdd: - instr = builder.Add(lhs.Get(), rhs.Get()); + instr = builder.Add(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kSubtract: - instr = builder.Subtract(lhs.Get(), rhs.Get()); + instr = builder.Subtract(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kMultiply: - instr = builder.Multiply(lhs.Get(), rhs.Get()); + instr = builder.Multiply(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kDivide: - instr = builder.Divide(lhs.Get(), rhs.Get()); + instr = builder.Divide(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kModulo: - instr = builder.Modulo(lhs.Get(), rhs.Get()); + instr = builder.Modulo(sem->Type(), lhs.Get(), rhs.Get()); break; case ast::BinaryOp::kNone: TINT_ICE(IR, diagnostics_) << "missing binary operand type"; diff --git a/src/tint/ir/constant.h b/src/tint/ir/constant.h index c5ee079db4..10b41d337b 100644 --- a/src/tint/ir/constant.h +++ b/src/tint/ir/constant.h @@ -31,6 +31,9 @@ class Constant : public Castable { explicit Constant(const constant::Value* val); ~Constant() override; + /// @returns the type of the constant + const type::Type* Type() const override { return value->Type(); } + /// Write the constant to the given stream /// @param out the stream to write to /// @param st the symbol table diff --git a/src/tint/ir/temp.cc b/src/tint/ir/temp.cc index afd0921265..57def51544 100644 --- a/src/tint/ir/temp.cc +++ b/src/tint/ir/temp.cc @@ -20,7 +20,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Temp); namespace tint::ir { -Temp::Temp(Id id) : id_(id) {} +Temp::Temp(const type::Type* type, Id id) : type_(type), id_(id) {} Temp::~Temp() = default; diff --git a/src/tint/ir/temp.h b/src/tint/ir/temp.h index 2db81f3ed8..1a4a38defd 100644 --- a/src/tint/ir/temp.h +++ b/src/tint/ir/temp.h @@ -29,8 +29,9 @@ class Temp : public Castable { using Id = uint32_t; /// Constructor + /// @param type the type of the temporary /// @param id the id for the value - explicit Temp(Id id); + Temp(const type::Type* type, Id id); /// Destructor ~Temp() override; @@ -44,6 +45,9 @@ class Temp : public Castable { /// @returns the value data as an `Id`. Id AsId() const { return id_; } + /// @returns the type of the temporary + const type::Type* Type() const override { return type_; } + /// Write the temp to the given stream /// @param out the stream to write to /// @param st the symbol table @@ -51,6 +55,7 @@ class Temp : public Castable { std::ostream& ToString(std::ostream& out, const SymbolTable& st) const override; private: + const type::Type* type_ = nullptr; Id id_ = 0; }; diff --git a/src/tint/ir/temp_test.cc b/src/tint/ir/temp_test.cc index 6b79a51abb..f1fc24f1ec 100644 --- a/src/tint/ir/temp_test.cc +++ b/src/tint/ir/temp_test.cc @@ -30,7 +30,7 @@ TEST_F(IR_TempTest, id) { std::stringstream str; b.builder.next_temp_id = Temp::Id(4); - auto* val = b.builder.Temp(); + auto* val = b.builder.Temp(b.builder.ir.types.Get()); EXPECT_EQ(4u, val->AsId()); val->ToString(str, program->Symbols()); diff --git a/src/tint/ir/value.h b/src/tint/ir/value.h index 6dffc3e596..3b3965311b 100644 --- a/src/tint/ir/value.h +++ b/src/tint/ir/value.h @@ -19,6 +19,7 @@ #include "src/tint/castable.h" #include "src/tint/symbol_table.h" +#include "src/tint/type/type.h" namespace tint::ir { @@ -34,6 +35,9 @@ class Value : public Castable { Value& operator=(const Value&) = delete; Value& operator=(Value&&) = delete; + /// @returns the type of the value + virtual const type::Type* Type() const = 0; + /// Write the value to the given stream /// @param out the stream to write to /// @param st the symbol table