diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 92c239f071..baca05c0cc 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -2151,6 +2151,7 @@ bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) { if (val_id == 0) { return false; } + val_id = GenerateLoadIfNeeded(stmt->value()->result_type(), val_id); push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)}); } else { push_function_inst(spv::Op::OpReturn, {}); diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index 700d8fe944..266b7b75fd 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -24,6 +24,8 @@ #include "src/ast/type/i32_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -60,10 +62,14 @@ TEST_F(BuilderTest, Function_WithParams) { ast::type::I32Type i32; ast::VariableList params; - params.push_back( - std::make_unique("a", ast::StorageClass::kFunction, &f32)); - params.push_back( - std::make_unique("b", ast::StorageClass::kFunction, &i32)); + auto var_a = + std::make_unique("a", ast::StorageClass::kFunction, &f32); + var_a->set_is_const(true); + params.push_back(std::move(var_a)); + auto var_b = + std::make_unique("b", ast::StorageClass::kFunction, &i32); + var_b->set_is_const(true); + params.push_back(std::move(var_b)); ast::Function func("a_func", std::move(params), &f32); @@ -72,7 +78,13 @@ TEST_F(BuilderTest, Function_WithParams) { std::make_unique("a"))); func.set_body(std::move(body)); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(func.params()[0].get()); + td.RegisterVariableForTesting(func.params()[1].get()); + EXPECT_TRUE(td.DetermineFunction(&func)); + Builder b(&mod); ASSERT_TRUE(b.GenerateFunction(&func)); EXPECT_EQ(DumpBuilder(b), R"(OpName %4 "a_func" @@ -87,7 +99,7 @@ OpName %6 "b" %7 = OpLabel OpReturnValue %5 OpFunctionEnd -)"); +)") << DumpBuilder(b); } TEST_F(BuilderTest, Function_WithBody) { diff --git a/src/writer/spirv/builder_return_test.cc b/src/writer/spirv/builder_return_test.cc index 48933a4a1c..3099ad1215 100644 --- a/src/writer/spirv/builder_return_test.cc +++ b/src/writer/spirv/builder_return_test.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "src/ast/float_literal.h" +#include "src/ast/identifier_expression.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type/f32_type.h" @@ -84,6 +85,39 @@ TEST_F(BuilderTest, Return_WithValue) { )"); } +TEST_F(BuilderTest, Return_WithValue_GeneratesLoad) { + ast::type::F32Type f32; + + ast::Variable var("param", ast::StorageClass::kFunction, &f32); + + ast::ReturnStatement ret( + std::make_unique("param")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&var); + EXPECT_TRUE(td.DetermineResultType(&ret)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); + EXPECT_TRUE(b.GenerateReturnStatement(&ret)) << b.error(); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 +%2 = OpTypePointer Function %3 +%4 = OpConstantNull %3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), + R"(%1 = OpVariable %2 Function %4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%5 = OpLoad %3 %1 +OpReturnValue %5 +)"); +} + } // namespace } // namespace spirv } // namespace writer