[spirv-writer] Handle non-struct entry point return values
Generate a global variable for the return value and replace return statements with assignments to this variable. Add a list of return statements to semantic::Function. Bug: tint:509 Change-Id: I6bc08fcac7858b48f0eff62199d5011665284220 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44804 Commit-Queue: James Price <jrprice@google.com> Auto-Submit: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
417b82291b
commit
4ffd3e2ea5
|
@ -381,12 +381,12 @@ class CloneContext {
|
|||
using CloneableList = std::vector<Cloneable*>;
|
||||
|
||||
/// A map of object in #src to their cloned equivalent in #dst
|
||||
std::unordered_map<Cloneable*, Cloneable*> cloned_;
|
||||
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
|
||||
|
||||
/// A map of object in #src to the list of cloned objects in #dst.
|
||||
/// Clone(const std::vector<T*>& v) will use this to insert the map-value list
|
||||
/// into the target vector/ before cloning and inserting the map-key.
|
||||
std::unordered_map<Cloneable*, CloneableList> insert_before_;
|
||||
std::unordered_map<const Cloneable*, CloneableList> insert_before_;
|
||||
|
||||
/// Cloneable transform functions registered with ReplaceAll()
|
||||
std::vector<CloneableTransform> transforms_;
|
||||
|
|
|
@ -321,6 +321,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
|
|||
});
|
||||
}
|
||||
if (auto* r = stmt->As<ast::ReturnStatement>()) {
|
||||
current_function_->return_statements.push_back(r);
|
||||
return Expression(r->value());
|
||||
}
|
||||
if (auto* s = stmt->As<ast::SwitchStatement>()) {
|
||||
|
@ -1215,7 +1216,7 @@ void Resolver::CreateSemanticNodes() const {
|
|||
|
||||
auto* sem_func = builder_->create<semantic::Function>(
|
||||
info->declaration, remap_vars(info->referenced_module_vars),
|
||||
remap_vars(info->local_referenced_module_vars),
|
||||
remap_vars(info->local_referenced_module_vars), info->return_statements,
|
||||
ancestor_entry_points[func->symbol()]);
|
||||
func_info_to_sem_func.emplace(info, sem_func);
|
||||
sem.Add(func, sem_func);
|
||||
|
|
|
@ -38,6 +38,7 @@ class ConstructorExpression;
|
|||
class Function;
|
||||
class IdentifierExpression;
|
||||
class MemberAccessorExpression;
|
||||
class ReturnStatement;
|
||||
class UnaryOpExpression;
|
||||
class Variable;
|
||||
} // namespace ast
|
||||
|
@ -92,6 +93,7 @@ class Resolver {
|
|||
ast::Function* const declaration;
|
||||
UniqueVector<VariableInfo*> referenced_module_vars;
|
||||
UniqueVector<VariableInfo*> local_referenced_module_vars;
|
||||
std::vector<const ast::ReturnStatement*> return_statements;
|
||||
|
||||
// List of transitive calls this function makes
|
||||
UniqueVector<FunctionInfo*> transitive_calls;
|
||||
|
|
|
@ -858,6 +858,30 @@ TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) {
|
|||
EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Function_ReturnStatements) {
|
||||
auto* var = Var("foo", ty.f32(), ast::StorageClass::kFunction);
|
||||
|
||||
auto* ret_1 = create<ast::ReturnStatement>(Expr(1.f));
|
||||
auto* ret_foo = create<ast::ReturnStatement>(Expr("foo"));
|
||||
|
||||
auto* func = Func("my_func", ast::VariableList{}, ty.f32(),
|
||||
ast::StatementList{
|
||||
create<ast::VariableDeclStatement>(var),
|
||||
If(Expr(true), Block(ret_1)),
|
||||
ret_foo,
|
||||
},
|
||||
ast::DecorationList{});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
auto* func_sem = Sem().Get(func);
|
||||
ASSERT_NE(func_sem, nullptr);
|
||||
|
||||
EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
|
||||
EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
|
||||
EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
|
||||
auto* strct = create<ast::Struct>(
|
||||
ast::StructMemberList{Member("first_member", ty.i32()),
|
||||
|
|
|
@ -29,6 +29,7 @@ class BuiltinDecoration;
|
|||
class Function;
|
||||
class GroupDecoration;
|
||||
class LocationDecoration;
|
||||
class ReturnStatement;
|
||||
} // namespace ast
|
||||
|
||||
namespace semantic {
|
||||
|
@ -53,11 +54,13 @@ class Function : public Castable<Function, CallTarget> {
|
|||
/// @param declaration the ast::Function
|
||||
/// @param referenced_module_vars the referenced module variables
|
||||
/// @param local_referenced_module_vars the locally referenced module
|
||||
/// @param return_statements the function return statements
|
||||
/// variables
|
||||
/// @param ancestor_entry_points the ancestor entry points
|
||||
Function(ast::Function* declaration,
|
||||
std::vector<const Variable*> referenced_module_vars,
|
||||
std::vector<const Variable*> local_referenced_module_vars,
|
||||
std::vector<const ast::ReturnStatement*> return_statements,
|
||||
std::vector<Symbol> ancestor_entry_points);
|
||||
|
||||
/// Destructor
|
||||
|
@ -76,6 +79,10 @@ class Function : public Castable<Function, CallTarget> {
|
|||
const std::vector<const Variable*>& LocalReferencedModuleVariables() const {
|
||||
return local_referenced_module_vars_;
|
||||
}
|
||||
/// @returns the return statements
|
||||
const std::vector<const ast::ReturnStatement*> ReturnStatements() const {
|
||||
return return_statements_;
|
||||
}
|
||||
/// @returns the ancestor entry points
|
||||
const std::vector<Symbol>& AncestorEntryPoints() const {
|
||||
return ancestor_entry_points_;
|
||||
|
@ -148,6 +155,7 @@ class Function : public Castable<Function, CallTarget> {
|
|||
ast::Function* const declaration_;
|
||||
std::vector<const Variable*> const referenced_module_vars_;
|
||||
std::vector<const Variable*> const local_referenced_module_vars_;
|
||||
std::vector<const ast::ReturnStatement*> const return_statements_;
|
||||
std::vector<Symbol> const ancestor_entry_points_;
|
||||
};
|
||||
|
||||
|
|
|
@ -57,11 +57,13 @@ std::tuple<ast::BindingDecoration*, ast::GroupDecoration*> GetBindingAndGroup(
|
|||
Function::Function(ast::Function* declaration,
|
||||
std::vector<const Variable*> referenced_module_vars,
|
||||
std::vector<const Variable*> local_referenced_module_vars,
|
||||
std::vector<const ast::ReturnStatement*> return_statements,
|
||||
std::vector<Symbol> ancestor_entry_points)
|
||||
: Base(declaration->return_type(), GetParameters(declaration)),
|
||||
declaration_(declaration),
|
||||
referenced_module_vars_(std::move(referenced_module_vars)),
|
||||
local_referenced_module_vars_(std::move(local_referenced_module_vars)),
|
||||
return_statements_(std::move(return_statements)),
|
||||
ancestor_entry_points_(std::move(ancestor_entry_points)) {}
|
||||
|
||||
Function::~Function() = default;
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "src/ast/return_statement.h"
|
||||
#include "src/program_builder.h"
|
||||
#include "src/semantic/function.h"
|
||||
#include "src/semantic/variable.h"
|
||||
|
||||
namespace tint {
|
||||
|
@ -102,6 +104,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto* sem_func = ctx.src->Sem().Get(func);
|
||||
|
||||
for (auto* param : func->params()) {
|
||||
// TODO(jrprice): Handle structures by moving the declaration and
|
||||
// construction to the function body.
|
||||
|
@ -126,21 +130,37 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(jrprice): Hoist the return type out to a global variable, and
|
||||
// replace return statements with variable assignments.
|
||||
if (!func->return_type()->Is<type::Void>()) {
|
||||
TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
|
||||
<< "entry point return values are not yet supported";
|
||||
continue;
|
||||
// TODO(jrprice): Handle structures by creating a variable for each member
|
||||
// and replacing return statements with extracts+stores.
|
||||
if (func->return_type()->UnwrapAll()->Is<type::Struct>()) {
|
||||
TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
|
||||
<< "structures as entry point return values are not yet supported";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create a new symbol for the global variable.
|
||||
auto var_symbol = ctx.dst->Symbols().New();
|
||||
// Create the global variable.
|
||||
auto* var = ctx.dst->Var(var_symbol, ctx.Clone(func->return_type()),
|
||||
ast::StorageClass::kOutput, nullptr,
|
||||
ctx.Clone(func->return_type_decorations()));
|
||||
ctx.InsertBefore(func, var);
|
||||
|
||||
// Replace all return statements with stores to the global variable.
|
||||
for (auto* ret : sem_func->ReturnStatements()) {
|
||||
ctx.InsertBefore(
|
||||
ret, ctx.dst->create<ast::AssignmentStatement>(
|
||||
ctx.dst->Expr(var_symbol), ctx.Clone(ret->value())));
|
||||
ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
|
||||
}
|
||||
}
|
||||
|
||||
// Rewrite the function header to remove the parameters.
|
||||
// TODO(jrprice): Change return type to void when return values are handled.
|
||||
// Rewrite the function header to remove the parameters and return value.
|
||||
auto* new_func = ctx.dst->create<ast::Function>(
|
||||
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
|
||||
ctx.Clone(func->return_type()), ctx.Clone(func->body()),
|
||||
ctx.Clone(func->decorations()),
|
||||
ctx.Clone(func->return_type_decorations()));
|
||||
ctx.dst->ty.void_(), ctx.Clone(func->body()),
|
||||
ctx.Clone(func->decorations()), ast::DecorationList{});
|
||||
ctx.Replace(func, new_func);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -86,6 +86,97 @@ fn frag_main() -> void {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnBuiltin) {
|
||||
auto* src = R"(
|
||||
[[stage(vertex)]]
|
||||
fn vert_main() -> [[builtin(position)]] vec4<f32> {
|
||||
return vec4<f32>(1.0, 2.0, 3.0, 0.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn vert_main() -> void {
|
||||
tint_symbol_1 = vec4<f32>(1.0, 2.0, 3.0, 0.0);
|
||||
return;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Spirv>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation) {
|
||||
auto* src = R"(
|
||||
[[stage(fragment)]]
|
||||
fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
|
||||
if (loc_in > 10u) {
|
||||
return 0.5;
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[location(0)]] var<in> tint_symbol_1 : u32;
|
||||
|
||||
[[location(0)]] var<out> tint_symbol_2 : f32;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn frag_main() -> void {
|
||||
if ((tint_symbol_1 > 10u)) {
|
||||
tint_symbol_2 = 0.5;
|
||||
return;
|
||||
}
|
||||
tint_symbol_2 = 1.0;
|
||||
return;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Spirv>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation_TypeAlias) {
|
||||
auto* src = R"(
|
||||
type myf32 = f32;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 {
|
||||
if (loc_in > 10u) {
|
||||
return 0.5;
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
type myf32 = f32;
|
||||
|
||||
[[location(0)]] var<in> tint_symbol_1 : u32;
|
||||
|
||||
[[location(0)]] var<out> tint_symbol_2 : myf32;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn frag_main() -> void {
|
||||
if ((tint_symbol_1 > 10u)) {
|
||||
tint_symbol_2 = 0.5;
|
||||
return;
|
||||
}
|
||||
tint_symbol_2 = 1.0;
|
||||
return;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Spirv>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
|
||||
auto* src = R"(
|
||||
[[builtin(sample_index)]] var<in> sample_index : u32;
|
||||
|
@ -164,27 +255,26 @@ fn main() -> void {
|
|||
|
||||
// Test that different transforms within the sanitizer interact correctly.
|
||||
TEST_F(SpirvTest, MultipleTransforms) {
|
||||
// TODO(jrprice): Make `mask_out` a return value when supported.
|
||||
auto* src = R"(
|
||||
[[builtin(sample_mask_out)]] var<out> mask_out : u32;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn main([[builtin(sample_index)]] sample_index : u32,
|
||||
[[builtin(sample_mask_in)]] mask_in : u32) -> void {
|
||||
mask_out = mask_in;
|
||||
[[builtin(sample_mask_in)]] mask_in : u32)
|
||||
-> [[builtin(sample_mask_out)]] u32 {
|
||||
return mask_in;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
|
||||
|
||||
[[builtin(sample_index)]] var<in> tint_symbol_1 : u32;
|
||||
|
||||
[[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
|
||||
|
||||
[[builtin(sample_mask_out)]] var<out> tint_symbol_3 : array<u32, 1>;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn main() -> void {
|
||||
mask_out[0] = tint_symbol_2[0];
|
||||
tint_symbol_3[0] = tint_symbol_2[0];
|
||||
return;
|
||||
}
|
||||
)";
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "src/ast/builtin.h"
|
||||
#include "src/ast/builtin_decoration.h"
|
||||
#include "src/ast/location_decoration.h"
|
||||
#include "src/ast/return_statement.h"
|
||||
#include "src/ast/stage_decoration.h"
|
||||
#include "src/ast/storage_class.h"
|
||||
#include "src/ast/variable.h"
|
||||
|
@ -96,6 +97,74 @@ OpFunctionEnd
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, EntryPoint_ReturnValue) {
|
||||
// [[stage(fragment)]]
|
||||
// fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
|
||||
// if (loc_in > 10) {
|
||||
// return 0.5;
|
||||
// }
|
||||
// return 1.0;
|
||||
// }
|
||||
auto* f32 = ty.f32();
|
||||
auto* u32 = ty.u32();
|
||||
auto* loc_in = Var("loc_in", u32, ast::StorageClass::kFunction, nullptr,
|
||||
{create<ast::LocationDecoration>(0)});
|
||||
auto* cond = create<ast::BinaryExpression>(ast::BinaryOp::kGreaterThan,
|
||||
Expr("loc_in"), Expr(10u));
|
||||
Func("frag_main", ast::VariableList{loc_in}, f32,
|
||||
ast::StatementList{
|
||||
If(cond, Block(create<ast::ReturnStatement>(Expr(0.5f)))),
|
||||
create<ast::ReturnStatement>(Expr(1.0f)),
|
||||
},
|
||||
ast::DecorationList{
|
||||
create<ast::StageDecoration>(ast::PipelineStage::kFragment),
|
||||
},
|
||||
ast::DecorationList{create<ast::LocationDecoration>(0)});
|
||||
|
||||
spirv::Builder& b = SanitizeAndBuild();
|
||||
|
||||
ASSERT_TRUE(b.Build());
|
||||
|
||||
// Test that the return value gets hoisted out to a global variable with the
|
||||
// Output storage class, and the return statements are replaced with stores.
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %10 "frag_main" %1 %4
|
||||
OpExecutionMode %10 OriginUpperLeft
|
||||
OpName %1 "tint_symbol_1"
|
||||
OpName %4 "tint_symbol_2"
|
||||
OpName %10 "frag_main"
|
||||
OpDecorate %1 Location 0
|
||||
OpDecorate %4 Location 0
|
||||
%3 = OpTypeInt 32 0
|
||||
%2 = OpTypePointer Input %3
|
||||
%1 = OpVariable %2 Input
|
||||
%6 = OpTypeFloat 32
|
||||
%5 = OpTypePointer Output %6
|
||||
%7 = OpConstantNull %6
|
||||
%4 = OpVariable %5 Output %7
|
||||
%9 = OpTypeVoid
|
||||
%8 = OpTypeFunction %9
|
||||
%13 = OpConstant %3 10
|
||||
%15 = OpTypeBool
|
||||
%18 = OpConstant %6 0.5
|
||||
%19 = OpConstant %6 1
|
||||
%10 = OpFunction %9 None %8
|
||||
%11 = OpLabel
|
||||
%12 = OpLoad %3 %1
|
||||
%14 = OpUGreaterThan %15 %12 %13
|
||||
OpSelectionMerge %16 None
|
||||
OpBranchConditional %14 %17 %16
|
||||
%17 = OpLabel
|
||||
OpStore %4 %18
|
||||
OpReturn
|
||||
%16 = OpLabel
|
||||
OpStore %4 %19
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
|
|
Loading…
Reference in New Issue