[ir] Set default flow node values.

For the `if`, `switch` and `function` flow nodes this Cl makes a few
required fields part of the constructor. This cuts down boilerplate when
creating new nodes.

Bug: tint:1718
Change-Id: I739bcefc2ed36b0203b57974b50bb2b79f6e1684
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132980
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair 2023-05-15 23:44:04 +00:00 committed by Dawn LUCI CQ
parent 809187c579
commit c9923d2ee3
11 changed files with 59 additions and 79 deletions

View File

@ -47,8 +47,10 @@ FunctionTerminator* Builder::CreateFunctionTerminator() {
return ir.flow_nodes.Create<FunctionTerminator>();
}
Function* Builder::CreateFunction() {
auto* ir_func = ir.flow_nodes.Create<Function>();
Function* Builder::CreateFunction(Symbol name, type::Type* return_type) {
TINT_ASSERT(IR, return_type);
auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type);
ir_func->start_target = CreateBlock();
ir_func->end_target = CreateFunctionTerminator();
@ -58,8 +60,10 @@ Function* Builder::CreateFunction() {
return ir_func;
}
If* Builder::CreateIf() {
auto* ir_if = ir.flow_nodes.Create<If>();
If* Builder::CreateIf(Value* condition) {
TINT_ASSERT(IR, condition);
auto* ir_if = ir.flow_nodes.Create<If>(condition);
ir_if->true_.target = CreateBlock();
ir_if->false_.target = CreateBlock();
ir_if->merge.target = CreateBlock();
@ -83,8 +87,8 @@ Loop* Builder::CreateLoop() {
return ir_loop;
}
Switch* Builder::CreateSwitch() {
auto* ir_switch = ir.flow_nodes.Create<Switch>();
Switch* Builder::CreateSwitch(Value* condition) {
auto* ir_switch = ir.flow_nodes.Create<Switch>(condition);
ir_switch->merge.target = CreateBlock();
return ir_switch;
}

View File

@ -65,20 +65,24 @@ class Builder {
FunctionTerminator* CreateFunctionTerminator();
/// Creates a function flow node
/// @param name the function name
/// @param return_type the function return type
/// @returns the flow node
Function* CreateFunction();
Function* CreateFunction(Symbol name, type::Type* return_type);
/// Creates an if flow node
/// @param condition the if condition
/// @returns the flow node
If* CreateIf();
If* CreateIf(Value* condition);
/// Creates a loop flow node
/// @returns the flow node
Loop* CreateLoop();
/// Creates a switch flow node
/// @param condition the switch condition
/// @returns the flow node
Switch* CreateSwitch();
Switch* CreateSwitch(Value* condition);
/// Creates a case flow node for the given case branch.
/// @param s the switch to create the case into

View File

@ -205,14 +205,16 @@ void BuilderImpl::EmitFunction(const ast::Function* ast_func) {
// The flow stack should have been emptied when the previous function finished building.
TINT_ASSERT(IR, flow_stack.IsEmpty());
auto* ir_func = builder.CreateFunction();
ir_func->name = CloneSymbol(ast_func->name->symbol);
const auto* sem = program_->Sem().Get(ast_func);
auto* ir_func = builder.CreateFunction(CloneSymbol(ast_func->name->symbol),
sem->ReturnType()->Clone(clone_ctx_.type_ctx));
current_function_ = ir_func;
builder.ir.functions.Push(ir_func);
ast_to_flow_[ast_func] = ir_func;
const auto* sem = program_->Sem().Get(ast_func);
if (ast_func->IsEntryPoint()) {
builder.ir.entry_points.Push(ir_func);
@ -280,7 +282,6 @@ void BuilderImpl::EmitFunction(const ast::Function* ast_func) {
});
}
}
ir_func->return_type = sem->ReturnType()->Clone(clone_ctx_.type_ctx);
ir_func->return_location = sem->ReturnLocation();
{
@ -430,14 +431,12 @@ void BuilderImpl::EmitBlock(const ast::BlockStatement* block) {
}
void BuilderImpl::EmitIf(const ast::IfStatement* stmt) {
auto* if_node = builder.CreateIf();
// Emit the if condition into the end of the preceding block
auto reg = EmitExpression(stmt->condition);
if (!reg) {
return;
}
if_node->condition = reg.Get();
auto* if_node = builder.CreateIf(reg.Get());
BranchTo(if_node);
@ -525,13 +524,12 @@ void BuilderImpl::EmitWhile(const ast::WhileStatement* stmt) {
}
// Create an `if (cond) {} else {break;}` control flow
auto* if_node = builder.CreateIf();
auto* if_node = builder.CreateIf(reg.Get());
TINT_ASSERT(IR, if_node->true_.target->Is<Block>());
builder.Branch(if_node->true_.target->As<Block>(), if_node->merge.target, utils::Empty);
TINT_ASSERT(IR, if_node->false_.target->Is<Block>());
builder.Branch(if_node->false_.target->As<Block>(), loop_node->merge.target, utils::Empty);
if_node->condition = reg.Get();
BranchTo(if_node);
@ -577,14 +575,14 @@ void BuilderImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
}
// Create an `if (cond) {} else {break;}` control flow
auto* if_node = builder.CreateIf();
auto* if_node = builder.CreateIf(reg.Get());
TINT_ASSERT(IR, if_node->true_.target->Is<Block>());
builder.Branch(if_node->true_.target->As<Block>(), if_node->merge.target, utils::Empty);
TINT_ASSERT(IR, if_node->false_.target->Is<Block>());
builder.Branch(if_node->false_.target->As<Block>(), loop_node->merge.target,
utils::Empty);
if_node->condition = reg.Get();
BranchTo(if_node);
current_flow_block = if_node->merge.target->As<Block>();
@ -605,14 +603,12 @@ void BuilderImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
}
void BuilderImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
auto* switch_node = builder.CreateSwitch();
// Emit the condition into the preceding block
auto reg = EmitExpression(stmt->condition);
if (!reg) {
return;
}
switch_node->condition = reg.Get();
auto* switch_node = builder.CreateSwitch(reg.Get());
BranchTo(switch_node);
@ -692,14 +688,12 @@ void BuilderImpl::EmitDiscard(const ast::DiscardStatement*) {
}
void BuilderImpl::EmitBreakIf(const ast::BreakIfStatement* stmt) {
auto* if_node = builder.CreateIf();
// Emit the break-if condition into the end of the preceding block
auto reg = EmitExpression(stmt->condition);
if (!reg) {
return;
}
if_node->condition = reg.Get();
auto* if_node = builder.CreateIf(reg.Get());
BranchTo(if_node);
@ -876,8 +870,7 @@ utils::Result<Value*> BuilderImpl::EmitShortCircuit(const ast::BinaryExpression*
auto* lhs_store = builder.Store(result_var, lhs.Get());
current_flow_block->instructions.Push(lhs_store);
auto* if_node = builder.CreateIf();
if_node->condition = lhs.Get();
auto* if_node = builder.CreateIf(lhs.Get());
BranchTo(if_node);
utils::Result<Value*> rhs;

View File

@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Function);
namespace tint::ir {
Function::Function() : Base() {}
Function::Function(Symbol n, type::Type* rt) : Base(), name(n), return_type(rt) {}
Function::~Function() = default;

View File

@ -62,7 +62,9 @@ class Function : public utils::Castable<Function, FlowNode> {
};
/// Constructor
Function();
/// @param n the function name
/// @param rt the function return type
Function(Symbol n, type::Type* rt);
~Function() override;
/// The function name

View File

@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::If);
namespace tint::ir {
If::If() : Base() {}
If::If(Value* cond) : Base(), condition(cond) {}
If::~If() = default;

View File

@ -30,7 +30,8 @@ namespace tint::ir {
class If : public utils::Castable<If, FlowNode> {
public:
/// Constructor
If();
/// @param cond the if condition
explicit If(Value* cond);
~If() override;
/// The true branch block

View File

@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Switch);
namespace tint::ir {
Switch::Switch() : Base() {}
Switch::Switch(Value* cond) : Base(), condition(cond) {}
Switch::~Switch() = default;

View File

@ -44,7 +44,8 @@ class Switch : public utils::Castable<Switch, FlowNode> {
};
/// Constructor
Switch();
/// @param cond the condition
explicit Switch(Value* cond);
~Switch() override;
/// The switch merge target

View File

@ -20,13 +20,11 @@ namespace tint::writer::spirv {
namespace {
TEST_F(SpvGeneratorImplTest, Binary_Add_I32) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("foo");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
func->start_target->branch.target = func->end_target;
func->start_target->instructions.Push(b.CreateBinary(
ir::Binary::Kind::kAdd, mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i)));
func->start_target->instructions.Push(
b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i)));
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
@ -44,13 +42,11 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Binary_Add_U32) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("foo");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
func->start_target->branch.target = func->end_target;
func->start_target->instructions.Push(b.CreateBinary(
ir::Binary::Kind::kAdd, mod.types.Get<type::U32>(), b.Constant(1_u), b.Constant(2_u)));
func->start_target->instructions.Push(
b.Add(mod.types.Get<type::U32>(), b.Constant(1_u), b.Constant(2_u)));
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
@ -68,13 +64,11 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Binary_Add_F32) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("foo");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
func->start_target->branch.target = func->end_target;
func->start_target->instructions.Push(b.CreateBinary(
ir::Binary::Kind::kAdd, mod.types.Get<type::F32>(), b.Constant(1_f), b.Constant(2_f)));
func->start_target->instructions.Push(
b.Add(mod.types.Get<type::F32>(), b.Constant(1_f), b.Constant(2_f)));
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
@ -92,16 +86,12 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Binary_Add_Chain) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("foo");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
func->start_target->branch.target = func->end_target;
auto* a = b.CreateBinary(ir::Binary::Kind::kAdd, mod.types.Get<type::I32>(), b.Constant(1_i),
b.Constant(2_i));
auto* a = b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(2_i));
func->start_target->instructions.Push(a);
func->start_target->instructions.Push(
b.CreateBinary(ir::Binary::Kind::kAdd, mod.types.Get<type::I32>(), a, a));
func->start_target->instructions.Push(b.Add(mod.types.Get<type::I32>(), a, a));
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"

View File

@ -18,9 +18,7 @@ namespace tint::writer::spirv {
namespace {
TEST_F(SpvGeneratorImplTest, Function_Empty) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("foo");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
func->start_target->branch.target = func->end_target;
generator_.EmitFunction(func);
@ -36,8 +34,7 @@ OpFunctionEnd
// Test that we do not emit the same function type more than once.
TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
auto* func = b.CreateFunction();
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
func->start_target->branch.target = func->end_target;
generator_.EmitFunction(func);
@ -49,9 +46,7 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("main");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
func->pipeline_stage = ir::Function::PipelineStage::kCompute;
func->workgroup_size = {32, 4, 1};
func->start_target->branch.target = func->end_target;
@ -70,9 +65,7 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("main");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
func->pipeline_stage = ir::Function::PipelineStage::kFragment;
func->start_target->branch.target = func->end_target;
@ -90,9 +83,7 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
auto* func = b.CreateFunction();
func->name = mod.symbols.Register("main");
func->return_type = mod.types.Get<type::Void>();
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
func->pipeline_stage = ir::Function::PipelineStage::kVertex;
func->start_target->branch.target = func->end_target;
@ -109,23 +100,17 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
auto* f1 = b.CreateFunction();
f1->name = mod.symbols.Register("main1");
f1->return_type = mod.types.Get<type::Void>();
auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>());
f1->pipeline_stage = ir::Function::PipelineStage::kCompute;
f1->workgroup_size = {32, 4, 1};
f1->start_target->branch.target = f1->end_target;
auto* f2 = b.CreateFunction();
f2->name = mod.symbols.Register("main2");
f2->return_type = mod.types.Get<type::Void>();
auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>());
f2->pipeline_stage = ir::Function::PipelineStage::kCompute;
f2->workgroup_size = {8, 2, 16};
f2->start_target->branch.target = f2->end_target;
auto* f3 = b.CreateFunction();
f3->name = mod.symbols.Register("main3");
f3->return_type = mod.types.Get<type::Void>();
auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>());
f3->pipeline_stage = ir::Function::PipelineStage::kFragment;
f3->start_target->branch.target = f3->end_target;