diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc index d8d456abf4..e25393730a 100644 --- a/src/tint/ir/builder.cc +++ b/src/tint/ir/builder.cc @@ -47,8 +47,10 @@ FunctionTerminator* Builder::CreateFunctionTerminator() { return ir.flow_nodes.Create(); } -Function* Builder::CreateFunction() { - auto* ir_func = ir.flow_nodes.Create(); +Function* Builder::CreateFunction(Symbol name, type::Type* return_type) { + TINT_ASSERT(IR, return_type); + + auto* ir_func = ir.flow_nodes.Create(name, return_type); ir_func->start_target = CreateBlock(); ir_func->end_target = CreateFunctionTerminator(); @@ -58,8 +60,10 @@ Function* Builder::CreateFunction() { return ir_func; } -If* Builder::CreateIf() { - auto* ir_if = ir.flow_nodes.Create(); +If* Builder::CreateIf(Value* condition) { + TINT_ASSERT(IR, condition); + + auto* ir_if = ir.flow_nodes.Create(condition); ir_if->true_.target = CreateBlock(); ir_if->false_.target = CreateBlock(); ir_if->merge.target = CreateBlock(); @@ -83,8 +87,8 @@ Loop* Builder::CreateLoop() { return ir_loop; } -Switch* Builder::CreateSwitch() { - auto* ir_switch = ir.flow_nodes.Create(); +Switch* Builder::CreateSwitch(Value* condition) { + auto* ir_switch = ir.flow_nodes.Create(condition); ir_switch->merge.target = CreateBlock(); return ir_switch; } diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h index 6dd555d43c..6f415c1ec4 100644 --- a/src/tint/ir/builder.h +++ b/src/tint/ir/builder.h @@ -65,20 +65,24 @@ class Builder { FunctionTerminator* CreateFunctionTerminator(); /// Creates a function flow node + /// @param name the function name + /// @param return_type the function return type /// @returns the flow node - Function* CreateFunction(); + Function* CreateFunction(Symbol name, type::Type* return_type); /// Creates an if flow node + /// @param condition the if condition /// @returns the flow node - If* CreateIf(); + If* CreateIf(Value* condition); /// Creates a loop flow node /// @returns the flow node Loop* CreateLoop(); /// Creates a switch flow node + /// @param condition the switch condition /// @returns the flow node - Switch* CreateSwitch(); + Switch* CreateSwitch(Value* condition); /// Creates a case flow node for the given case branch. /// @param s the switch to create the case into diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc index 46b4cb3046..57cd02c465 100644 --- a/src/tint/ir/builder_impl.cc +++ b/src/tint/ir/builder_impl.cc @@ -205,14 +205,16 @@ void BuilderImpl::EmitFunction(const ast::Function* ast_func) { // The flow stack should have been emptied when the previous function finished building. TINT_ASSERT(IR, flow_stack.IsEmpty()); - auto* ir_func = builder.CreateFunction(); - ir_func->name = CloneSymbol(ast_func->name->symbol); + const auto* sem = program_->Sem().Get(ast_func); + + auto* ir_func = builder.CreateFunction(CloneSymbol(ast_func->name->symbol), + sem->ReturnType()->Clone(clone_ctx_.type_ctx)); + current_function_ = ir_func; builder.ir.functions.Push(ir_func); ast_to_flow_[ast_func] = ir_func; - const auto* sem = program_->Sem().Get(ast_func); if (ast_func->IsEntryPoint()) { builder.ir.entry_points.Push(ir_func); @@ -280,7 +282,6 @@ void BuilderImpl::EmitFunction(const ast::Function* ast_func) { }); } } - ir_func->return_type = sem->ReturnType()->Clone(clone_ctx_.type_ctx); ir_func->return_location = sem->ReturnLocation(); { @@ -430,14 +431,12 @@ void BuilderImpl::EmitBlock(const ast::BlockStatement* block) { } void BuilderImpl::EmitIf(const ast::IfStatement* stmt) { - auto* if_node = builder.CreateIf(); - // Emit the if condition into the end of the preceding block auto reg = EmitExpression(stmt->condition); if (!reg) { return; } - if_node->condition = reg.Get(); + auto* if_node = builder.CreateIf(reg.Get()); BranchTo(if_node); @@ -525,13 +524,12 @@ void BuilderImpl::EmitWhile(const ast::WhileStatement* stmt) { } // Create an `if (cond) {} else {break;}` control flow - auto* if_node = builder.CreateIf(); + auto* if_node = builder.CreateIf(reg.Get()); TINT_ASSERT(IR, if_node->true_.target->Is()); builder.Branch(if_node->true_.target->As(), if_node->merge.target, utils::Empty); TINT_ASSERT(IR, if_node->false_.target->Is()); builder.Branch(if_node->false_.target->As(), loop_node->merge.target, utils::Empty); - if_node->condition = reg.Get(); BranchTo(if_node); @@ -577,14 +575,14 @@ void BuilderImpl::EmitForLoop(const ast::ForLoopStatement* stmt) { } // Create an `if (cond) {} else {break;}` control flow - auto* if_node = builder.CreateIf(); + auto* if_node = builder.CreateIf(reg.Get()); + TINT_ASSERT(IR, if_node->true_.target->Is()); builder.Branch(if_node->true_.target->As(), if_node->merge.target, utils::Empty); TINT_ASSERT(IR, if_node->false_.target->Is()); builder.Branch(if_node->false_.target->As(), loop_node->merge.target, utils::Empty); - if_node->condition = reg.Get(); BranchTo(if_node); current_flow_block = if_node->merge.target->As(); @@ -605,14 +603,12 @@ void BuilderImpl::EmitForLoop(const ast::ForLoopStatement* stmt) { } void BuilderImpl::EmitSwitch(const ast::SwitchStatement* stmt) { - auto* switch_node = builder.CreateSwitch(); - // Emit the condition into the preceding block auto reg = EmitExpression(stmt->condition); if (!reg) { return; } - switch_node->condition = reg.Get(); + auto* switch_node = builder.CreateSwitch(reg.Get()); BranchTo(switch_node); @@ -692,14 +688,12 @@ void BuilderImpl::EmitDiscard(const ast::DiscardStatement*) { } void BuilderImpl::EmitBreakIf(const ast::BreakIfStatement* stmt) { - auto* if_node = builder.CreateIf(); - // Emit the break-if condition into the end of the preceding block auto reg = EmitExpression(stmt->condition); if (!reg) { return; } - if_node->condition = reg.Get(); + auto* if_node = builder.CreateIf(reg.Get()); BranchTo(if_node); @@ -876,8 +870,7 @@ utils::Result BuilderImpl::EmitShortCircuit(const ast::BinaryExpression* auto* lhs_store = builder.Store(result_var, lhs.Get()); current_flow_block->instructions.Push(lhs_store); - auto* if_node = builder.CreateIf(); - if_node->condition = lhs.Get(); + auto* if_node = builder.CreateIf(lhs.Get()); BranchTo(if_node); utils::Result rhs; diff --git a/src/tint/ir/function.cc b/src/tint/ir/function.cc index a03812e596..59a04f7714 100644 --- a/src/tint/ir/function.cc +++ b/src/tint/ir/function.cc @@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Function); namespace tint::ir { -Function::Function() : Base() {} +Function::Function(Symbol n, type::Type* rt) : Base(), name(n), return_type(rt) {} Function::~Function() = default; diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h index e3d5529305..06aa29ec4f 100644 --- a/src/tint/ir/function.h +++ b/src/tint/ir/function.h @@ -62,7 +62,9 @@ class Function : public utils::Castable { }; /// Constructor - Function(); + /// @param n the function name + /// @param rt the function return type + Function(Symbol n, type::Type* rt); ~Function() override; /// The function name diff --git a/src/tint/ir/if.cc b/src/tint/ir/if.cc index 22a90787bb..b59d87f56d 100644 --- a/src/tint/ir/if.cc +++ b/src/tint/ir/if.cc @@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::If); namespace tint::ir { -If::If() : Base() {} +If::If(Value* cond) : Base(), condition(cond) {} If::~If() = default; diff --git a/src/tint/ir/if.h b/src/tint/ir/if.h index 255b165851..aadc5c9912 100644 --- a/src/tint/ir/if.h +++ b/src/tint/ir/if.h @@ -30,7 +30,8 @@ namespace tint::ir { class If : public utils::Castable { public: /// Constructor - If(); + /// @param cond the if condition + explicit If(Value* cond); ~If() override; /// The true branch block diff --git a/src/tint/ir/switch.cc b/src/tint/ir/switch.cc index 9ad6d3030f..ad6a1453f7 100644 --- a/src/tint/ir/switch.cc +++ b/src/tint/ir/switch.cc @@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Switch); namespace tint::ir { -Switch::Switch() : Base() {} +Switch::Switch(Value* cond) : Base(), condition(cond) {} Switch::~Switch() = default; diff --git a/src/tint/ir/switch.h b/src/tint/ir/switch.h index 2ff927ef25..6be7b62efc 100644 --- a/src/tint/ir/switch.h +++ b/src/tint/ir/switch.h @@ -44,7 +44,8 @@ class Switch : public utils::Castable { }; /// Constructor - Switch(); + /// @param cond the condition + explicit Switch(Value* cond); ~Switch() override; /// The switch merge target diff --git a/src/tint/writer/spirv/generator_impl_binary_test.cc b/src/tint/writer/spirv/generator_impl_binary_test.cc index 3b2f5eb099..578109fdf7 100644 --- a/src/tint/writer/spirv/generator_impl_binary_test.cc +++ b/src/tint/writer/spirv/generator_impl_binary_test.cc @@ -20,13 +20,11 @@ namespace tint::writer::spirv { namespace { TEST_F(SpvGeneratorImplTest, Binary_Add_I32) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("foo"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); func->start_target->branch.target = func->end_target; - func->start_target->instructions.Push(b.CreateBinary( - ir::Binary::Kind::kAdd, mod.types.Get(), b.Constant(1_i), b.Constant(2_i))); + func->start_target->instructions.Push( + b.Add(mod.types.Get(), b.Constant(1_i), b.Constant(2_i))); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" @@ -44,13 +42,11 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Add_U32) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("foo"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); func->start_target->branch.target = func->end_target; - func->start_target->instructions.Push(b.CreateBinary( - ir::Binary::Kind::kAdd, mod.types.Get(), b.Constant(1_u), b.Constant(2_u))); + func->start_target->instructions.Push( + b.Add(mod.types.Get(), b.Constant(1_u), b.Constant(2_u))); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" @@ -68,13 +64,11 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Add_F32) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("foo"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); func->start_target->branch.target = func->end_target; - func->start_target->instructions.Push(b.CreateBinary( - ir::Binary::Kind::kAdd, mod.types.Get(), b.Constant(1_f), b.Constant(2_f))); + func->start_target->instructions.Push( + b.Add(mod.types.Get(), b.Constant(1_f), b.Constant(2_f))); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" @@ -92,16 +86,12 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Add_Chain) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("foo"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); func->start_target->branch.target = func->end_target; - auto* a = b.CreateBinary(ir::Binary::Kind::kAdd, mod.types.Get(), b.Constant(1_i), - b.Constant(2_i)); + auto* a = b.Add(mod.types.Get(), b.Constant(1_i), b.Constant(2_i)); func->start_target->instructions.Push(a); - func->start_target->instructions.Push( - b.CreateBinary(ir::Binary::Kind::kAdd, mod.types.Get(), a, a)); + func->start_target->instructions.Push(b.Add(mod.types.Get(), a, a)); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" diff --git a/src/tint/writer/spirv/generator_impl_function_test.cc b/src/tint/writer/spirv/generator_impl_function_test.cc index fd08165dcb..283190ded1 100644 --- a/src/tint/writer/spirv/generator_impl_function_test.cc +++ b/src/tint/writer/spirv/generator_impl_function_test.cc @@ -18,9 +18,7 @@ namespace tint::writer::spirv { namespace { TEST_F(SpvGeneratorImplTest, Function_Empty) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("foo"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); func->start_target->branch.target = func->end_target; generator_.EmitFunction(func); @@ -36,8 +34,7 @@ OpFunctionEnd // Test that we do not emit the same function type more than once. TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) { - auto* func = b.CreateFunction(); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); func->start_target->branch.target = func->end_target; generator_.EmitFunction(func); @@ -49,9 +46,7 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) { } TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("main"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get()); func->pipeline_stage = ir::Function::PipelineStage::kCompute; func->workgroup_size = {32, 4, 1}; func->start_target->branch.target = func->end_target; @@ -70,9 +65,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("main"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get()); func->pipeline_stage = ir::Function::PipelineStage::kFragment; func->start_target->branch.target = func->end_target; @@ -90,9 +83,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) { - auto* func = b.CreateFunction(); - func->name = mod.symbols.Register("main"); - func->return_type = mod.types.Get(); + auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get()); func->pipeline_stage = ir::Function::PipelineStage::kVertex; func->start_target->branch.target = func->end_target; @@ -109,23 +100,17 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) { - auto* f1 = b.CreateFunction(); - f1->name = mod.symbols.Register("main1"); - f1->return_type = mod.types.Get(); + auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get()); f1->pipeline_stage = ir::Function::PipelineStage::kCompute; f1->workgroup_size = {32, 4, 1}; f1->start_target->branch.target = f1->end_target; - auto* f2 = b.CreateFunction(); - f2->name = mod.symbols.Register("main2"); - f2->return_type = mod.types.Get(); + auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get()); f2->pipeline_stage = ir::Function::PipelineStage::kCompute; f2->workgroup_size = {8, 2, 16}; f2->start_target->branch.target = f2->end_target; - auto* f3 = b.CreateFunction(); - f3->name = mod.symbols.Register("main3"); - f3->return_type = mod.types.Get(); + auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get()); f3->pipeline_stage = ir::Function::PipelineStage::kFragment; f3->start_target->branch.target = f3->end_target;