[ir] Set default flow node values.
For the `if`, `switch` and `function` flow nodes this Cl makes a few required fields part of the constructor. This cuts down boilerplate when creating new nodes. Bug: tint:1718 Change-Id: I739bcefc2ed36b0203b57974b50bb2b79f6e1684 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132980 Commit-Queue: Dan Sinclair <dsinclair@chromium.org> Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
809187c579
commit
c9923d2ee3
|
@ -47,8 +47,10 @@ FunctionTerminator* Builder::CreateFunctionTerminator() {
|
|||
return ir.flow_nodes.Create<FunctionTerminator>();
|
||||
}
|
||||
|
||||
Function* Builder::CreateFunction() {
|
||||
auto* ir_func = ir.flow_nodes.Create<Function>();
|
||||
Function* Builder::CreateFunction(Symbol name, type::Type* return_type) {
|
||||
TINT_ASSERT(IR, return_type);
|
||||
|
||||
auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type);
|
||||
ir_func->start_target = CreateBlock();
|
||||
ir_func->end_target = CreateFunctionTerminator();
|
||||
|
||||
|
@ -58,8 +60,10 @@ Function* Builder::CreateFunction() {
|
|||
return ir_func;
|
||||
}
|
||||
|
||||
If* Builder::CreateIf() {
|
||||
auto* ir_if = ir.flow_nodes.Create<If>();
|
||||
If* Builder::CreateIf(Value* condition) {
|
||||
TINT_ASSERT(IR, condition);
|
||||
|
||||
auto* ir_if = ir.flow_nodes.Create<If>(condition);
|
||||
ir_if->true_.target = CreateBlock();
|
||||
ir_if->false_.target = CreateBlock();
|
||||
ir_if->merge.target = CreateBlock();
|
||||
|
@ -83,8 +87,8 @@ Loop* Builder::CreateLoop() {
|
|||
return ir_loop;
|
||||
}
|
||||
|
||||
Switch* Builder::CreateSwitch() {
|
||||
auto* ir_switch = ir.flow_nodes.Create<Switch>();
|
||||
Switch* Builder::CreateSwitch(Value* condition) {
|
||||
auto* ir_switch = ir.flow_nodes.Create<Switch>(condition);
|
||||
ir_switch->merge.target = CreateBlock();
|
||||
return ir_switch;
|
||||
}
|
||||
|
|
|
@ -65,20 +65,24 @@ class Builder {
|
|||
FunctionTerminator* CreateFunctionTerminator();
|
||||
|
||||
/// Creates a function flow node
|
||||
/// @param name the function name
|
||||
/// @param return_type the function return type
|
||||
/// @returns the flow node
|
||||
Function* CreateFunction();
|
||||
Function* CreateFunction(Symbol name, type::Type* return_type);
|
||||
|
||||
/// Creates an if flow node
|
||||
/// @param condition the if condition
|
||||
/// @returns the flow node
|
||||
If* CreateIf();
|
||||
If* CreateIf(Value* condition);
|
||||
|
||||
/// Creates a loop flow node
|
||||
/// @returns the flow node
|
||||
Loop* CreateLoop();
|
||||
|
||||
/// Creates a switch flow node
|
||||
/// @param condition the switch condition
|
||||
/// @returns the flow node
|
||||
Switch* CreateSwitch();
|
||||
Switch* CreateSwitch(Value* condition);
|
||||
|
||||
/// Creates a case flow node for the given case branch.
|
||||
/// @param s the switch to create the case into
|
||||
|
|
|
@ -205,14 +205,16 @@ void BuilderImpl::EmitFunction(const ast::Function* ast_func) {
|
|||
// The flow stack should have been emptied when the previous function finished building.
|
||||
TINT_ASSERT(IR, flow_stack.IsEmpty());
|
||||
|
||||
auto* ir_func = builder.CreateFunction();
|
||||
ir_func->name = CloneSymbol(ast_func->name->symbol);
|
||||
const auto* sem = program_->Sem().Get(ast_func);
|
||||
|
||||
auto* ir_func = builder.CreateFunction(CloneSymbol(ast_func->name->symbol),
|
||||
sem->ReturnType()->Clone(clone_ctx_.type_ctx));
|
||||
|
||||
current_function_ = ir_func;
|
||||
builder.ir.functions.Push(ir_func);
|
||||
|
||||
ast_to_flow_[ast_func] = ir_func;
|
||||
|
||||
const auto* sem = program_->Sem().Get(ast_func);
|
||||
if (ast_func->IsEntryPoint()) {
|
||||
builder.ir.entry_points.Push(ir_func);
|
||||
|
||||
|
@ -280,7 +282,6 @@ void BuilderImpl::EmitFunction(const ast::Function* ast_func) {
|
|||
});
|
||||
}
|
||||
}
|
||||
ir_func->return_type = sem->ReturnType()->Clone(clone_ctx_.type_ctx);
|
||||
ir_func->return_location = sem->ReturnLocation();
|
||||
|
||||
{
|
||||
|
@ -430,14 +431,12 @@ void BuilderImpl::EmitBlock(const ast::BlockStatement* block) {
|
|||
}
|
||||
|
||||
void BuilderImpl::EmitIf(const ast::IfStatement* stmt) {
|
||||
auto* if_node = builder.CreateIf();
|
||||
|
||||
// Emit the if condition into the end of the preceding block
|
||||
auto reg = EmitExpression(stmt->condition);
|
||||
if (!reg) {
|
||||
return;
|
||||
}
|
||||
if_node->condition = reg.Get();
|
||||
auto* if_node = builder.CreateIf(reg.Get());
|
||||
|
||||
BranchTo(if_node);
|
||||
|
||||
|
@ -525,13 +524,12 @@ void BuilderImpl::EmitWhile(const ast::WhileStatement* stmt) {
|
|||
}
|
||||
|
||||
// Create an `if (cond) {} else {break;}` control flow
|
||||
auto* if_node = builder.CreateIf();
|
||||
auto* if_node = builder.CreateIf(reg.Get());
|
||||
TINT_ASSERT(IR, if_node->true_.target->Is<Block>());
|
||||
builder.Branch(if_node->true_.target->As<Block>(), if_node->merge.target, utils::Empty);
|
||||
|
||||
TINT_ASSERT(IR, if_node->false_.target->Is<Block>());
|
||||
builder.Branch(if_node->false_.target->As<Block>(), loop_node->merge.target, utils::Empty);
|
||||
if_node->condition = reg.Get();
|
||||
|
||||
BranchTo(if_node);
|
||||
|
||||
|
@ -577,14 +575,14 @@ void BuilderImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
|
|||
}
|
||||
|
||||
// Create an `if (cond) {} else {break;}` control flow
|
||||
auto* if_node = builder.CreateIf();
|
||||
auto* if_node = builder.CreateIf(reg.Get());
|
||||
|
||||
TINT_ASSERT(IR, if_node->true_.target->Is<Block>());
|
||||
builder.Branch(if_node->true_.target->As<Block>(), if_node->merge.target, utils::Empty);
|
||||
|
||||
TINT_ASSERT(IR, if_node->false_.target->Is<Block>());
|
||||
builder.Branch(if_node->false_.target->As<Block>(), loop_node->merge.target,
|
||||
utils::Empty);
|
||||
if_node->condition = reg.Get();
|
||||
|
||||
BranchTo(if_node);
|
||||
current_flow_block = if_node->merge.target->As<Block>();
|
||||
|
@ -605,14 +603,12 @@ void BuilderImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
|
|||
}
|
||||
|
||||
void BuilderImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
|
||||
auto* switch_node = builder.CreateSwitch();
|
||||
|
||||
// Emit the condition into the preceding block
|
||||
auto reg = EmitExpression(stmt->condition);
|
||||
if (!reg) {
|
||||
return;
|
||||
}
|
||||
switch_node->condition = reg.Get();
|
||||
auto* switch_node = builder.CreateSwitch(reg.Get());
|
||||
|
||||
BranchTo(switch_node);
|
||||
|
||||
|
@ -692,14 +688,12 @@ void BuilderImpl::EmitDiscard(const ast::DiscardStatement*) {
|
|||
}
|
||||
|
||||
void BuilderImpl::EmitBreakIf(const ast::BreakIfStatement* stmt) {
|
||||
auto* if_node = builder.CreateIf();
|
||||
|
||||
// Emit the break-if condition into the end of the preceding block
|
||||
auto reg = EmitExpression(stmt->condition);
|
||||
if (!reg) {
|
||||
return;
|
||||
}
|
||||
if_node->condition = reg.Get();
|
||||
auto* if_node = builder.CreateIf(reg.Get());
|
||||
|
||||
BranchTo(if_node);
|
||||
|
||||
|
@ -876,8 +870,7 @@ utils::Result<Value*> BuilderImpl::EmitShortCircuit(const ast::BinaryExpression*
|
|||
auto* lhs_store = builder.Store(result_var, lhs.Get());
|
||||
current_flow_block->instructions.Push(lhs_store);
|
||||
|
||||
auto* if_node = builder.CreateIf();
|
||||
if_node->condition = lhs.Get();
|
||||
auto* if_node = builder.CreateIf(lhs.Get());
|
||||
BranchTo(if_node);
|
||||
|
||||
utils::Result<Value*> rhs;
|
||||
|
|
|
@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Function);
|
|||
|
||||
namespace tint::ir {
|
||||
|
||||
Function::Function() : Base() {}
|
||||
Function::Function(Symbol n, type::Type* rt) : Base(), name(n), return_type(rt) {}
|
||||
|
||||
Function::~Function() = default;
|
||||
|
||||
|
|
|
@ -62,7 +62,9 @@ class Function : public utils::Castable<Function, FlowNode> {
|
|||
};
|
||||
|
||||
/// Constructor
|
||||
Function();
|
||||
/// @param n the function name
|
||||
/// @param rt the function return type
|
||||
Function(Symbol n, type::Type* rt);
|
||||
~Function() override;
|
||||
|
||||
/// The function name
|
||||
|
|
|
@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::If);
|
|||
|
||||
namespace tint::ir {
|
||||
|
||||
If::If() : Base() {}
|
||||
If::If(Value* cond) : Base(), condition(cond) {}
|
||||
|
||||
If::~If() = default;
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ namespace tint::ir {
|
|||
class If : public utils::Castable<If, FlowNode> {
|
||||
public:
|
||||
/// Constructor
|
||||
If();
|
||||
/// @param cond the if condition
|
||||
explicit If(Value* cond);
|
||||
~If() override;
|
||||
|
||||
/// The true branch block
|
||||
|
|
|
@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Switch);
|
|||
|
||||
namespace tint::ir {
|
||||
|
||||
Switch::Switch() : Base() {}
|
||||
Switch::Switch(Value* cond) : Base(), condition(cond) {}
|
||||
|
||||
Switch::~Switch() = default;
|
||||
|
||||
|
|
|
@ -44,7 +44,8 @@ class Switch : public utils::Castable<Switch, FlowNode> {
|
|||
};
|
||||
|
||||
/// Constructor
|
||||
Switch();
|
||||
/// @param cond the condition
|
||||
explicit Switch(Value* cond);
|
||||
~Switch() override;
|
||||
|
||||
/// The switch merge target
|
||||
|
|
|
@ -20,13 +20,11 @@ namespace tint::writer::spirv {
|
|||
namespace {
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Binary_Add_I32) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("foo");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
func->start_target->instructions.Push(b.CreateBinary(
|
||||
ir::Binary::Kind::kAdd, mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i)));
|
||||
func->start_target->instructions.Push(
|
||||
b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i)));
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
|
@ -44,13 +42,11 @@ OpFunctionEnd
|
|||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Binary_Add_U32) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("foo");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
func->start_target->instructions.Push(b.CreateBinary(
|
||||
ir::Binary::Kind::kAdd, mod.types.Get<type::U32>(), b.Constant(1_u), b.Constant(2_u)));
|
||||
func->start_target->instructions.Push(
|
||||
b.Add(mod.types.Get<type::U32>(), b.Constant(1_u), b.Constant(2_u)));
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
|
@ -68,13 +64,11 @@ OpFunctionEnd
|
|||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Binary_Add_F32) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("foo");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
func->start_target->instructions.Push(b.CreateBinary(
|
||||
ir::Binary::Kind::kAdd, mod.types.Get<type::F32>(), b.Constant(1_f), b.Constant(2_f)));
|
||||
func->start_target->instructions.Push(
|
||||
b.Add(mod.types.Get<type::F32>(), b.Constant(1_f), b.Constant(2_f)));
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
|
@ -92,16 +86,12 @@ OpFunctionEnd
|
|||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Binary_Add_Chain) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("foo");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
auto* a = b.CreateBinary(ir::Binary::Kind::kAdd, mod.types.Get<type::I32>(), b.Constant(1_i),
|
||||
b.Constant(2_i));
|
||||
auto* a = b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i));
|
||||
func->start_target->instructions.Push(a);
|
||||
func->start_target->instructions.Push(
|
||||
b.CreateBinary(ir::Binary::Kind::kAdd, mod.types.Get<type::I32>(), a, a));
|
||||
func->start_target->instructions.Push(b.Add(mod.types.Get<type::I32>(), a, a));
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
|
|
|
@ -18,9 +18,7 @@ namespace tint::writer::spirv {
|
|||
namespace {
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Function_Empty) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("foo");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
|
@ -36,8 +34,7 @@ OpFunctionEnd
|
|||
|
||||
// Test that we do not emit the same function type more than once.
|
||||
TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
|
@ -49,9 +46,7 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
|
|||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("main");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
|
||||
func->pipeline_stage = ir::Function::PipelineStage::kCompute;
|
||||
func->workgroup_size = {32, 4, 1};
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
@ -70,9 +65,7 @@ OpFunctionEnd
|
|||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("main");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
|
||||
func->pipeline_stage = ir::Function::PipelineStage::kFragment;
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
|
@ -90,9 +83,7 @@ OpFunctionEnd
|
|||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
|
||||
auto* func = b.CreateFunction();
|
||||
func->name = mod.symbols.Register("main");
|
||||
func->return_type = mod.types.Get<type::Void>();
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
|
||||
func->pipeline_stage = ir::Function::PipelineStage::kVertex;
|
||||
func->start_target->branch.target = func->end_target;
|
||||
|
||||
|
@ -109,23 +100,17 @@ OpFunctionEnd
|
|||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
|
||||
auto* f1 = b.CreateFunction();
|
||||
f1->name = mod.symbols.Register("main1");
|
||||
f1->return_type = mod.types.Get<type::Void>();
|
||||
auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>());
|
||||
f1->pipeline_stage = ir::Function::PipelineStage::kCompute;
|
||||
f1->workgroup_size = {32, 4, 1};
|
||||
f1->start_target->branch.target = f1->end_target;
|
||||
|
||||
auto* f2 = b.CreateFunction();
|
||||
f2->name = mod.symbols.Register("main2");
|
||||
f2->return_type = mod.types.Get<type::Void>();
|
||||
auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>());
|
||||
f2->pipeline_stage = ir::Function::PipelineStage::kCompute;
|
||||
f2->workgroup_size = {8, 2, 16};
|
||||
f2->start_target->branch.target = f2->end_target;
|
||||
|
||||
auto* f3 = b.CreateFunction();
|
||||
f3->name = mod.symbols.Register("main3");
|
||||
f3->return_type = mod.types.Get<type::Void>();
|
||||
auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>());
|
||||
f3->pipeline_stage = ir::Function::PipelineStage::kFragment;
|
||||
f3->start_target->branch.target = f3->end_target;
|
||||
|
||||
|
|
Loading…
Reference in New Issue