diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 3f2d289484..de986e7bdd 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -104,8 +104,6 @@ bool LastIsTerminator(const ast::StatementList& stmts) { } auto* last = stmts.back().get(); - // TODO(dneto): Conditional break and conditional continue should return - // false. return last->IsBreak() || last->IsContinue() || last->IsReturn() || last->IsKill() || last->IsFallthrough(); } @@ -391,16 +389,32 @@ bool Builder::GenerateFunction(ast::Function* func) { return false; } - // TODO(dsinclair): Handle parameters + scope_stack_.push_scope(); auto definition_inst = Instruction{ spv::Op::OpFunction, {Operand::Int(ret_id), func_op, Operand::Int(SpvFunctionControlMaskNone), Operand::Int(func_type_id)}}; - std::vector params; - push_function(Function{definition_inst, result_op(), std::move(params)}); - scope_stack_.push_scope(); + std::vector params; + for (const auto& param : func->params()) { + auto param_op = result_op(); + auto param_id = param_op.to_i(); + + auto param_type_id = GenerateTypeIfNeeded(param->type()); + if (param_type_id == 0) { + return false; + } + + push_debug(spv::Op::OpName, + {Operand::Int(param_id), Operand::String(param->name())}); + params.push_back(Instruction{spv::Op::OpFunctionParameter, + {Operand::Int(param_type_id), param_op}}); + + scope_stack_.set(param->name(), param_id); + } + + push_function(Function{definition_inst, result_op(), std::move(params)}); for (const auto& stmt : func->body()) { if (!GenerateStatement(stmt.get())) { @@ -428,8 +442,16 @@ uint32_t Builder::GenerateFunctionTypeIfNeeded(ast::Function* func) { return 0; } - // TODO(dsinclair): Handle parameters - push_type(spv::Op::OpTypeFunction, {func_op, Operand::Int(ret_id)}); + std::vector ops = {func_op, Operand::Int(ret_id)}; + for (const auto& param : func->params()) { + auto param_type_id = GenerateTypeIfNeeded(param->type()); + if (param_type_id == 0) { + return 0; + } + ops.push_back(Operand::Int(param_type_id)); + } + + push_type(spv::Op::OpTypeFunction, std::move(ops)); type_name_to_id_[func->type_name()] = func_type_id; return func_type_id; diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index 4ee186347b..a0fb76f3a7 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -18,8 +18,12 @@ #include "spirv/unified1/spirv.h" #include "spirv/unified1/spirv.hpp11" #include "src/ast/function.h" +#include "src/ast/identifier_expression.h" #include "src/ast/return_statement.h" +#include "src/ast/type/f32_type.h" +#include "src/ast/type/i32_type.h" #include "src/ast/type/void_type.h" +#include "src/ast/variable.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -50,7 +54,41 @@ TEST_F(BuilderTest, Function_Empty) { )"); } -TEST_F(BuilderTest, DISABLED_Function_WithParams) {} +TEST_F(BuilderTest, Function_WithParams) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::VariableList params; + params.push_back( + std::make_unique("a", ast::StorageClass::kFunction, &f32)); + params.push_back( + std::make_unique("b", ast::StorageClass::kFunction, &i32)); + + ast::Function func("a_func", std::move(params), &f32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("a"))); + func.set_body(std::move(body)); + + ast::Module mod; + Builder b(&mod); + ASSERT_TRUE(b.GenerateFunction(&func)); + EXPECT_EQ(DumpBuilder(b), R"(OpName %4 "a_func" +OpName %5 "a" +OpName %6 "b" +%2 = OpTypeFloat 32 +%3 = OpTypeInt 32 1 +%1 = OpTypeFunction %2 %2 %3 +%4 = OpFunction %2 None %1 +%5 = OpFunctionParameter %2 +%6 = OpFunctionParameter %3 +%7 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"); +} TEST_F(BuilderTest, Function_WithBody) { ast::type::VoidType void_type;