363 lines
13 KiB
C++
363 lines
13 KiB
C++
// Copyright 2022 The Tint Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef SRC_TINT_IR_BUILDER_H_
|
|
#define SRC_TINT_IR_BUILDER_H_
|
|
|
|
#include <utility>
|
|
|
|
#include "src/tint/constant/scalar.h"
|
|
#include "src/tint/ir/binary.h"
|
|
#include "src/tint/ir/bitcast.h"
|
|
#include "src/tint/ir/block_param.h"
|
|
#include "src/tint/ir/builtin.h"
|
|
#include "src/tint/ir/constant.h"
|
|
#include "src/tint/ir/construct.h"
|
|
#include "src/tint/ir/convert.h"
|
|
#include "src/tint/ir/discard.h"
|
|
#include "src/tint/ir/function.h"
|
|
#include "src/tint/ir/function_param.h"
|
|
#include "src/tint/ir/function_terminator.h"
|
|
#include "src/tint/ir/if.h"
|
|
#include "src/tint/ir/load.h"
|
|
#include "src/tint/ir/loop.h"
|
|
#include "src/tint/ir/module.h"
|
|
#include "src/tint/ir/root_terminator.h"
|
|
#include "src/tint/ir/store.h"
|
|
#include "src/tint/ir/switch.h"
|
|
#include "src/tint/ir/unary.h"
|
|
#include "src/tint/ir/user_call.h"
|
|
#include "src/tint/ir/value.h"
|
|
#include "src/tint/ir/var.h"
|
|
#include "src/tint/type/bool.h"
|
|
#include "src/tint/type/f16.h"
|
|
#include "src/tint/type/f32.h"
|
|
#include "src/tint/type/i32.h"
|
|
#include "src/tint/type/u32.h"
|
|
#include "src/tint/type/vector.h"
|
|
#include "src/tint/type/void.h"
|
|
|
|
namespace tint::ir {
|
|
|
|
/// Builds an ir::Module
|
|
class Builder {
|
|
public:
|
|
/// Constructor
|
|
/// @param mod the ir::Module to wrap with this builder
|
|
explicit Builder(Module& mod);
|
|
/// Destructor
|
|
~Builder();
|
|
|
|
/// @returns a new block flow node
|
|
Block* CreateBlock();
|
|
|
|
/// @returns a new root terminator flow node
|
|
RootTerminator* CreateRootTerminator();
|
|
|
|
/// @returns a new function terminator flow node
|
|
FunctionTerminator* CreateFunctionTerminator();
|
|
|
|
/// Creates a function flow node
|
|
/// @param name the function name
|
|
/// @param return_type the function return type
|
|
/// @param stage the function stage
|
|
/// @param wg_size the workgroup_size
|
|
/// @returns the flow node
|
|
Function* CreateFunction(std::string_view name,
|
|
const type::Type* return_type,
|
|
Function::PipelineStage stage = Function::PipelineStage::kUndefined,
|
|
std::optional<std::array<uint32_t, 3>> wg_size = {});
|
|
|
|
/// Creates an if flow node
|
|
/// @param condition the if condition
|
|
/// @returns the flow node
|
|
If* CreateIf(Value* condition);
|
|
|
|
/// Creates a loop flow node
|
|
/// @returns the flow node
|
|
Loop* CreateLoop();
|
|
|
|
/// Creates a switch flow node
|
|
/// @param condition the switch condition
|
|
/// @returns the flow node
|
|
Switch* CreateSwitch(Value* condition);
|
|
|
|
/// Creates a case flow node for the given case branch.
|
|
/// @param s the switch to create the case into
|
|
/// @param selectors the case selectors for the case statement
|
|
/// @returns the start block for the case flow node
|
|
Block* CreateCase(Switch* s, utils::VectorRef<Switch::CaseSelector> selectors);
|
|
|
|
/// Creates a new ir::Constant
|
|
/// @param val the constant value
|
|
/// @returns the new constant
|
|
ir::Constant* Constant(const constant::Value* val) {
|
|
return ir.constants.GetOrCreate(val, [&]() { return ir.values.Create<ir::Constant>(val); });
|
|
}
|
|
|
|
/// Creates a ir::Constant for an i32 Scalar
|
|
/// @param v the value
|
|
/// @returns the new constant
|
|
ir::Constant* Constant(i32 v) { return Constant(ir.constant_values.Get(v)); }
|
|
|
|
/// Creates a ir::Constant for a u32 Scalar
|
|
/// @param v the value
|
|
/// @returns the new constant
|
|
ir::Constant* Constant(u32 v) { return Constant(ir.constant_values.Get(v)); }
|
|
|
|
/// Creates a ir::Constant for a f32 Scalar
|
|
/// @param v the value
|
|
/// @returns the new constant
|
|
ir::Constant* Constant(f32 v) { return Constant(ir.constant_values.Get(v)); }
|
|
|
|
/// Creates a ir::Constant for a f16 Scalar
|
|
/// @param v the value
|
|
/// @returns the new constant
|
|
ir::Constant* Constant(f16 v) { return Constant(ir.constant_values.Get(v)); }
|
|
|
|
/// Creates a ir::Constant for a bool Scalar
|
|
/// @param v the value
|
|
/// @returns the new constant
|
|
ir::Constant* Constant(bool v) { return Constant(ir.constant_values.Get(v)); }
|
|
|
|
/// Creates an op for `lhs kind rhs`
|
|
/// @param kind the kind of operation
|
|
/// @param type the result type of the binary expression
|
|
/// @param lhs the left-hand-side of the operation
|
|
/// @param rhs the right-hand-side of the operation
|
|
/// @returns the operation
|
|
Binary* CreateBinary(enum Binary::Kind kind, const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an And operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* And(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Or operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Or(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Xor operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Xor(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Equal operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Equal(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an NotEqual operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* NotEqual(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an LessThan operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* LessThan(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an GreaterThan operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* GreaterThan(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an LessThanEqual operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* LessThanEqual(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an GreaterThanEqual operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* GreaterThanEqual(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an ShiftLeft operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* ShiftLeft(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an ShiftRight operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* ShiftRight(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Add operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Add(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Subtract operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Subtract(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Multiply operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Multiply(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Divide operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Divide(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an Modulo operation
|
|
/// @param type the result type of the expression
|
|
/// @param lhs the lhs of the add
|
|
/// @param rhs the rhs of the add
|
|
/// @returns the operation
|
|
Binary* Modulo(const type::Type* type, Value* lhs, Value* rhs);
|
|
|
|
/// Creates an op for `kind val`
|
|
/// @param kind the kind of operation
|
|
/// @param type the result type of the binary expression
|
|
/// @param val the value of the operation
|
|
/// @returns the operation
|
|
Unary* CreateUnary(enum Unary::Kind kind, const type::Type* type, Value* val);
|
|
|
|
/// Creates a Complement operation
|
|
/// @param type the result type of the expression
|
|
/// @param val the value
|
|
/// @returns the operation
|
|
Unary* Complement(const type::Type* type, Value* val);
|
|
|
|
/// Creates a Negation operation
|
|
/// @param type the result type of the expression
|
|
/// @param val the value
|
|
/// @returns the operation
|
|
Unary* Negation(const type::Type* type, Value* val);
|
|
|
|
/// Creates a Not operation
|
|
/// @param type the result type of the expression
|
|
/// @param val the value
|
|
/// @returns the operation
|
|
Binary* Not(const type::Type* type, Value* val);
|
|
|
|
/// Creates a bitcast instruction
|
|
/// @param type the result type of the bitcast
|
|
/// @param val the value being bitcast
|
|
/// @returns the instruction
|
|
ir::Bitcast* Bitcast(const type::Type* type, Value* val);
|
|
|
|
/// Creates a discard instruction
|
|
/// @returns the instruction
|
|
ir::Discard* Discard();
|
|
|
|
/// Creates a user function call instruction
|
|
/// @param type the return type of the call
|
|
/// @param func the function being called
|
|
/// @param args the call arguments
|
|
/// @returns the instruction
|
|
ir::UserCall* UserCall(const type::Type* type, Function* func, utils::VectorRef<Value*> args);
|
|
|
|
/// Creates a value conversion instruction
|
|
/// @param to the type converted to
|
|
/// @param from the type converted from
|
|
/// @param args the arguments to be converted
|
|
/// @returns the instruction
|
|
ir::Convert* Convert(const type::Type* to,
|
|
const type::Type* from,
|
|
utils::VectorRef<Value*> args);
|
|
|
|
/// Creates a value constructor instruction
|
|
/// @param to the type being converted
|
|
/// @param args the arguments to be converted
|
|
/// @returns the instruction
|
|
ir::Construct* Construct(const type::Type* to, utils::VectorRef<Value*> args);
|
|
|
|
/// Creates a builtin call instruction
|
|
/// @param type the return type
|
|
/// @param func the builtin function
|
|
/// @param args the arguments to be converted
|
|
/// @returns the instruction
|
|
ir::Builtin* Builtin(const type::Type* type,
|
|
builtin::Function func,
|
|
utils::VectorRef<Value*> args);
|
|
|
|
/// Creates a load instruction
|
|
/// @param from the expression being loaded from
|
|
/// @returns the instruction
|
|
ir::Load* Load(Value* from);
|
|
|
|
/// Creates a store instruction
|
|
/// @param to the expression being stored too
|
|
/// @param from the expression being stored
|
|
/// @returns the instruction
|
|
ir::Store* Store(Value* to, Value* from);
|
|
|
|
/// Creates a new `var` declaration
|
|
/// @param type the var type
|
|
/// @returns the instruction
|
|
ir::Var* Declare(const type::Type* type);
|
|
|
|
/// Creates a branch declaration
|
|
/// @param to the node being branched too
|
|
/// @param args the branch arguments
|
|
/// @returns the instruction
|
|
ir::Branch* Branch(Block* to, utils::VectorRef<Value*> args = {});
|
|
|
|
/// Creates a new `BlockParam`
|
|
/// @param type the parameter type
|
|
/// @returns the value
|
|
ir::BlockParam* BlockParam(const type::Type* type);
|
|
|
|
/// Creates a new `FunctionParam`
|
|
/// @param type the parameter type
|
|
/// @returns the value
|
|
ir::FunctionParam* FunctionParam(const type::Type* type);
|
|
|
|
/// Retrieves the root block for the module, creating if necessary
|
|
/// @returns the root block
|
|
ir::Block* CreateRootBlockIfNeeded();
|
|
|
|
/// The IR module.
|
|
Module& ir;
|
|
};
|
|
|
|
} // namespace tint::ir
|
|
|
|
#endif // SRC_TINT_IR_BUILDER_H_
|