[spirv-writer] Handle non-struct entry point return values

Generate a global variable for the return value and replace return
statements with assignments to this variable.

Add a list of return statements to semantic::Function.

Bug: tint:509
Change-Id: I6bc08fcac7858b48f0eff62199d5011665284220
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44804
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-03-17 14:24:04 +00:00 committed by Commit Bot service account
parent 417b82291b
commit 4ffd3e2ea5
9 changed files with 237 additions and 21 deletions

View File

@ -381,12 +381,12 @@ class CloneContext {
using CloneableList = std::vector<Cloneable*>;
/// A map of object in #src to their cloned equivalent in #dst
std::unordered_map<Cloneable*, Cloneable*> cloned_;
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
/// A map of object in #src to the list of cloned objects in #dst.
/// Clone(const std::vector<T*>& v) will use this to insert the map-value list
/// into the target vector/ before cloning and inserting the map-key.
std::unordered_map<Cloneable*, CloneableList> insert_before_;
std::unordered_map<const Cloneable*, CloneableList> insert_before_;
/// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_;

View File

@ -321,6 +321,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
});
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
current_function_->return_statements.push_back(r);
return Expression(r->value());
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
@ -1215,7 +1216,7 @@ void Resolver::CreateSemanticNodes() const {
auto* sem_func = builder_->create<semantic::Function>(
info->declaration, remap_vars(info->referenced_module_vars),
remap_vars(info->local_referenced_module_vars),
remap_vars(info->local_referenced_module_vars), info->return_statements,
ancestor_entry_points[func->symbol()]);
func_info_to_sem_func.emplace(info, sem_func);
sem.Add(func, sem_func);

View File

@ -38,6 +38,7 @@ class ConstructorExpression;
class Function;
class IdentifierExpression;
class MemberAccessorExpression;
class ReturnStatement;
class UnaryOpExpression;
class Variable;
} // namespace ast
@ -92,6 +93,7 @@ class Resolver {
ast::Function* const declaration;
UniqueVector<VariableInfo*> referenced_module_vars;
UniqueVector<VariableInfo*> local_referenced_module_vars;
std::vector<const ast::ReturnStatement*> return_statements;
// List of transitive calls this function makes
UniqueVector<FunctionInfo*> transitive_calls;

View File

@ -858,6 +858,30 @@ TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) {
EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
}
TEST_F(ResolverTest, Function_ReturnStatements) {
auto* var = Var("foo", ty.f32(), ast::StorageClass::kFunction);
auto* ret_1 = create<ast::ReturnStatement>(Expr(1.f));
auto* ret_foo = create<ast::ReturnStatement>(Expr("foo"));
auto* func = Func("my_func", ast::VariableList{}, ty.f32(),
ast::StatementList{
create<ast::VariableDeclStatement>(var),
If(Expr(true), Block(ret_1)),
ret_foo,
},
ast::DecorationList{});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo);
}
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
auto* strct = create<ast::Struct>(
ast::StructMemberList{Member("first_member", ty.i32()),

View File

@ -29,6 +29,7 @@ class BuiltinDecoration;
class Function;
class GroupDecoration;
class LocationDecoration;
class ReturnStatement;
} // namespace ast
namespace semantic {
@ -53,11 +54,13 @@ class Function : public Castable<Function, CallTarget> {
/// @param declaration the ast::Function
/// @param referenced_module_vars the referenced module variables
/// @param local_referenced_module_vars the locally referenced module
/// @param return_statements the function return statements
/// variables
/// @param ancestor_entry_points the ancestor entry points
Function(ast::Function* declaration,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
std::vector<const ast::ReturnStatement*> return_statements,
std::vector<Symbol> ancestor_entry_points);
/// Destructor
@ -76,6 +79,10 @@ class Function : public Castable<Function, CallTarget> {
const std::vector<const Variable*>& LocalReferencedModuleVariables() const {
return local_referenced_module_vars_;
}
/// @returns the return statements
const std::vector<const ast::ReturnStatement*> ReturnStatements() const {
return return_statements_;
}
/// @returns the ancestor entry points
const std::vector<Symbol>& AncestorEntryPoints() const {
return ancestor_entry_points_;
@ -148,6 +155,7 @@ class Function : public Castable<Function, CallTarget> {
ast::Function* const declaration_;
std::vector<const Variable*> const referenced_module_vars_;
std::vector<const Variable*> const local_referenced_module_vars_;
std::vector<const ast::ReturnStatement*> const return_statements_;
std::vector<Symbol> const ancestor_entry_points_;
};

View File

@ -57,11 +57,13 @@ std::tuple<ast::BindingDecoration*, ast::GroupDecoration*> GetBindingAndGroup(
Function::Function(ast::Function* declaration,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
std::vector<const ast::ReturnStatement*> return_statements,
std::vector<Symbol> ancestor_entry_points)
: Base(declaration->return_type(), GetParameters(declaration)),
declaration_(declaration),
referenced_module_vars_(std::move(referenced_module_vars)),
local_referenced_module_vars_(std::move(local_referenced_module_vars)),
return_statements_(std::move(return_statements)),
ancestor_entry_points_(std::move(ancestor_entry_points)) {}
Function::~Function() = default;

View File

@ -17,7 +17,9 @@
#include <string>
#include <utility>
#include "src/ast/return_statement.h"
#include "src/program_builder.h"
#include "src/semantic/function.h"
#include "src/semantic/variable.h"
namespace tint {
@ -102,6 +104,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
continue;
}
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* param : func->params()) {
// TODO(jrprice): Handle structures by moving the declaration and
// construction to the function body.
@ -126,21 +130,37 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
}
}
// TODO(jrprice): Hoist the return type out to a global variable, and
// replace return statements with variable assignments.
if (!func->return_type()->Is<type::Void>()) {
TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
<< "entry point return values are not yet supported";
continue;
// TODO(jrprice): Handle structures by creating a variable for each member
// and replacing return statements with extracts+stores.
if (func->return_type()->UnwrapAll()->Is<type::Struct>()) {
TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
<< "structures as entry point return values are not yet supported";
continue;
}
// Create a new symbol for the global variable.
auto var_symbol = ctx.dst->Symbols().New();
// Create the global variable.
auto* var = ctx.dst->Var(var_symbol, ctx.Clone(func->return_type()),
ast::StorageClass::kOutput, nullptr,
ctx.Clone(func->return_type_decorations()));
ctx.InsertBefore(func, var);
// Replace all return statements with stores to the global variable.
for (auto* ret : sem_func->ReturnStatements()) {
ctx.InsertBefore(
ret, ctx.dst->create<ast::AssignmentStatement>(
ctx.dst->Expr(var_symbol), ctx.Clone(ret->value())));
ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
}
}
// Rewrite the function header to remove the parameters.
// TODO(jrprice): Change return type to void when return values are handled.
// Rewrite the function header to remove the parameters and return value.
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
ctx.Clone(func->return_type()), ctx.Clone(func->body()),
ctx.Clone(func->decorations()),
ctx.Clone(func->return_type_decorations()));
ctx.dst->ty.void_(), ctx.Clone(func->body()),
ctx.Clone(func->decorations()), ast::DecorationList{});
ctx.Replace(func, new_func);
}
}

View File

@ -86,6 +86,97 @@ fn frag_main() -> void {
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnBuiltin) {
auto* src = R"(
[[stage(vertex)]]
fn vert_main() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>(1.0, 2.0, 3.0, 0.0);
}
)";
auto* expect = R"(
[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
[[stage(vertex)]]
fn vert_main() -> void {
tint_symbol_1 = vec4<f32>(1.0, 2.0, 3.0, 0.0);
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation) {
auto* src = R"(
[[stage(fragment)]]
fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
if (loc_in > 10u) {
return 0.5;
}
return 1.0;
}
)";
auto* expect = R"(
[[location(0)]] var<in> tint_symbol_1 : u32;
[[location(0)]] var<out> tint_symbol_2 : f32;
[[stage(fragment)]]
fn frag_main() -> void {
if ((tint_symbol_1 > 10u)) {
tint_symbol_2 = 0.5;
return;
}
tint_symbol_2 = 1.0;
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation_TypeAlias) {
auto* src = R"(
type myf32 = f32;
[[stage(fragment)]]
fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 {
if (loc_in > 10u) {
return 0.5;
}
return 1.0;
}
)";
auto* expect = R"(
type myf32 = f32;
[[location(0)]] var<in> tint_symbol_1 : u32;
[[location(0)]] var<out> tint_symbol_2 : myf32;
[[stage(fragment)]]
fn frag_main() -> void {
if ((tint_symbol_1 > 10u)) {
tint_symbol_2 = 0.5;
return;
}
tint_symbol_2 = 1.0;
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
auto* src = R"(
[[builtin(sample_index)]] var<in> sample_index : u32;
@ -164,27 +255,26 @@ fn main() -> void {
// Test that different transforms within the sanitizer interact correctly.
TEST_F(SpirvTest, MultipleTransforms) {
// TODO(jrprice): Make `mask_out` a return value when supported.
auto* src = R"(
[[builtin(sample_mask_out)]] var<out> mask_out : u32;
[[stage(fragment)]]
fn main([[builtin(sample_index)]] sample_index : u32,
[[builtin(sample_mask_in)]] mask_in : u32) -> void {
mask_out = mask_in;
[[builtin(sample_mask_in)]] mask_in : u32)
-> [[builtin(sample_mask_out)]] u32 {
return mask_in;
}
)";
auto* expect = R"(
[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
[[builtin(sample_index)]] var<in> tint_symbol_1 : u32;
[[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
[[builtin(sample_mask_out)]] var<out> tint_symbol_3 : array<u32, 1>;
[[stage(fragment)]]
fn main() -> void {
mask_out[0] = tint_symbol_2[0];
tint_symbol_3[0] = tint_symbol_2[0];
return;
}
)";

View File

@ -18,6 +18,7 @@
#include "src/ast/builtin.h"
#include "src/ast/builtin_decoration.h"
#include "src/ast/location_decoration.h"
#include "src/ast/return_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/storage_class.h"
#include "src/ast/variable.h"
@ -96,6 +97,74 @@ OpFunctionEnd
)");
}
TEST_F(BuilderTest, EntryPoint_ReturnValue) {
// [[stage(fragment)]]
// fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
// if (loc_in > 10) {
// return 0.5;
// }
// return 1.0;
// }
auto* f32 = ty.f32();
auto* u32 = ty.u32();
auto* loc_in = Var("loc_in", u32, ast::StorageClass::kFunction, nullptr,
{create<ast::LocationDecoration>(0)});
auto* cond = create<ast::BinaryExpression>(ast::BinaryOp::kGreaterThan,
Expr("loc_in"), Expr(10u));
Func("frag_main", ast::VariableList{loc_in}, f32,
ast::StatementList{
If(cond, Block(create<ast::ReturnStatement>(Expr(0.5f)))),
create<ast::ReturnStatement>(Expr(1.0f)),
},
ast::DecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment),
},
ast::DecorationList{create<ast::LocationDecoration>(0)});
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build());
// Test that the return value gets hoisted out to a global variable with the
// Output storage class, and the return statements are replaced with stores.
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %10 "frag_main" %1 %4
OpExecutionMode %10 OriginUpperLeft
OpName %1 "tint_symbol_1"
OpName %4 "tint_symbol_2"
OpName %10 "frag_main"
OpDecorate %1 Location 0
OpDecorate %4 Location 0
%3 = OpTypeInt 32 0
%2 = OpTypePointer Input %3
%1 = OpVariable %2 Input
%6 = OpTypeFloat 32
%5 = OpTypePointer Output %6
%7 = OpConstantNull %6
%4 = OpVariable %5 Output %7
%9 = OpTypeVoid
%8 = OpTypeFunction %9
%13 = OpConstant %3 10
%15 = OpTypeBool
%18 = OpConstant %6 0.5
%19 = OpConstant %6 1
%10 = OpFunction %9 None %8
%11 = OpLabel
%12 = OpLoad %3 %1
%14 = OpUGreaterThan %15 %12 %13
OpSelectionMerge %16 None
OpBranchConditional %14 %17 %16
%17 = OpLabel
OpStore %4 %18
OpReturn
%16 = OpLabel
OpStore %4 %19
OpReturn
OpFunctionEnd
)");
}
} // namespace
} // namespace spirv
} // namespace writer