transform/shader_io: Generate a wrapper function

This is a major reworking of this transform. The old transform code
was getting unwieldy, with part of the complication coming from the
handling of multiple return statements. By generating a wrapper
function instead, we can avoid a lot of this complexity.

The original entry point function is stripped of all shader IO
attributes (as well as `stage` and `workgroup_size`), but the body is
left unmodified. A new entry point wrapper function is introduced
which calls the original function, packing/unpacking the shader inputs
as necessary, and propagates the result to the corresponding shader
outputs.

The new code has been refactored to use a state object with the
different parts of the transform split into separate functions, which
makes it much more manageable.

Fixed: tint:1076
Bug: tint:920
Change-Id: I3490a0ea7a3509a4e198ce730e476516649d8d96
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/60521
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price
2021-08-04 22:15:28 +00:00
committed by Tint LUCI CQ
parent 3e92e9f8ba
commit a5d73ce965
3866 changed files with 49323 additions and 26508 deletions

View File

@@ -16,14 +16,12 @@
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/program_builder.h"
#include "src/sem/block_statement.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO);
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config);
@@ -62,8 +60,396 @@ bool StructMemberComparator(const ast::StructMember* a,
}
}
// Returns true if `deco` is a shader IO decoration.
bool IsShaderIODecoration(const ast::Decoration* deco) {
return deco->IsAnyOf<ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
}
} // namespace
/// State holds the current transform state for a single entry point.
struct CanonicalizeEntryPointIO::State {
/// OutputValue represents a shader result that the wrapper function produces.
struct OutputValue {
/// The name of the output value.
std::string name;
/// The type of the output value.
ast::Type* type;
/// The shader IO attributes.
ast::DecorationList attributes;
/// The value itself.
ast::Expression* value;
};
/// The clone context.
CloneContext& ctx;
/// The transform config.
CanonicalizeEntryPointIO::Config const cfg;
/// The entry point function (AST).
ast::Function* func_ast;
/// The entry point function (SEM).
sem::Function const* func_sem;
/// The new entry point wrapper function's parameters.
ast::VariableList wrapper_ep_parameters;
/// The members of the wrapper function's struct parameter.
ast::StructMemberList wrapper_struct_param_members;
/// The name of the wrapper function's struct parameter.
Symbol wrapper_struct_param_name;
/// The parameters that will be passed to the original function.
ast::ExpressionList inner_call_parameters;
/// The members of the wrapper function's struct return type.
ast::StructMemberList wrapper_struct_output_members;
/// The wrapper function output values.
std::vector<OutputValue> wrapper_output_values;
/// The body of the wrapper function.
ast::StatementList wrapper_body;
/// Constructor
/// @param context the clone context
/// @param config the transform config
/// @param function the entry point function
State(CloneContext& context,
const CanonicalizeEntryPointIO::Config& config,
ast::Function* function)
: ctx(context),
cfg(config),
func_ast(function),
func_sem(ctx.src->Sem().Get(function)) {}
/// Clones the shader IO decorations from `src`.
/// @param src the decorations to clone
/// @return the cloned decorations
ast::DecorationList CloneShaderIOAttributes(const ast::DecorationList& src) {
ast::DecorationList new_decorations;
for (auto* deco : src) {
if (IsShaderIODecoration(deco)) {
new_decorations.push_back(ctx.Clone(deco));
}
}
return new_decorations;
}
/// Create or return a symbol for the wrapper function's struct parameter.
/// @returns the symbol for the struct parameter
Symbol InputStructSymbol() {
if (!wrapper_struct_param_name.IsValid()) {
wrapper_struct_param_name = ctx.dst->Sym();
}
return wrapper_struct_param_name;
}
/// Add a shader input to the entry point.
/// @param name the name of the shader input
/// @param type the type of the shader input
/// @param attributes the attributes to apply to the shader input
/// @returns an expression which evaluates to the value of the shader input
ast::Expression* AddInput(std::string name,
ast::Type* type,
ast::DecorationList attributes) {
if (cfg.builtin_style == BuiltinStyle::kParameter &&
ast::HasDecoration<ast::BuiltinDecoration>(attributes)) {
// If this input is a builtin and we are emitting those as parameters,
// then add it to the parameter list and pass it directly to the inner
// function.
wrapper_ep_parameters.push_back(
ctx.dst->Param(name, type, std::move(attributes)));
return ctx.dst->Expr(name);
} else {
// Otherwise, move it to the new structure member list.
wrapper_struct_param_members.push_back(
ctx.dst->Member(name, type, std::move(attributes)));
return ctx.dst->MemberAccessor(InputStructSymbol(), name);
}
}
/// Add a shader output to the entry point.
/// @param name the name of the shader output
/// @param type the type of the shader output
/// @param attributes the attributes to apply to the shader output
/// @param value the value of the shader output
void AddOutput(std::string name,
ast::Type* type,
ast::DecorationList attributes,
ast::Expression* value) {
OutputValue output;
output.name = name;
output.type = type;
output.attributes = std::move(attributes);
output.value = value;
wrapper_output_values.push_back(output);
}
/// Process a non-struct parameter.
/// This creates a new object for the shader input, moving the shader IO
/// attributes to it. It also adds an expression to the list of parameters
/// that will be passed to the original function.
/// @param param the original function parameter
void ProcessNonStructParameter(const sem::Parameter* param) {
// Remove the shader IO attributes from the inner function parameter, and
// attach them to the new object instead.
ast::DecorationList attributes;
for (auto* deco : param->Declaration()->decorations()) {
if (IsShaderIODecoration(deco)) {
ctx.Remove(param->Declaration()->decorations(), deco);
attributes.push_back(ctx.Clone(deco));
}
}
auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol());
auto* type = ctx.Clone(param->Declaration()->type());
auto* input_expr = AddInput(name, type, std::move(attributes));
inner_call_parameters.push_back(input_expr);
}
/// Process a struct parameter.
/// This creates new objects for each struct member, moving the shader IO
/// attributes to them. It also creates the structure that will be passed to
/// the original function.
/// @param param the original function parameter
void ProcessStructParameter(const sem::Parameter* param) {
auto* str = param->Type()->As<sem::Struct>();
// Recreate struct members in the outer entry point and build an initializer
// list to pass them through to the inner function.
ast::ExpressionList inner_struct_values;
for (auto* member : str->Members()) {
if (member->Type()->Is<sem::Struct>()) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
continue;
}
auto* member_ast = member->Declaration();
auto name = ctx.src->Symbols().NameFor(member_ast->symbol());
auto* type = ctx.Clone(member_ast->type());
auto attributes = CloneShaderIOAttributes(member_ast->decorations());
auto* input_expr = AddInput(name, type, std::move(attributes));
inner_struct_values.push_back(input_expr);
}
// Construct the original structure using the new shader input objects.
inner_call_parameters.push_back(ctx.dst->Construct(
ctx.Clone(param->Declaration()->type()), inner_struct_values));
}
/// Process the entry point return type.
/// This generates a list of output values that are returned by the original
/// function.
/// @param inner_ret_type the original function return type
/// @param original_result the result object produced by the original function
void ProcessReturnType(const sem::Type* inner_ret_type,
Symbol original_result) {
if (auto* str = inner_ret_type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (member->Type()->Is<sem::Struct>()) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
continue;
}
auto* member_ast = member->Declaration();
auto name = ctx.src->Symbols().NameFor(member_ast->symbol());
auto* type = ctx.Clone(member_ast->type());
auto attributes = CloneShaderIOAttributes(member_ast->decorations());
// Extract the original structure member.
AddOutput(name, type, std::move(attributes),
ctx.dst->MemberAccessor(original_result, name));
}
} else if (!inner_ret_type->Is<sem::Void>()) {
auto* type = ctx.Clone(func_ast->return_type());
auto attributes =
CloneShaderIOAttributes(func_ast->return_type_decorations());
// Propagate the non-struct return value as is.
AddOutput("value", type, std::move(attributes),
ctx.dst->Expr(original_result));
}
}
/// Add a fixed sample mask to the wrapper function output.
/// If there is already a sample mask, bitwise-and it with the fixed mask.
/// Otherwise, create a new output value from the fixed mask.
void AddFixedSampleMask() {
// Check the existing output values for a sample mask builtin.
for (auto& outval : wrapper_output_values) {
auto* builtin =
ast::GetDecoration<ast::BuiltinDecoration>(outval.attributes);
if (builtin && builtin->value() == ast::Builtin::kSampleMask) {
// Combine the authored sample mask with the fixed mask.
outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
return;
}
}
// No existing sample mask builtin was found, so create a new output value
// using the fixed sample mask.
AddOutput("fixed_sample_mask", ctx.dst->ty.u32(),
{ctx.dst->Builtin(ast::Builtin::kSampleMask)},
ctx.dst->Expr(cfg.fixed_sample_mask));
}
/// Create the wrapper function's struct parameter and type objects.
void CreateInputStruct() {
// Sort the struct members to satisfy HLSL interfacing matching rules.
std::sort(wrapper_struct_param_members.begin(),
wrapper_struct_param_members.end(), StructMemberComparator);
// Create the new struct type.
auto struct_name = ctx.dst->Sym();
auto* in_struct = ctx.dst->create<ast::Struct>(
struct_name, wrapper_struct_param_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
// Create a new function parameter using this struct type.
auto* param =
ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name));
wrapper_ep_parameters.push_back(param);
}
/// Create and return the wrapper function's struct result object.
/// @returns the struct type
ast::Struct* CreateOutputStruct() {
ast::StatementList assignments;
auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
// Create the struct members and their corresponding assignment statements.
std::unordered_set<std::string> member_names;
for (auto& outval : wrapper_output_values) {
// Use the original output name, unless that is already taken.
Symbol name;
if (member_names.count(outval.name)) {
name = ctx.dst->Symbols().New(outval.name);
} else {
name = ctx.dst->Symbols().Register(outval.name);
}
member_names.insert(ctx.dst->Symbols().NameFor(name));
wrapper_struct_output_members.push_back(
ctx.dst->Member(name, outval.type, std::move(outval.attributes)));
assignments.push_back(ctx.dst->Assign(
ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
}
// Sort the struct members to satisfy HLSL interfacing matching rules.
std::sort(wrapper_struct_output_members.begin(),
wrapper_struct_output_members.end(), StructMemberComparator);
// Create the new struct type.
auto* out_struct = ctx.dst->create<ast::Struct>(
ctx.dst->Sym(), wrapper_struct_output_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
// Create the output struct object, assign its members, and return it.
auto* result_object =
ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name()));
wrapper_body.push_back(ctx.dst->Decl(result_object));
wrapper_body.insert(wrapper_body.end(), assignments.begin(),
assignments.end());
wrapper_body.push_back(ctx.dst->Return(wrapper_result));
return out_struct;
}
// Recreate the original function without entry point attributes and call it.
/// @returns the inner function call expression
ast::CallExpression* CallInnerFunction() {
// Add a suffix to the function name, as the wrapper function will take the
// original entry point name.
auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol());
auto inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
// Clone everything, dropping the function and return type attributes.
// The parameter attributes will have already been stripped during
// processing.
auto* inner_function = ctx.dst->create<ast::Function>(
inner_name, ctx.Clone(func_ast->params()),
ctx.Clone(func_ast->return_type()), ctx.Clone(func_ast->body()),
ast::DecorationList{}, ast::DecorationList{});
ctx.Replace(func_ast, inner_function);
// Call the function.
return ctx.dst->Call(inner_function->symbol(), inner_call_parameters);
}
/// Process the entry point function.
void Process() {
bool needs_fixed_sample_mask = false;
if (func_ast->pipeline_stage() == ast::PipelineStage::kFragment &&
cfg.fixed_sample_mask != 0xFFFFFFFF) {
needs_fixed_sample_mask = true;
}
// Exit early if there is no shader IO to handle.
if (func_sem->Parameters().size() == 0 &&
func_sem->ReturnType()->Is<sem::Void>() && !needs_fixed_sample_mask) {
return;
}
// Process the entry point parameters, collecting those that need to be
// aggregated into a single structure.
if (!func_sem->Parameters().empty()) {
for (auto* param : func_sem->Parameters()) {
if (param->Type()->Is<sem::Struct>()) {
ProcessStructParameter(param);
} else {
ProcessNonStructParameter(param);
}
}
// Create a structure parameter for the outer entry point if necessary.
if (!wrapper_struct_param_members.empty()) {
CreateInputStruct();
}
}
// Recreate the original function and call it.
auto* call_inner = CallInnerFunction();
// Process the return type, and start building the wrapper function body.
std::function<ast::Type*()> wrapper_ret_type = [&] {
return ctx.dst->ty.void_();
};
if (func_sem->ReturnType()->Is<sem::Void>()) {
// The function call is just a statement with no result.
wrapper_body.push_back(ctx.dst->create<ast::CallStatement>(call_inner));
} else {
// Capture the result of calling the original function.
auto* inner_result = ctx.dst->Const(
ctx.dst->Symbols().New("inner_result"), nullptr, call_inner);
wrapper_body.push_back(ctx.dst->Decl(inner_result));
// Process the original return type to determine the outputs that the
// outer function needs to produce.
ProcessReturnType(func_sem->ReturnType(), inner_result->symbol());
}
// Add a fixed sample mask, if necessary.
if (needs_fixed_sample_mask) {
AddFixedSampleMask();
}
// Produce the entry point outputs, if necessary.
if (!wrapper_output_values.empty()) {
auto* output_struct = CreateOutputStruct();
wrapper_ret_type = [&, output_struct] {
return ctx.dst->ty.type_name(output_struct->name());
};
}
// Create the wrapper entry point function.
// Take the name of the original entry point function.
auto name = ctx.Clone(func_ast->symbol());
auto* wrapper_func = ctx.dst->create<ast::Function>(
name, wrapper_ep_parameters, wrapper_ret_type(),
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->decorations()),
ast::DecorationList{});
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast,
wrapper_func);
}
};
void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) {
@@ -75,302 +461,27 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
return;
}
// Strip entry point IO decorations from struct declarations.
// TODO(jrprice): This code is duplicated with the SPIR-V transform.
// Remove entry point IO attributes from struct declarations.
// New structures will be created for each entry point, as necessary.
for (auto* ty : ctx.src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
// Build new list of struct members without entry point IO decorations.
ast::StructMemberList new_struct_members;
for (auto* member : struct_ty->members()) {
ast::DecorationList new_decorations = RemoveDecorations(
ctx, member->decorations(), [](const ast::Decoration* deco) {
return deco->IsAnyOf<
ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
});
new_struct_members.push_back(
ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
for (auto* deco : member->decorations()) {
if (IsShaderIODecoration(deco)) {
ctx.Remove(member->decorations(), deco);
}
}
}
// Redeclare the struct.
auto new_struct_name = ctx.Clone(struct_ty->name());
auto* new_struct =
ctx.dst->create<ast::Struct>(new_struct_name, new_struct_members,
ctx.Clone(struct_ty->decorations()));
ctx.Replace(struct_ty, new_struct);
}
}
// Returns true if `decos` contains a `sample_mask` builtin.
auto has_sample_mask_builtin = [](const ast::DecorationList& decos) {
if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos)) {
return builtin->value() == ast::Builtin::kSampleMask;
}
return false;
};
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
bool needs_fixed_sample_mask =
func_ast->pipeline_stage() == ast::PipelineStage::kFragment &&
cfg->fixed_sample_mask != 0xFFFFFFFF;
ast::VariableList new_parameters;
if (!func->Parameters().empty()) {
// Collect all parameters and build a list of new struct members.
auto new_struct_param_symbol = ctx.dst->Sym();
ast::StructMemberList new_struct_members;
for (auto* param : func->Parameters()) {
if (cfg->builtin_style == BuiltinStyle::kParameter &&
ast::HasDecoration<ast::BuiltinDecoration>(
param->Declaration()->decorations())) {
// If this parameter is a builtin and we are emitting those as
// parameters, then just clone it as is.
new_parameters.push_back(
ctx.Clone(const_cast<ast::Variable*>(param->Declaration())));
continue;
}
auto param_name = ctx.Clone(param->Declaration()->symbol());
auto* param_ty = param->Type();
auto* param_declared_ty = param->Declaration()->type();
ast::Expression* func_const_initializer = nullptr;
if (auto* str = param_ty->As<sem::Struct>()) {
// Pull out all struct members and build initializer list.
ast::ExpressionList init_values;
for (auto* member : str->Members()) {
if (member->Type()->Is<sem::Struct>()) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "nested pipeline IO struct";
}
ast::DecorationList new_decorations = RemoveDecorations(
ctx, member->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<
ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
});
if (cfg->builtin_style == BuiltinStyle::kParameter &&
ast::HasDecoration<ast::BuiltinDecoration>(
member->Declaration()->decorations())) {
// If this struct member is a builtin and we are emitting those as
// parameters, then move it to the parameter list.
auto* member_ty = CreateASTTypeFor(ctx, member->Type());
auto new_param_name = ctx.dst->Sym();
new_parameters.push_back(
ctx.dst->Param(new_param_name, member_ty, new_decorations));
init_values.push_back(ctx.dst->Expr(new_param_name));
continue;
}
auto member_name = ctx.Clone(member->Declaration()->symbol());
auto* member_type = ctx.Clone(member->Declaration()->type());
new_struct_members.push_back(
ctx.dst->Member(member_name, member_type, new_decorations));
init_values.push_back(
ctx.dst->MemberAccessor(new_struct_param_symbol, member_name));
}
func_const_initializer =
ctx.dst->Construct(ctx.Clone(param_declared_ty), init_values);
} else {
new_struct_members.push_back(
ctx.dst->Member(param_name, ctx.Clone(param_declared_ty),
ctx.Clone(param->Declaration()->decorations())));
func_const_initializer =
ctx.dst->MemberAccessor(new_struct_param_symbol, param_name);
}
// Create a function-scope const to replace the parameter.
// Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const(
param_name, ctx.Clone(param_declared_ty), func_const_initializer);
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const.
for (auto* user : param->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(param_name));
}
}
if (!new_struct_members.empty()) {
// Sort struct members to satisfy HLSL interfacing matching rules.
std::sort(new_struct_members.begin(), new_struct_members.end(),
StructMemberComparator);
// Create the new struct type.
auto in_struct_name = ctx.dst->Sym();
auto* in_struct = ctx.dst->create<ast::Struct>(
in_struct_name, new_struct_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
in_struct);
// Create a new function parameter using this struct type.
auto* struct_param = ctx.dst->Param(
new_struct_param_symbol, ctx.dst->ty.type_name(in_struct_name));
new_parameters.push_back(struct_param);
}
}
// Handle return type.
auto* ret_type = func->ReturnType();
std::function<ast::Type*()> new_ret_type;
if (ret_type->Is<sem::Void>() && !needs_fixed_sample_mask) {
new_ret_type = [&ctx] { return ctx.dst->ty.void_(); };
} else {
ast::StructMemberList new_struct_members;
bool has_authored_sample_mask = false;
if (auto* str = ret_type->As<sem::Struct>()) {
// Rebuild struct with only the entry point IO attributes.
for (auto* member : str->Members()) {
if (member->Type()->Is<sem::Struct>()) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "nested pipeline IO struct";
}
ast::DecorationList new_decorations = RemoveDecorations(
ctx, member->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<
ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
});
auto symbol = ctx.Clone(member->Declaration()->symbol());
auto* member_ty = ctx.Clone(member->Declaration()->type());
new_struct_members.push_back(
ctx.dst->Member(symbol, member_ty, new_decorations));
if (has_sample_mask_builtin(new_decorations)) {
has_authored_sample_mask = true;
}
}
} else if (!ret_type->Is<sem::Void>()) {
auto* member_ty = ctx.Clone(func->Declaration()->return_type());
auto decos = ctx.Clone(func_ast->return_type_decorations());
new_struct_members.push_back(
ctx.dst->Member("value", member_ty, std::move(decos)));
if (has_sample_mask_builtin(func_ast->return_type_decorations())) {
has_authored_sample_mask = true;
}
}
// If a sample mask builtin is required and the shader source did not
// contain one, create one now.
if (needs_fixed_sample_mask && !has_authored_sample_mask) {
new_struct_members.push_back(
ctx.dst->Member(ctx.dst->Sym(), ctx.dst->ty.u32(),
{ctx.dst->Builtin(ast::Builtin::kSampleMask)}));
}
// Sort struct members to satisfy HLSL interfacing matching rules.
std::sort(new_struct_members.begin(), new_struct_members.end(),
StructMemberComparator);
// Create the new struct type.
auto out_struct_name = ctx.dst->Sym();
auto* out_struct = ctx.dst->create<ast::Struct>(
out_struct_name, new_struct_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
out_struct);
new_ret_type = [out_struct_name, &ctx] {
return ctx.dst->ty.type_name(out_struct_name);
};
// Replace all return statements.
for (auto* ret : func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
// Reconstruct the return value using the newly created struct.
std::function<ast::Expression*()> new_ret_value = [&ctx, ret] {
return ctx.Clone(ret->value());
};
ast::ExpressionList ret_values;
if (ret_type->Is<sem::Struct>()) {
if (!ret->value()->Is<ast::IdentifierExpression>()) {
// Create a const to hold the return value expression to avoid
// re-evaluating it multiple times.
auto temp = ctx.dst->Sym();
auto* ty = CreateASTTypeFor(ctx, ret_type);
auto* temp_var =
ctx.dst->Decl(ctx.dst->Const(temp, ty, new_ret_value()));
ctx.InsertBefore(ret_sem->Block()->Declaration()->statements(), ret,
temp_var);
new_ret_value = [&ctx, temp] { return ctx.dst->Expr(temp); };
}
for (auto* member : new_struct_members) {
ast::Expression* expr = nullptr;
if (needs_fixed_sample_mask &&
has_sample_mask_builtin(member->decorations())) {
// Use the fixed sample mask, combining it with the authored value
// if there is one.
expr = ctx.dst->Expr(cfg->fixed_sample_mask);
if (has_authored_sample_mask) {
expr = ctx.dst->And(
ctx.dst->MemberAccessor(new_ret_value(), member->symbol()),
expr);
}
} else {
expr = ctx.dst->MemberAccessor(new_ret_value(), member->symbol());
}
ret_values.push_back(expr);
}
} else {
if (!ret_type->Is<sem::Void>()) {
ret_values.push_back(new_ret_value());
}
if (needs_fixed_sample_mask) {
// If the original return value was a sample mask, `and` it with the
// fixed mask and return the result.
// Otherwise, append the fixed mask to the list of return values,
// since it will be the last element of the output struct.
if (has_authored_sample_mask) {
ret_values[0] =
ctx.dst->And(ret_values[0], cfg->fixed_sample_mask);
} else {
ret_values.push_back(ctx.dst->Expr(cfg->fixed_sample_mask));
}
}
}
auto* new_ret =
ctx.dst->Return(ctx.dst->Construct(new_ret_type(), ret_values));
ctx.Replace(ret, new_ret);
}
if (needs_fixed_sample_mask && func->ReturnStatements().empty()) {
// There we no return statements but we need to return a fixed sample
// mask, so add a return statement that does this.
ctx.InsertBack(func_ast->body()->statements(),
ctx.dst->Return(ctx.dst->Construct(
new_ret_type(), cfg->fixed_sample_mask)));
}
}
// Rewrite the function header with the new parameters.
auto* new_func = ctx.dst->create<ast::Function>(
func_ast->source(), ctx.Clone(func_ast->symbol()), new_parameters,
new_ret_type(), ctx.Clone(func_ast->body()),
ctx.Clone(func_ast->decorations()), ast::DecorationList{});
ctx.Replace(func_ast, new_func);
State state(ctx, *cfg, func_ast);
state.Process();
}
ctx.Clone();

View File

@@ -21,10 +21,12 @@ namespace tint {
namespace transform {
/// CanonicalizeEntryPointIO is a transform used to rewrite shader entry point
/// interfaces into a form that the generators can handle. After the transform,
/// an entry point's parameters will be aggregated into a single struct, and its
/// return type will either be a struct or void. All structs in the module that
/// have entry point IO decorations will have exactly one pipeline stage usage.
/// interfaces into a form that the generators can handle. Each entry point
/// function is stripped of all shader IO attributes and wrapped in a function
/// that provides the shader interface.
/// The transform config determines how shader IO parameters will be exposed.
/// Entry point return values are always produced as a structure, and optionally
/// include additional builtins as per the transform config.
///
/// Before:
/// ```
@@ -36,12 +38,15 @@ namespace transform {
/// [[stage(fragment)]]
/// fn frag_main([[builtin(position)]] coord : vec4<f32>,
/// locations : Locations) -> [[location(0)]] f32 {
/// if (coord.w > 1.0) {
/// return 0.0;
/// }
/// var col : f32 = (coord.x * locations.loc1);
/// return col;
/// }
/// ```
///
/// After:
/// After (using structures for all parameters):
/// ```
/// struct Locations{
/// loc1 : f32;
@@ -58,14 +63,21 @@ namespace transform {
/// [[location(0)]] loc0 : f32;
/// };
///
/// fn frag_main_inner(coord : vec4<f32>,
/// locations : Locations) -> f32 {
/// if (coord.w > 1.0) {
/// return 0.0;
/// }
/// var col : f32 = (coord.x * locations.loc1);
/// return col;
/// }
///
/// [[stage(fragment)]]
/// fn frag_main(in : frag_main_in) -> frag_main_out {
/// const coord = in.coord;
/// const locations = Locations(in.loc1, in.loc2);
/// var col : f32 = (coord.x * locations.loc1);
/// var retval : frag_main_out;
/// retval.loc0 = col;
/// return retval;
/// let inner_retval = frag_main_inner(in.coord, Locations(in.loc1, in.loc2));
/// var wrapper_result : frag_main_out;
/// wrapper_result.loc0 = inner_retval;
/// return wrapper_result;
/// }
/// ```
class CanonicalizeEntryPointIO
@@ -111,6 +123,8 @@ class CanonicalizeEntryPointIO
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
struct State;
};
} // namespace transform

View File

@@ -34,6 +34,29 @@ TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) {
EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, NoShaderIO) {
// Test that we do not introduce wrapper functions when there is no shader IO
// to process.
auto* src = R"(
[[stage(fragment)]]
fn frag_main() {
}
[[stage(compute), workgroup_size(1)]]
fn comp_main() {
}
)";
auto* expect = src;
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::BuiltinStyle::kParameter);
auto got = Run<CanonicalizeEntryPointIO>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Parameters_BuiltinsAsParameters) {
auto* src = R"(
[[stage(fragment)]]
@@ -52,11 +75,13 @@ struct tint_symbol_1 {
loc2 : vec4<u32>;
};
fn frag_main_inner(loc1 : f32, loc2 : vec4<u32>, coord : vec4<f32>) {
var col : f32 = (coord.x * loc1);
}
[[stage(fragment)]]
fn frag_main([[builtin(position)]] coord : vec4<f32>, tint_symbol : tint_symbol_1) {
let loc1 : f32 = tint_symbol.loc1;
let loc2 : vec4<u32> = tint_symbol.loc2;
var col : f32 = (coord.x * loc1);
frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, coord);
}
)";
@@ -88,12 +113,13 @@ struct tint_symbol_1 {
coord : vec4<f32>;
};
fn frag_main_inner(loc1 : f32, loc2 : vec4<u32>, coord : vec4<f32>) {
var col : f32 = (coord.x * loc1);
}
[[stage(fragment)]]
fn frag_main(tint_symbol : tint_symbol_1) {
let loc1 : f32 = tint_symbol.loc1;
let loc2 : vec4<u32> = tint_symbol.loc2;
let coord : vec4<f32> = tint_symbol.coord;
var col : f32 = (coord.x * loc1);
frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, tint_symbol.coord);
}
)";
@@ -123,10 +149,13 @@ struct tint_symbol_1 {
loc1 : myf32;
};
fn frag_main_inner(loc1 : myf32) {
var x : myf32 = loc1;
}
[[stage(fragment)]]
fn frag_main(tint_symbol : tint_symbol_1) {
let loc1 : myf32 = tint_symbol.loc1;
var x : myf32 = loc1;
frag_main_inner(tint_symbol.loc1);
}
)";
@@ -156,10 +185,12 @@ struct tint_symbol_1 {
loc2 : vec4<u32>;
};
fn frag_main_inner(loc1 : f32, loc2 : vec4<u32>, coord : vec4<f32>) {
}
[[stage(fragment)]]
fn frag_main([[builtin(position)]] coord : vec4<f32>, tint_symbol : tint_symbol_1) {
let loc1 : f32 = tint_symbol.loc1;
let loc2 : vec4<u32> = tint_symbol.loc2;
frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, coord);
}
)";
@@ -191,11 +222,12 @@ struct tint_symbol_1 {
coord : vec4<f32>;
};
fn frag_main_inner(loc1 : f32, loc2 : vec4<u32>, coord : vec4<f32>) {
}
[[stage(fragment)]]
fn frag_main(tint_symbol : tint_symbol_1) {
let loc1 : f32 = tint_symbol.loc1;
let loc2 : vec4<u32> = tint_symbol.loc2;
let coord : vec4<f32> = tint_symbol.coord;
frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, tint_symbol.coord);
}
)";
@@ -235,7 +267,7 @@ struct FragLocations {
loc2 : vec4<u32>;
};
struct tint_symbol_2 {
struct tint_symbol_1 {
[[location(0)]]
loc0 : f32;
[[location(1)]]
@@ -244,13 +276,14 @@ struct tint_symbol_2 {
loc2 : vec4<u32>;
};
[[stage(fragment)]]
fn frag_main([[builtin(position)]] tint_symbol_1 : vec4<f32>, tint_symbol : tint_symbol_2) {
let loc0 : f32 = tint_symbol.loc0;
let locations : FragLocations = FragLocations(tint_symbol.loc1, tint_symbol.loc2);
let builtins : FragBuiltins = FragBuiltins(tint_symbol_1);
fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
}
[[stage(fragment)]]
fn frag_main([[builtin(position)]] coord : vec4<f32>, tint_symbol : tint_symbol_1) {
frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(coord));
}
)";
DataMap data;
@@ -300,12 +333,13 @@ struct tint_symbol_1 {
coord : vec4<f32>;
};
fn frag_main_inner(loc0 : f32, locations : FragLocations, builtins : FragBuiltins) {
var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
}
[[stage(fragment)]]
fn frag_main(tint_symbol : tint_symbol_1) {
let loc0 : f32 = tint_symbol.loc0;
let locations : FragLocations = FragLocations(tint_symbol.loc1, tint_symbol.loc2);
let builtins : FragBuiltins = FragBuiltins(tint_symbol.coord);
var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(tint_symbol.coord));
}
)";
@@ -331,9 +365,16 @@ struct tint_symbol {
value : f32;
};
fn frag_main_inner() -> f32 {
return 1.0;
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol {
return tint_symbol(1.0);
let inner_result = frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.value = inner_result;
return wrapper_result;
}
)";
@@ -379,13 +420,22 @@ struct tint_symbol {
mask : u32;
};
[[stage(fragment)]]
fn frag_main() -> tint_symbol {
fn frag_main_inner() -> FragOutput {
var output : FragOutput;
output.depth = 1.0;
output.mask = 7u;
output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
return tint_symbol(output.color, output.depth, output.mask);
return output;
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol {
let inner_result = frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.color = inner_result.color;
wrapper_result.depth = inner_result.depth;
wrapper_result.mask = inner_result.mask;
return wrapper_result;
}
)";
@@ -436,10 +486,13 @@ struct tint_symbol_1 {
mul : f32;
};
fn frag_main1_inner(inputs : FragmentInput) {
var x : f32 = foo(inputs);
}
[[stage(fragment)]]
fn frag_main1(tint_symbol : tint_symbol_1) {
let inputs : FragmentInput = FragmentInput(tint_symbol.value, tint_symbol.mul);
var x : f32 = foo(inputs);
frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
}
struct tint_symbol_3 {
@@ -449,10 +502,13 @@ struct tint_symbol_3 {
mul : f32;
};
fn frag_main2_inner(inputs : FragmentInput) {
var x : f32 = foo(inputs);
}
[[stage(fragment)]]
fn frag_main2(tint_symbol_2 : tint_symbol_3) {
let inputs : FragmentInput = FragmentInput(tint_symbol_2.value, tint_symbol_2.mul);
var x : f32 = foo(inputs);
frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
}
)";
@@ -512,13 +568,16 @@ struct tint_symbol_1 {
col2 : f32;
};
[[stage(fragment)]]
fn frag_main1(tint_symbol : tint_symbol_1) {
let inputs : FragmentInput = FragmentInput(tint_symbol.col1, tint_symbol.col2);
fn frag_main1_inner(inputs : FragmentInput) {
global_inputs = inputs;
var r : f32 = foo();
var g : f32 = bar();
}
[[stage(fragment)]]
fn frag_main1(tint_symbol : tint_symbol_1) {
frag_main1_inner(FragmentInput(tint_symbol.col1, tint_symbol.col2));
}
)";
DataMap data;
@@ -593,12 +652,18 @@ struct tint_symbol_2 {
col2 : myf32;
};
fn frag_main_inner(inputs : MyFragmentInput) -> MyFragmentOutput {
var x : myf32 = foo(inputs);
return MyFragmentOutput(x, inputs.col2);
}
[[stage(fragment)]]
fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
let inputs : MyFragmentInput = MyFragmentInput(tint_symbol.col1, tint_symbol.col2);
var x : myf32 = foo(inputs);
let tint_symbol_3 : FragmentOutput = MyFragmentOutput(x, inputs.col2);
return tint_symbol_2(tint_symbol_3.col1, tint_symbol_3.col2);
let inner_result = frag_main_inner(MyFragmentInput(tint_symbol.col1, tint_symbol.col2));
var wrapper_result : tint_symbol_2;
wrapper_result.col1 = inner_result.col1;
wrapper_result.col2 = inner_result.col2;
return wrapper_result;
}
)";
@@ -660,13 +725,22 @@ struct tint_symbol {
pos : vec4<f32>;
};
[[stage(vertex)]]
fn vert_main() -> tint_symbol {
let tint_symbol_1 : VertexOut = VertexOut();
return tint_symbol(tint_symbol_1.loc1, tint_symbol_1.loc2, tint_symbol_1.loc3, tint_symbol_1.pos);
fn vert_main_inner() -> VertexOut {
return VertexOut();
}
struct tint_symbol_3 {
[[stage(vertex)]]
fn vert_main() -> tint_symbol {
let inner_result = vert_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.pos = inner_result.pos;
wrapper_result.loc1 = inner_result.loc1;
wrapper_result.loc2 = inner_result.loc2;
wrapper_result.loc3 = inner_result.loc3;
return wrapper_result;
}
struct tint_symbol_2 {
[[location(1), interpolate(flat)]]
loc1 : f32;
[[location(2), interpolate(linear, sample)]]
@@ -675,12 +749,14 @@ struct tint_symbol_3 {
loc3 : f32;
};
[[stage(fragment)]]
fn frag_main(tint_symbol_2 : tint_symbol_3) {
let inputs : FragmentIn = FragmentIn(tint_symbol_2.loc1, tint_symbol_2.loc2);
let loc3 : f32 = tint_symbol_2.loc3;
fn frag_main_inner(inputs : FragmentIn, loc3 : f32) {
let x = ((inputs.loc1 + inputs.loc2) + loc3);
}
[[stage(fragment)]]
fn frag_main(tint_symbol_1 : tint_symbol_2) {
frag_main_inner(FragmentIn(tint_symbol_1.loc1, tint_symbol_1.loc2), tint_symbol_1.loc3);
}
)";
DataMap data;
@@ -718,20 +794,33 @@ struct tint_symbol {
pos : vec4<f32>;
};
[[stage(vertex)]]
fn main1() -> tint_symbol {
let tint_symbol_1 : VertexOut = VertexOut();
return tint_symbol(tint_symbol_1.pos);
fn main1_inner() -> VertexOut {
return VertexOut();
}
struct tint_symbol_2 {
[[stage(vertex)]]
fn main1() -> tint_symbol {
let inner_result = main1_inner();
var wrapper_result : tint_symbol;
wrapper_result.pos = inner_result.pos;
return wrapper_result;
}
struct tint_symbol_1 {
[[builtin(position), invariant]]
value : vec4<f32>;
};
fn main2_inner() -> vec4<f32> {
return vec4<f32>();
}
[[stage(vertex)]]
fn main2() -> tint_symbol_2 {
return tint_symbol_2(vec4<f32>());
fn main2() -> tint_symbol_1 {
let inner_result_1 = main2_inner();
var wrapper_result_1 : tint_symbol_1;
wrapper_result_1.value = inner_result_1;
return wrapper_result_1;
}
)";
@@ -792,11 +881,16 @@ struct tint_symbol_2 {
value : f32;
};
fn frag_main_inner(inputs : FragmentInput) -> FragmentOutput {
return FragmentOutput(((inputs.coord.x * inputs.value) + inputs.loc0));
}
[[stage(fragment)]]
fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
let inputs : FragmentInput = FragmentInput(tint_symbol.value, tint_symbol.coord, tint_symbol.loc0);
let tint_symbol_3 : FragmentOutput = FragmentOutput(((inputs.coord.x * inputs.value) + inputs.loc0));
return tint_symbol_2(tint_symbol_3.value);
let inner_result = frag_main_inner(FragmentInput(tint_symbol.value, tint_symbol.coord, tint_symbol.loc0));
var wrapper_result : tint_symbol_2;
wrapper_result.value = inner_result.value;
return wrapper_result;
}
)";
@@ -865,13 +959,23 @@ struct tint_symbol {
pos : vec4<f32>;
};
[[stage(vertex)]]
fn vert_main() -> tint_symbol {
let tint_symbol_1 : VertexOutput = VertexOutput();
return tint_symbol(tint_symbol_1.a, tint_symbol_1.b, tint_symbol_1.c, tint_symbol_1.d, tint_symbol_1.pos);
fn vert_main_inner() -> VertexOutput {
return VertexOutput();
}
struct tint_symbol_3 {
[[stage(vertex)]]
fn vert_main() -> tint_symbol {
let inner_result = vert_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.b = inner_result.b;
wrapper_result.pos = inner_result.pos;
wrapper_result.d = inner_result.d;
wrapper_result.a = inner_result.a;
wrapper_result.c = inner_result.c;
return wrapper_result;
}
struct tint_symbol_2 {
[[location(0)]]
a : f32;
[[location(1)]]
@@ -886,12 +990,12 @@ struct tint_symbol_3 {
ff : bool;
};
fn frag_main_inner(ff : bool, c : i32, inputs : FragmentInputExtra, b : u32) {
}
[[stage(fragment)]]
fn frag_main(tint_symbol_2 : tint_symbol_3) {
let ff : bool = tint_symbol_2.ff;
let c : i32 = tint_symbol_2.c;
let inputs : FragmentInputExtra = FragmentInputExtra(tint_symbol_2.d, tint_symbol_2.pos, tint_symbol_2.a);
let b : u32 = tint_symbol_2.b;
fn frag_main(tint_symbol_1 : tint_symbol_2) {
frag_main_inner(tint_symbol_1.ff, tint_symbol_1.c, FragmentInputExtra(tint_symbol_1.d, tint_symbol_1.pos, tint_symbol_1.a), tint_symbol_1.b);
}
)";
@@ -916,9 +1020,12 @@ struct tint_symbol_2 {
col : f32;
};
fn tint_symbol_1_inner(col : f32) {
}
[[stage(fragment)]]
fn tint_symbol_1(tint_symbol : tint_symbol_2) {
let col : f32 = tint_symbol.col;
tint_symbol_1_inner(tint_symbol.col);
}
)";
@@ -938,14 +1045,20 @@ fn frag_main() {
)";
auto* expect = R"(
struct tint_symbol_1 {
struct tint_symbol {
[[builtin(sample_mask)]]
tint_symbol : u32;
fixed_sample_mask : u32;
};
fn frag_main_inner() {
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol_1 {
return tint_symbol_1(3u);
fn frag_main() -> tint_symbol {
frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.fixed_sample_mask = 3u;
return wrapper_result;
}
)";
@@ -966,14 +1079,21 @@ fn frag_main() {
)";
auto* expect = R"(
struct tint_symbol_1 {
struct tint_symbol {
[[builtin(sample_mask)]]
tint_symbol : u32;
fixed_sample_mask : u32;
};
fn frag_main_inner() {
return;
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol_1 {
return tint_symbol_1(3u);
fn frag_main() -> tint_symbol {
frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.fixed_sample_mask = 3u;
return wrapper_result;
}
)";
@@ -999,9 +1119,16 @@ struct tint_symbol {
value : u32;
};
fn frag_main_inner() -> u32 {
return 7u;
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol {
return tint_symbol((7u & 3u));
let inner_result = frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.value = (inner_result & 3u);
return wrapper_result;
}
)";
@@ -1022,16 +1149,24 @@ fn frag_main() -> [[location(0)]] f32 {
)";
auto* expect = R"(
struct tint_symbol_1 {
struct tint_symbol {
[[location(0)]]
value : f32;
[[builtin(sample_mask)]]
tint_symbol : u32;
fixed_sample_mask : u32;
};
fn frag_main_inner() -> f32 {
return 1.0;
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol_1 {
return tint_symbol_1(1.0, 3u);
fn frag_main() -> tint_symbol {
let inner_result = frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.value = inner_result;
wrapper_result.fixed_sample_mask = 3u;
return wrapper_result;
}
)";
@@ -1073,10 +1208,18 @@ struct tint_symbol {
mask : u32;
};
fn frag_main_inner() -> Output {
return Output(0.5, 7u, 1.0);
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol {
let tint_symbol_1 : Output = Output(0.5, 7u, 1.0);
return tint_symbol(tint_symbol_1.value, tint_symbol_1.depth, (tint_symbol_1.mask & 3u));
let inner_result = frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.depth = inner_result.depth;
wrapper_result.mask = (inner_result.mask & 3u);
wrapper_result.value = inner_result.value;
return wrapper_result;
}
)";
@@ -1108,19 +1251,27 @@ struct Output {
value : f32;
};
struct tint_symbol_1 {
struct tint_symbol {
[[location(0)]]
value : f32;
[[builtin(frag_depth)]]
depth : f32;
[[builtin(sample_mask)]]
tint_symbol : u32;
fixed_sample_mask : u32;
};
fn frag_main_inner() -> Output {
return Output(0.5, 1.0);
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol_1 {
let tint_symbol_2 : Output = Output(0.5, 1.0);
return tint_symbol_1(tint_symbol_2.value, tint_symbol_2.depth, 3u);
fn frag_main() -> tint_symbol {
let inner_result = frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.depth = inner_result.depth;
wrapper_result.value = inner_result.value;
wrapper_result.fixed_sample_mask = 3u;
return wrapper_result;
}
)";
@@ -1160,31 +1311,53 @@ struct tint_symbol {
value : u32;
};
[[stage(fragment)]]
fn frag_main1() -> tint_symbol {
return tint_symbol((7u & 3u));
fn frag_main1_inner() -> u32 {
return 7u;
}
struct tint_symbol_2 {
[[stage(fragment)]]
fn frag_main1() -> tint_symbol {
let inner_result = frag_main1_inner();
var wrapper_result : tint_symbol;
wrapper_result.value = (inner_result & 3u);
return wrapper_result;
}
struct tint_symbol_1 {
[[location(0)]]
value : f32;
[[builtin(sample_mask)]]
tint_symbol_1 : u32;
fixed_sample_mask : u32;
};
[[stage(fragment)]]
fn frag_main2() -> tint_symbol_2 {
return tint_symbol_2(1.0, 3u);
fn frag_main2_inner() -> f32 {
return 1.0;
}
struct tint_symbol_3 {
[[stage(fragment)]]
fn frag_main2() -> tint_symbol_1 {
let inner_result_1 = frag_main2_inner();
var wrapper_result_1 : tint_symbol_1;
wrapper_result_1.value = inner_result_1;
wrapper_result_1.fixed_sample_mask = 3u;
return wrapper_result_1;
}
struct tint_symbol_2 {
[[builtin(position)]]
value : vec4<f32>;
};
fn vert_main1_inner() -> vec4<f32> {
return vec4<f32>();
}
[[stage(vertex)]]
fn vert_main1() -> tint_symbol_3 {
return tint_symbol_3(vec4<f32>());
fn vert_main1() -> tint_symbol_2 {
let inner_result_2 = vert_main1_inner();
var wrapper_result_2 : tint_symbol_2;
wrapper_result_2.value = inner_result_2;
return wrapper_result_2;
}
[[stage(compute), workgroup_size(1)]]
@@ -1200,6 +1373,57 @@ fn comp_main1() {
EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_AvoidNameClash) {
auto* src = R"(
struct FragOut {
[[location(0)]] fixed_sample_mask : vec4<f32>;
[[location(1)]] fixed_sample_mask_1 : vec4<f32>;
};
[[stage(fragment)]]
fn frag_main() -> FragOut {
return FragOut();
}
)";
auto* expect = R"(
struct FragOut {
fixed_sample_mask : vec4<f32>;
fixed_sample_mask_1 : vec4<f32>;
};
struct tint_symbol {
[[location(0)]]
fixed_sample_mask : vec4<f32>;
[[location(1)]]
fixed_sample_mask_1 : vec4<f32>;
[[builtin(sample_mask)]]
fixed_sample_mask_2 : u32;
};
fn frag_main_inner() -> FragOut {
return FragOut();
}
[[stage(fragment)]]
fn frag_main() -> tint_symbol {
let inner_result = frag_main_inner();
var wrapper_result : tint_symbol;
wrapper_result.fixed_sample_mask = inner_result.fixed_sample_mask;
wrapper_result.fixed_sample_mask_1 = inner_result.fixed_sample_mask_1;
wrapper_result.fixed_sample_mask_2 = 3u;
return wrapper_result;
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::BuiltinStyle::kParameter, 0x03);
auto got = Run<CanonicalizeEntryPointIO>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -34,15 +34,19 @@ fn main() {
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_1 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_2 : f32;
fn main_inner(local_invocation_index : u32, tint_symbol : ptr<workgroup, f32>, tint_symbol_1 : ptr<private, f32>) {
{
tint_symbol_1 = f32();
*(tint_symbol) = f32();
}
workgroupBarrier();
tint_symbol_1 = tint_symbol_2;
*(tint_symbol) = *(tint_symbol_1);
}
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_2 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_3 : f32;
main_inner(local_invocation_index, &(tint_symbol_2), &(tint_symbol_3));
}
)";
@@ -80,26 +84,30 @@ fn main() {
fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol_1 : ptr<private, f32>, tint_symbol_2 : ptr<workgroup, f32>) {
*(tint_symbol_1) = a;
*(tint_symbol_2) = b;
fn bar(a : f32, b : f32, tint_symbol : ptr<private, f32>, tint_symbol_1 : ptr<workgroup, f32>) {
*(tint_symbol) = a;
*(tint_symbol_1) = b;
}
fn foo(a : f32, tint_symbol_3 : ptr<private, f32>, tint_symbol_4 : ptr<workgroup, f32>) {
fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
bar(a, b, tint_symbol_3, tint_symbol_4);
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
fn main_inner(local_invocation_index : u32, tint_symbol_4 : ptr<workgroup, f32>, tint_symbol_5 : ptr<private, f32>) {
{
*(tint_symbol_4) = f32();
}
workgroupBarrier();
foo(1.0, tint_symbol_5, tint_symbol_4);
}
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_5 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_6 : f32;
{
tint_symbol_5 = f32();
}
workgroupBarrier();
foo(1.0, &(tint_symbol_6), &(tint_symbol_5));
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_6 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_7 : f32;
main_inner(local_invocation_index, &(tint_symbol_6), &(tint_symbol_7));
}
)";
@@ -148,16 +156,20 @@ fn main() {
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_1 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_2 : f32;
fn main_inner(local_invocation_index : u32, tint_symbol : ptr<workgroup, f32>, tint_symbol_1 : ptr<private, f32>) {
{
tint_symbol_1 = f32();
*(tint_symbol) = f32();
}
workgroupBarrier();
let x : f32 = (tint_symbol_2 + tint_symbol_1);
tint_symbol_2 = x;
let x : f32 = (*(tint_symbol_1) + *(tint_symbol));
*(tint_symbol_1) = x;
}
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_2 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_3 : f32;
main_inner(local_invocation_index, &(tint_symbol_2), &(tint_symbol_3));
}
)";

View File

@@ -130,10 +130,15 @@ struct tint_symbol_2 {
float value : SV_Target1;
};
float frag_main_inner(float foo) {
return foo;
}
tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
const float foo = tint_symbol.foo;
const tint_symbol_2 tint_symbol_3 = {foo};
return tint_symbol_3;
const float inner_result = frag_main_inner(tint_symbol.foo);
tint_symbol_2 wrapper_result = (tint_symbol_2)0;
wrapper_result.value = inner_result;
return wrapper_result;
}
)");
}
@@ -160,10 +165,15 @@ struct tint_symbol_2 {
float value : SV_Depth;
};
float frag_main_inner(float4 coord) {
return coord.x;
}
tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
const float4 coord = tint_symbol.coord;
const tint_symbol_2 tint_symbol_3 = {coord.x};
return tint_symbol_3;
const float inner_result = frag_main_inner(tint_symbol.coord);
tint_symbol_2 wrapper_result = (tint_symbol_2)0;
wrapper_result.value = inner_result;
return wrapper_result;
}
)");
}
@@ -218,23 +228,35 @@ struct tint_symbol {
float4 pos : SV_Position;
};
tint_symbol vert_main() {
const Interface tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f), 0.5f, 0.25f};
const tint_symbol tint_symbol_4 = {tint_symbol_1.col1, tint_symbol_1.col2, tint_symbol_1.pos};
return tint_symbol_4;
Interface vert_main_inner() {
const Interface tint_symbol_3 = {float4(0.0f, 0.0f, 0.0f, 0.0f), 0.5f, 0.25f};
return tint_symbol_3;
}
struct tint_symbol_3 {
tint_symbol vert_main() {
const Interface inner_result = vert_main_inner();
tint_symbol wrapper_result = (tint_symbol)0;
wrapper_result.pos = inner_result.pos;
wrapper_result.col1 = inner_result.col1;
wrapper_result.col2 = inner_result.col2;
return wrapper_result;
}
struct tint_symbol_2 {
float col1 : TEXCOORD1;
float col2 : TEXCOORD2;
float4 pos : SV_Position;
};
void frag_main(tint_symbol_3 tint_symbol_2) {
const Interface inputs = {tint_symbol_2.pos, tint_symbol_2.col1, tint_symbol_2.col2};
void frag_main_inner(Interface inputs) {
const float r = inputs.col1;
const float g = inputs.col2;
const float4 p = inputs.pos;
}
void frag_main(tint_symbol_2 tint_symbol_1) {
const Interface tint_symbol_4 = {tint_symbol_1.pos, tint_symbol_1.col1, tint_symbol_1.col2};
frag_main_inner(tint_symbol_4);
return;
}
)");
@@ -278,28 +300,38 @@ TEST_F(HlslGeneratorImplTest_Function,
};
VertexOutput foo(float x) {
const VertexOutput tint_symbol_4 = {float4(x, x, x, 1.0f)};
return tint_symbol_4;
const VertexOutput tint_symbol_2 = {float4(x, x, x, 1.0f)};
return tint_symbol_2;
}
struct tint_symbol {
float4 pos : SV_Position;
};
tint_symbol vert_main1() {
const VertexOutput tint_symbol_1 = foo(0.5f);
const tint_symbol tint_symbol_5 = {tint_symbol_1.pos};
return tint_symbol_5;
VertexOutput vert_main1_inner() {
return foo(0.5f);
}
struct tint_symbol_2 {
tint_symbol vert_main1() {
const VertexOutput inner_result = vert_main1_inner();
tint_symbol wrapper_result = (tint_symbol)0;
wrapper_result.pos = inner_result.pos;
return wrapper_result;
}
struct tint_symbol_1 {
float4 pos : SV_Position;
};
tint_symbol_2 vert_main2() {
const VertexOutput tint_symbol_3 = foo(0.25f);
const tint_symbol_2 tint_symbol_6 = {tint_symbol_3.pos};
return tint_symbol_6;
VertexOutput vert_main2_inner() {
return foo(0.25f);
}
tint_symbol_1 vert_main2() {
const VertexOutput inner_result_1 = vert_main2_inner();
tint_symbol_1 wrapper_result_1 = (tint_symbol_1)0;
wrapper_result_1.pos = inner_result_1.pos;
return wrapper_result_1;
}
)");
}

View File

@@ -111,10 +111,15 @@ struct tint_symbol_2 {
float value [[color(1)]];
};
float frag_main_inner(float foo) {
return foo;
}
fragment tint_symbol_2 frag_main(tint_symbol_1 tint_symbol [[stage_in]]) {
float const foo = tint_symbol.foo;
tint_symbol_2 const tint_symbol_3 = {.value=foo};
return tint_symbol_3;
float const inner_result = frag_main_inner(tint_symbol.foo);
tint_symbol_2 wrapper_result = {};
wrapper_result.value = inner_result;
return wrapper_result;
}
)");
@@ -137,13 +142,19 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOut_Builtins) {
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol_1 {
struct tint_symbol {
float value [[depth(any)]];
};
fragment tint_symbol_1 frag_main(float4 coord [[position]]) {
tint_symbol_1 const tint_symbol_2 = {.value=coord.x};
return tint_symbol_2;
float frag_main_inner(float4 coord) {
return coord.x;
}
fragment tint_symbol frag_main(float4 coord [[position]]) {
float const inner_result = frag_main_inner(coord);
tint_symbol wrapper_result = {};
wrapper_result.value = inner_result;
return wrapper_result;
}
)");
@@ -201,21 +212,33 @@ struct tint_symbol {
float col2 [[user(locn2)]];
float4 pos [[position]];
};
struct tint_symbol_4 {
struct tint_symbol_2 {
float col1 [[user(locn1)]];
float col2 [[user(locn2)]];
};
vertex tint_symbol vert_main() {
Interface const tint_symbol_1 = {.col1=0.5f, .col2=0.25f, .pos=float4()};
tint_symbol const tint_symbol_5 = {.col1=tint_symbol_1.col1, .col2=tint_symbol_1.col2, .pos=tint_symbol_1.pos};
return tint_symbol_5;
Interface vert_main_inner() {
Interface const tint_symbol_3 = {.col1=0.5f, .col2=0.25f, .pos=float4()};
return tint_symbol_3;
}
fragment void frag_main(float4 tint_symbol_3 [[position]], tint_symbol_4 tint_symbol_2 [[stage_in]]) {
Interface const colors = {.col1=tint_symbol_2.col1, .col2=tint_symbol_2.col2, .pos=tint_symbol_3};
vertex tint_symbol vert_main() {
Interface const inner_result = vert_main_inner();
tint_symbol wrapper_result = {};
wrapper_result.col1 = inner_result.col1;
wrapper_result.col2 = inner_result.col2;
wrapper_result.pos = inner_result.pos;
return wrapper_result;
}
void frag_main_inner(Interface colors) {
float const r = colors.col1;
float const g = colors.col2;
}
fragment void frag_main(float4 pos [[position]], tint_symbol_2 tint_symbol_1 [[stage_in]]) {
Interface const tint_symbol_4 = {.col1=tint_symbol_1.col1, .col2=tint_symbol_1.col2, .pos=pos};
frag_main_inner(tint_symbol_4);
return;
}
@@ -265,25 +288,35 @@ struct VertexOutput {
struct tint_symbol {
float4 pos [[position]];
};
struct tint_symbol_2 {
struct tint_symbol_1 {
float4 pos [[position]];
};
VertexOutput foo(float x) {
VertexOutput const tint_symbol_4 = {.pos=float4(x, x, x, 1.0f)};
return tint_symbol_4;
VertexOutput const tint_symbol_2 = {.pos=float4(x, x, x, 1.0f)};
return tint_symbol_2;
}
VertexOutput vert_main1_inner() {
return foo(0.5f);
}
vertex tint_symbol vert_main1() {
VertexOutput const tint_symbol_1 = foo(0.5f);
tint_symbol const tint_symbol_5 = {.pos=tint_symbol_1.pos};
return tint_symbol_5;
VertexOutput const inner_result = vert_main1_inner();
tint_symbol wrapper_result = {};
wrapper_result.pos = inner_result.pos;
return wrapper_result;
}
vertex tint_symbol_2 vert_main2() {
VertexOutput const tint_symbol_3 = foo(0.25f);
tint_symbol_2 const tint_symbol_6 = {.pos=tint_symbol_3.pos};
return tint_symbol_6;
VertexOutput vert_main2_inner() {
return foo(0.25f);
}
vertex tint_symbol_1 vert_main2() {
VertexOutput const inner_result_1 = vert_main2_inner();
tint_symbol_1 wrapper_result_1 = {};
wrapper_result_1.pos = inner_result_1.pos;
return wrapper_result_1;
}
)");