spirv: Use generic transform to process shader IO

The refactored CanonicalizeEntryPointIO transform makes it much easier
to handle SPIR-V style IO as well, and doing this removes a lot of
duplicated code. Remove all of the SPIR-V transform code for shader IO
and vertex point size.

Bug: tint:920
Change-Id: Id1b97517619b4d2fd09b45d5aee848259f3dfa77
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/60840
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price
2021-08-05 17:34:19 +00:00
committed by Tint LUCI CQ
parent 11e172ab64
commit 11c6fcdb51
1981 changed files with 55207 additions and 57171 deletions

View File

@@ -20,6 +20,7 @@
#include <utility>
#include <vector>
#include "src/ast/disable_validation_decoration.h"
#include "src/program_builder.h"
#include "src/sem/function.h"
@@ -148,11 +149,30 @@ struct CanonicalizeEntryPointIO::State {
ast::Expression* AddInput(std::string name,
ast::Type* type,
ast::DecorationList attributes) {
if (cfg.builtin_style == BuiltinStyle::kParameter &&
ast::HasDecoration<ast::BuiltinDecoration>(attributes)) {
// If this input is a builtin and we are emitting those as parameters,
// then add it to the parameter list and pass it directly to the inner
// function.
if (cfg.shader_style == ShaderStyle::kSpirv) {
// Vulkan requires that integer user-defined fragment inputs are
// always decorated with `Flat`.
if (type->is_integer_scalar_or_vector() &&
ast::HasDecoration<ast::LocationDecoration>(attributes) &&
func_ast->pipeline_stage() == ast::PipelineStage::kFragment) {
attributes.push_back(ctx.dst->Interpolate(
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
}
// Disable validation for use of the `input` storage class.
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass));
// Create the global variable and use its value for the shader input.
auto var = ctx.dst->Symbols().New(name);
ctx.dst->Global(var, type, ast::StorageClass::kInput,
std::move(attributes));
return ctx.dst->Expr(var);
} else if (cfg.shader_style == ShaderStyle::kMsl &&
ast::HasDecoration<ast::BuiltinDecoration>(attributes)) {
// If this input is a builtin and we are targeting MSL, then add it to the
// parameter list and pass it directly to the inner function.
wrapper_ep_parameters.push_back(
ctx.dst->Param(name, type, std::move(attributes)));
return ctx.dst->Expr(name);
@@ -173,6 +193,16 @@ struct CanonicalizeEntryPointIO::State {
ast::Type* type,
ast::DecorationList attributes,
ast::Expression* value) {
// Vulkan requires that integer user-defined vertex outputs are
// always decorated with `Flat`.
if (cfg.shader_style == ShaderStyle::kSpirv &&
type->is_integer_scalar_or_vector() &&
ast::HasDecoration<ast::LocationDecoration>(attributes) &&
func_ast->pipeline_stage() == ast::PipelineStage::kVertex) {
attributes.push_back(ctx.dst->Interpolate(
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
}
OutputValue output;
output.name = name;
output.type = type;
@@ -359,6 +389,23 @@ struct CanonicalizeEntryPointIO::State {
return out_struct;
}
/// Create and assign the wrapper function's output variables.
void CreateOutputVariables() {
for (auto& outval : wrapper_output_values) {
// Disable validation for use of the `output` storage class.
ast::DecorationList attributes = std::move(outval.attributes);
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass));
// Create the global variable and assign it the output value.
auto name = ctx.dst->Symbols().New(outval.name);
ctx.dst->Global(name, outval.type, ast::StorageClass::kOutput,
std::move(attributes));
wrapper_body.push_back(ctx.dst->Assign(name, outval.value));
}
}
// Recreate the original function without entry point attributes and call it.
/// @returns the inner function call expression
ast::CallExpression* CallInnerFunction() {
@@ -450,10 +497,14 @@ struct CanonicalizeEntryPointIO::State {
// Produce the entry point outputs, if necessary.
if (!wrapper_output_values.empty()) {
auto* output_struct = CreateOutputStruct();
wrapper_ret_type = [&, output_struct] {
return ctx.dst->ty.type_name(output_struct->name());
};
if (cfg.shader_style == ShaderStyle::kSpirv) {
CreateOutputVariables();
} else {
auto* output_struct = CreateOutputStruct();
wrapper_ret_type = [&, output_struct] {
return ctx.dst->ty.type_name(output_struct->name());
};
}
}
// Create the wrapper entry point function.
@@ -505,10 +556,10 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
ctx.Clone();
}
CanonicalizeEntryPointIO::Config::Config(BuiltinStyle builtins,
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
uint32_t sample_mask,
bool emit_point_size)
: builtin_style(builtins),
: shader_style(style),
fixed_sample_mask(sample_mask),
emit_vertex_point_size(emit_point_size) {}

View File

@@ -24,9 +24,9 @@ namespace transform {
/// interfaces into a form that the generators can handle. Each entry point
/// function is stripped of all shader IO attributes and wrapped in a function
/// that provides the shader interface.
/// The transform config determines how shader IO parameters will be exposed.
/// Entry point return values are always produced as a structure, and optionally
/// include additional builtins as per the transform config.
/// The transform config determines whether to use global variables, structures,
/// or parameters for the shader inputs and outputs, and optionally adds
/// additional builtins to the shader interface.
///
/// Before:
/// ```
@@ -83,21 +83,23 @@ namespace transform {
class CanonicalizeEntryPointIO
: public Castable<CanonicalizeEntryPointIO, Transform> {
public:
/// BuiltinStyle is an enumerator of different ways to emit builtins.
enum class BuiltinStyle {
/// Use non-struct function parameters for all builtins.
kParameter,
/// Use struct members for all builtins.
kStructMember,
/// ShaderStyle is an enumerator of different ways to emit shader IO.
enum class ShaderStyle {
/// Target SPIR-V (using global variables).
kSpirv,
/// Target MSL (using non-struct function parameters for builtins).
kMsl,
/// Target HLSL (using structures for all IO).
kHlsl,
};
/// Configuration options for the transform.
struct Config : public Castable<Config, Data> {
/// Constructor
/// @param builtins the approach to use for emitting builtins.
/// @param style the approach to use for emitting shader IO.
/// @param sample_mask an optional sample mask to combine with shader masks
/// @param emit_vertex_point_size `true` to generate a pointsize builtin
explicit Config(BuiltinStyle builtins,
explicit Config(ShaderStyle style,
uint32_t sample_mask = 0xFFFFFFFF,
bool emit_vertex_point_size = false);
@@ -107,8 +109,8 @@ class CanonicalizeEntryPointIO
/// Destructor
~Config() override;
/// The approach to use for emitting builtins.
BuiltinStyle const builtin_style;
/// The approach to use for emitting shader IO.
ShaderStyle const shader_style;
/// A fixed sample mask to combine into masks produced by fragment shaders.
uint32_t const fixed_sample_mask;

File diff suppressed because it is too large Load Diff

View File

@@ -73,7 +73,7 @@ Output Hlsl::Run(const Program* in, const DataMap& inputs) {
manager.Add<PadArrayElements>();
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::BuiltinStyle::kStructMember);
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
auto out = manager.Run(in, data);
if (!out.program.IsValid()) {
return out;

View File

@@ -64,7 +64,7 @@ Output Msl::Run(const Program* in, const DataMap& inputs) {
auto array_length_from_uniform_cfg = ArrayLengthFromUniform::Config(
sem::BindingPoint{0, buffer_size_ubo_index});
auto entry_point_io_cfg = CanonicalizeEntryPointIO::Config(
CanonicalizeEntryPointIO::BuiltinStyle::kParameter, fixed_sample_mask,
CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask,
emit_point_size);
// Use the SSBO binding numbers as the indices for the buffer size lookups.

View File

@@ -17,16 +17,10 @@
#include <string>
#include <utility>
#include "src/ast/call_statement.h"
#include "src/ast/disable_validation_decoration.h"
#include "src/ast/return_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/program_builder.h"
#include "src/sem/block_statement.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/external_texture_transform.h"
#include "src/transform/fold_constants.h"
#include "src/transform/for_loop_to_loop.h"
@@ -48,6 +42,7 @@ Output Spirv::Run(const Program* in, const DataMap& data) {
auto* cfg = data.Get<Config>();
Manager manager;
DataMap internal_inputs;
if (!cfg || !cfg->disable_workgroup_init) {
manager.Add<ZeroInitWorkgroupMemory>();
}
@@ -56,174 +51,27 @@ Output Spirv::Run(const Program* in, const DataMap& data) {
manager.Add<FoldConstants>();
manager.Add<ExternalTextureTransform>();
manager.Add<ForLoopToLoop>(); // Must come after ZeroInitWorkgroupMemory
auto transformedInput = manager.Run(in, data);
manager.Add<CanonicalizeEntryPointIO>();
internal_inputs.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::Config(
CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF,
(cfg && cfg->emit_vertex_point_size)));
auto transformedInput = manager.Run(in, internal_inputs);
if (transformedInput.program.Diagnostics().contains_errors()) {
return transformedInput;
}
ProgramBuilder out;
CloneContext ctx(&out, &transformedInput.program);
HandleEntryPointIOTypes(ctx);
ProgramBuilder builder;
CloneContext ctx(&builder, &transformedInput.program);
HandleSampleMaskBuiltins(ctx);
AddEmptyEntryPoint(ctx);
ctx.Clone();
// TODO(jrprice): Look into combining these transforms into a single clone.
Program tmp(std::move(out));
ProgramBuilder out2;
CloneContext ctx2(&out2, &tmp);
HandleSampleMaskBuiltins(ctx2);
AddEmptyEntryPoint(ctx2);
if (cfg && cfg->emit_vertex_point_size) {
EmitVertexPointSize(ctx2);
}
ctx2.Clone();
out2.SetTransformApplied(this);
return Output{Program(std::move(out2))};
}
void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
// Hoist entry point parameters, return values, and struct members out to
// global variables. Declare and construct struct parameters in the function
// body. Replace entry point return statements with calls to a function that
// assigns the return value to the global output variables.
//
// Before:
// ```
// struct FragmentInput {
// [[builtin(sample_index)]] sample_index : u32;
// [[builtin(sample_mask)]] sample_mask : u32;
// };
// struct FragmentOutput {
// [[builtin(frag_depth)]] depth: f32;
// [[builtin(sample_mask)]] mask_out : u32;
// };
//
// [[stage(fragment)]]
// fn frag_main(
// [[builtin(position)]] coord : vec4<f32>,
// samples : FragmentInput
// ) -> FragmentOutput {
// var output : FragmentOutput = FragmentOutput(1.0,
// samples.sample_mask);
// return output;
// }
// ```
//
// After:
// ```
// struct FragmentInput {
// sample_index : u32;
// sample_mask : u32;
// };
// struct FragmentOutput {
// depth: f32;
// mask_out : u32;
// };
//
// [[builtin(position)]] var<in> coord : vec4<f32>,
// [[builtin(sample_index)]] var<in> sample_index : u32,
// [[builtin(sample_mask)]] var<in> sample_mask : u32,
// [[builtin(frag_depth)]] var<out> depth: f32;
// [[builtin(sample_mask)]] var<out> mask_out : u32;
//
// fn frag_main_ret(retval : FragmentOutput) {
// depth = reval.depth;
// mask_out = retval.mask_out;
// }
//
// [[stage(fragment)]]
// fn frag_main() {
// let samples : FragmentInput(sample_index, sample_mask);
// var output : FragmentOutput = FragmentOutput(1.0,
// samples.sample_mask);
// frag_main_ret(output);
// return;
// }
// ```
// Strip entry point IO decorations from struct declarations.
for (auto* ty : ctx.src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
// Build new list of struct members without entry point IO decorations.
ast::StructMemberList new_struct_members;
for (auto* member : struct_ty->members()) {
ast::DecorationList new_decorations = RemoveDecorations(
ctx, member->decorations(), [](const ast::Decoration* deco) {
return deco->IsAnyOf<
ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
});
new_struct_members.push_back(
ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
}
// Redeclare the struct.
auto new_struct_name = ctx.Clone(struct_ty->name());
auto* new_struct =
ctx.dst->create<ast::Struct>(new_struct_name, new_struct_members,
ctx.Clone(struct_ty->decorations()));
ctx.Replace(struct_ty, new_struct);
}
}
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
for (auto* param : func->Parameters()) {
Symbol new_var = HoistToInputVariables(
ctx, func_ast, param->Type(), param->Declaration()->type(),
param->Declaration()->decorations());
// Replace all uses of the function parameter with the new variable.
for (auto* user : param->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(new_var));
}
}
if (!func->ReturnType()->Is<sem::Void>()) {
ast::StatementList stores;
auto store_value_symbol = ctx.dst->Sym();
HoistToOutputVariables(
ctx, func_ast, func->ReturnType(), func_ast->return_type(),
func_ast->return_type_decorations(), {}, store_value_symbol, stores);
// Create a function that writes a return value to all output variables.
auto* store_value = ctx.dst->Param(store_value_symbol,
ctx.Clone(func_ast->return_type()));
auto return_func_symbol = ctx.dst->Sym();
auto* return_func = ctx.dst->create<ast::Function>(
return_func_symbol, ast::VariableList{store_value},
ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
ast::DecorationList{}, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
return_func);
// Replace all return statements with calls to the output function.
for (auto* ret : func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
ctx.InsertBefore(ret_sem->Block()->Declaration()->statements(), ret,
ctx.dst->create<ast::CallStatement>(call));
ctx.Replace(ret, ctx.dst->Return());
}
}
// Rewrite the function header to remove the parameters and return value.
auto name = ctx.Clone(func_ast->symbol());
auto* body = ctx.Clone(func_ast->body());
auto decos = ctx.Clone(func_ast->decorations());
auto* new_func = ctx.dst->create<ast::Function>(
func_ast->source(), name, ast::VariableList{}, ctx.dst->ty.void_(),
body, decos, ast::DecorationList{});
ctx.Replace(func_ast, new_func);
}
builder.SetTransformApplied(this);
return Output{Program(std::move(builder))};
}
void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
@@ -274,31 +122,6 @@ void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
}
}
void Spirv::EmitVertexPointSize(CloneContext& ctx) const {
// No-op if there are no vertex stages in the module.
if (!ctx.src->AST().Functions().HasStage(ast::PipelineStage::kVertex)) {
return;
}
// Create a module-scope pointsize builtin output variable.
Symbol pointsize = ctx.dst->Symbols().New("tint_pointsize");
ctx.dst->Global(
pointsize, ctx.dst->ty.f32(), ast::StorageClass::kOutput,
ast::DecorationList{
ctx.dst->Builtin(ast::Builtin::kPointSize),
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass)});
// Assign 1.0 to the global at the start of all vertex shader entry points.
ctx.ReplaceAll([&ctx, pointsize](ast::Function* func) -> ast::Function* {
if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
ctx.InsertFront(func->body()->statements(),
ctx.dst->Assign(pointsize, 1.0f));
}
return nullptr;
});
}
void Spirv::AddEmptyEntryPoint(CloneContext& ctx) const {
for (auto* func : ctx.src->AST().Functions()) {
if (func->IsEntryPoint()) {
@@ -310,125 +133,6 @@ void Spirv::AddEmptyEntryPoint(CloneContext& ctx) const {
ctx.dst->WorkgroupSize(1)});
}
Symbol Spirv::HoistToInputVariables(
CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations) const {
if (!ty->Is<sem::Struct>()) {
// Base case: create a global variable and return.
ast::DecorationList new_decorations =
RemoveDecorations(ctx, decorations, [](const ast::Decoration* deco) {
return !deco->IsAnyOf<
ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
});
new_decorations.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass));
if (ty->is_integer_scalar_or_vector() &&
ast::HasDecoration<ast::LocationDecoration>(new_decorations) &&
func->pipeline_stage() == ast::PipelineStage::kFragment) {
// Vulkan requires that integer user-defined fragment inputs are
// always decorated with `Flat`.
new_decorations.push_back(ctx.dst->Interpolate(
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
}
auto global_var_symbol = ctx.dst->Sym();
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
ast::StorageClass::kInput, nullptr, new_decorations);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
return global_var_symbol;
}
// Recurse into struct members and build the initializer list.
std::vector<Symbol> init_value_names;
auto* struct_ty = ty->As<sem::Struct>();
for (auto* member : struct_ty->Members()) {
auto member_var = HoistToInputVariables(
ctx, func, member->Type(), member->Declaration()->type(),
member->Declaration()->decorations());
init_value_names.emplace_back(member_var);
}
auto func_var_symbol = ctx.dst->Sym();
if (func->body()->empty()) {
// The return value should never get used.
return func_var_symbol;
}
ast::ExpressionList init_values;
for (auto name : init_value_names) {
init_values.push_back(ctx.dst->Expr(name));
}
// Create a function-scope variable for the struct.
auto* initializer = ctx.dst->Construct(ctx.Clone(declared_ty), init_values);
auto* func_var =
ctx.dst->Const(func_var_symbol, ctx.Clone(declared_ty), initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_var));
return func_var_symbol;
}
void Spirv::HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,
ast::StatementList& stores) const {
// Base case.
if (!ty->Is<sem::Struct>()) {
// Create a global variable.
ast::DecorationList new_decorations =
RemoveDecorations(ctx, decorations, [](const ast::Decoration* deco) {
return !deco->IsAnyOf<
ast::BuiltinDecoration, ast::InterpolateDecoration,
ast::InvariantDecoration, ast::LocationDecoration>();
});
new_decorations.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass));
if (ty->is_integer_scalar_or_vector() &&
ast::HasDecoration<ast::LocationDecoration>(new_decorations) &&
func->pipeline_stage() == ast::PipelineStage::kVertex) {
// Vulkan requires that integer user-defined vertex outputs are
// always decorated with `Flat`.
new_decorations.push_back(ctx.dst->Interpolate(
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
}
auto global_var_symbol = ctx.dst->Sym();
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
ast::StorageClass::kOutput, nullptr, new_decorations);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
// Create the assignment instruction.
ast::Expression* rhs = ctx.dst->Expr(store_value);
for (auto member : member_accesses) {
rhs = ctx.dst->MemberAccessor(rhs, member);
}
stores.push_back(ctx.dst->Assign(ctx.dst->Expr(global_var_symbol), rhs));
return;
}
// Recurse into struct members.
auto* struct_ty = ty->As<sem::Struct>();
for (auto* member : struct_ty->Members()) {
member_accesses.push_back(ctx.Clone(member->Declaration()->symbol()));
HoistToOutputVariables(ctx, func, member->Type(),
member->Declaration()->type(),
member->Declaration()->decorations(),
member_accesses, store_value, stores);
member_accesses.pop_back();
}
}
Spirv::Config::Config(bool emit_vps, bool disable_wi)
: emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {}

View File

@@ -69,46 +69,10 @@ class Spirv : public Castable<Spirv, Transform> {
Output Run(const Program* program, const DataMap& data = {}) override;
private:
/// Hoist entry point parameters, return values, and struct members out to
/// global variables.
void HandleEntryPointIOTypes(CloneContext& ctx) const;
/// Change type of sample mask builtin variables to single element arrays.
void HandleSampleMaskBuiltins(CloneContext& ctx) const;
/// Add a PointSize builtin output to the module and set it to 1.0 from all
/// vertex stage entry points.
void EmitVertexPointSize(CloneContext& ctx) const;
/// Add an empty shader entry point if none exist in the module.
void AddEmptyEntryPoint(CloneContext& ctx) const;
/// Recursively create module-scope input variables for `ty` and add
/// function-scope variables for structs to `func`.
///
/// For non-structures, create a module-scope input variable.
/// For structures, recurse into members and then create a function-scope
/// variable initialized using the variables created for its members.
/// Return the symbol for the variable that was created.
Symbol HoistToInputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations) const;
/// Recursively create module-scope output variables for `ty` and build a list
/// of assignment instructions to write to them from `store_value`.
///
/// For non-structures, create a module-scope output variable and generate the
/// assignment instruction.
/// For structures, recurse into members, tracking the chain of member
/// accessors.
/// Returns the list of variable assignments in `stores`.
void HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,
ast::StatementList& stores) const;
};
} // namespace transform

View File

@@ -22,738 +22,6 @@ namespace {
using SpirvTest = TransformTest;
TEST_F(SpirvTest, HandleEntryPointIOTypes_Parameters) {
auto* src = R"(
[[stage(fragment)]]
fn frag_main([[builtin(position)]] coord : vec4<f32>,
[[location(1)]] loc1 : f32) {
var col : f32 = (coord.x * loc1);
}
[[stage(compute), workgroup_size(8, 1, 1)]]
fn compute_main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
[[builtin(local_invocation_index)]] local_index : u32) {
var id_x : u32 = local_id.x;
}
)";
auto* expect = R"(
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : vec4<f32>;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_1 : f32;
[[stage(fragment)]]
fn frag_main() {
var col : f32 = (tint_symbol.x * tint_symbol_1);
}
[[builtin(local_invocation_id), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_2 : vec3<u32>;
[[builtin(local_invocation_index), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_3 : u32;
[[stage(compute), workgroup_size(8, 1, 1)]]
fn compute_main() {
var id_x : u32 = tint_symbol_2.x;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_Parameter_TypeAlias) {
auto* src = R"(
type myf32 = f32;
[[stage(fragment)]]
fn frag_main([[location(1)]] loc1 : myf32) {
}
)";
auto* expect = R"(
type myf32 = f32;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : myf32;
[[stage(fragment)]]
fn frag_main() {
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnBuiltin) {
auto* src = R"(
[[stage(vertex)]]
fn vert_main() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>(1.0, 2.0, 3.0, 0.0);
}
)";
auto* expect = R"(
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
fn tint_symbol_2(tint_symbol : vec4<f32>) {
tint_symbol_1 = tint_symbol;
}
[[stage(vertex)]]
fn vert_main() {
tint_symbol_2(vec4<f32>(1.0, 2.0, 3.0, 0.0));
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation) {
auto* src = R"(
[[stage(fragment)]]
fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
if (loc_in > 10u) {
return 0.5;
}
return 1.0;
}
)";
auto* expect = R"(
[[location(0), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<in> tint_symbol : u32;
[[location(0), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_2 : f32;
fn tint_symbol_3(tint_symbol_1 : f32) {
tint_symbol_2 = tint_symbol_1;
}
[[stage(fragment)]]
fn frag_main() {
if ((tint_symbol > 10u)) {
tint_symbol_3(0.5);
return;
}
tint_symbol_3(1.0);
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation_TypeAlias) {
auto* src = R"(
type myf32 = f32;
[[stage(fragment)]]
fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 {
if (loc_in > 10u) {
return 0.5;
}
return 1.0;
}
)";
auto* expect = R"(
type myf32 = f32;
[[location(0), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<in> tint_symbol : u32;
[[location(0), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_2 : myf32;
fn tint_symbol_3(tint_symbol_1 : myf32) {
tint_symbol_2 = tint_symbol_1;
}
[[stage(fragment)]]
fn frag_main() {
if ((tint_symbol > 10u)) {
tint_symbol_3(0.5);
return;
}
tint_symbol_3(1.0);
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_StructParameters) {
auto* src = R"(
struct FragmentInput {
[[builtin(position)]] coord : vec4<f32>;
[[location(1)]] value : f32;
};
[[stage(fragment)]]
fn frag_main(inputs : FragmentInput) {
var col : f32 = inputs.coord.x * inputs.value;
}
)";
auto* expect = R"(
struct FragmentInput {
coord : vec4<f32>;
value : f32;
};
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : vec4<f32>;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_1 : f32;
[[stage(fragment)]]
fn frag_main() {
let tint_symbol_2 : FragmentInput = FragmentInput(tint_symbol, tint_symbol_1);
var col : f32 = (tint_symbol_2.coord.x * tint_symbol_2.value);
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_StructParameters_EmptyBody) {
auto* src = R"(
struct FragmentInput {
[[location(1)]] value : f32;
};
[[stage(fragment)]]
fn frag_main(inputs : FragmentInput) {
}
)";
auto* expect = R"(
struct FragmentInput {
value : f32;
};
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : f32;
[[stage(fragment)]]
fn frag_main() {
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnStruct) {
auto* src = R"(
struct VertexOutput {
[[builtin(position)]] pos : vec4<f32>;
[[location(1)]] value : f32;
};
[[stage(vertex)]]
fn vert_main() -> VertexOutput {
if (false) {
return VertexOutput();
}
var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0);
return VertexOutput(pos, 2.0);
}
)";
auto* expect = R"(
struct VertexOutput {
pos : vec4<f32>;
value : f32;
};
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_2 : f32;
fn tint_symbol_3(tint_symbol : VertexOutput) {
tint_symbol_1 = tint_symbol.pos;
tint_symbol_2 = tint_symbol.value;
}
[[stage(vertex)]]
fn vert_main() {
if (false) {
tint_symbol_3(VertexOutput());
return;
}
var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0);
tint_symbol_3(VertexOutput(pos, 2.0));
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_SharedStruct_SameShader) {
auto* src = R"(
struct Interface {
[[location(1)]] value : f32;
};
[[stage(fragment)]]
fn frag_main(inputs : Interface) -> Interface {
return inputs;
}
)";
auto* expect = R"(
struct Interface {
value : f32;
};
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : f32;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_3 : f32;
fn tint_symbol_4(tint_symbol_2 : Interface) {
tint_symbol_3 = tint_symbol_2.value;
}
[[stage(fragment)]]
fn frag_main() {
let tint_symbol_1 : Interface = Interface(tint_symbol);
tint_symbol_4(tint_symbol_1);
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_SharedStruct_DifferentShaders) {
auto* src = R"(
struct Interface {
[[builtin(position)]] pos : vec4<f32>;
[[location(1)]] value : f32;
};
[[stage(vertex)]]
fn vert_main() -> Interface {
return Interface(vec4<f32>(), 42.0);
}
[[stage(fragment)]]
fn frag_main(inputs : Interface) {
var x : f32 = inputs.value;
}
)";
auto* expect = R"(
struct Interface {
pos : vec4<f32>;
value : f32;
};
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_2 : f32;
fn tint_symbol_3(tint_symbol : Interface) {
tint_symbol_1 = tint_symbol.pos;
tint_symbol_2 = tint_symbol.value;
}
[[stage(vertex)]]
fn vert_main() {
tint_symbol_3(Interface(vec4<f32>(), 42.0));
return;
}
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_4 : vec4<f32>;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_5 : f32;
[[stage(fragment)]]
fn frag_main() {
let tint_symbol_6 : Interface = Interface(tint_symbol_4, tint_symbol_5);
var x : f32 = tint_symbol_6.value;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_InterpolateAttributes) {
auto* src = R"(
struct VertexOut {
[[builtin(position)]] pos : vec4<f32>;
[[location(1), interpolate(flat)]] loc1: f32;
[[location(2), interpolate(linear, sample)]] loc2 : f32;
[[location(3), interpolate(perspective, centroid)]] loc3 : f32;
};
struct FragmentIn {
[[location(1), interpolate(flat)]] loc1: f32;
[[location(2), interpolate(linear, sample)]] loc2 : f32;
};
[[stage(vertex)]]
fn vert_main() -> VertexOut {
return VertexOut();
}
[[stage(fragment)]]
fn frag_main(inputs : FragmentIn,
[[location(3), interpolate(perspective, centroid)]] loc3 : f32) {
let x = inputs.loc1 + inputs.loc2 + loc3;
}
)";
auto* expect = R"(
struct VertexOut {
pos : vec4<f32>;
loc1 : f32;
loc2 : f32;
loc3 : f32;
};
struct FragmentIn {
loc1 : f32;
loc2 : f32;
};
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
[[location(1), interpolate(flat), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_2 : f32;
[[location(2), interpolate(linear, sample), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_3 : f32;
[[location(3), interpolate(perspective, centroid), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_4 : f32;
fn tint_symbol_5(tint_symbol : VertexOut) {
tint_symbol_1 = tint_symbol.pos;
tint_symbol_2 = tint_symbol.loc1;
tint_symbol_3 = tint_symbol.loc2;
tint_symbol_4 = tint_symbol.loc3;
}
[[stage(vertex)]]
fn vert_main() {
tint_symbol_5(VertexOut());
return;
}
[[location(1), interpolate(flat), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_6 : f32;
[[location(2), interpolate(linear, sample), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_7 : f32;
[[location(3), interpolate(perspective, centroid), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_9 : f32;
[[stage(fragment)]]
fn frag_main() {
let tint_symbol_8 : FragmentIn = FragmentIn(tint_symbol_6, tint_symbol_7);
let x = ((tint_symbol_8.loc1 + tint_symbol_8.loc2) + tint_symbol_9);
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_InterpolateAttributes_Integers) {
// Test that we add a Flat attribute to integers that are vertex outputs and
// fragment inputs, but not vertex inputs or fragment outputs.
auto* src = R"(
struct VertexIn {
[[location(0)]] i : i32;
[[location(1)]] u : u32;
[[location(2)]] vi : vec4<i32>;
[[location(3)]] vu : vec4<u32>;
};
struct VertexOut {
[[location(0)]] i : i32;
[[location(1)]] u : u32;
[[location(2)]] vi : vec4<i32>;
[[location(3)]] vu : vec4<u32>;
[[builtin(position)]] pos : vec4<f32>;
};
struct FragmentInterface {
[[location(0)]] i : i32;
[[location(1)]] u : u32;
[[location(2)]] vi : vec4<i32>;
[[location(3)]] vu : vec4<u32>;
};
[[stage(vertex)]]
fn vert_main(in : VertexIn) -> VertexOut {
return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
}
[[stage(fragment)]]
fn frag_main(inputs : FragmentInterface) -> FragmentInterface {
return inputs;
}
)";
auto* expect = R"(
struct VertexIn {
i : i32;
u : u32;
vi : vec4<i32>;
vu : vec4<u32>;
};
struct VertexOut {
i : i32;
u : u32;
vi : vec4<i32>;
vu : vec4<u32>;
pos : vec4<f32>;
};
struct FragmentInterface {
i : i32;
u : u32;
vi : vec4<i32>;
vu : vec4<u32>;
};
[[location(0), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : i32;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_1 : u32;
[[location(2), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_2 : vec4<i32>;
[[location(3), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_3 : vec4<u32>;
[[location(0), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<out> tint_symbol_6 : i32;
[[location(1), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<out> tint_symbol_7 : u32;
[[location(2), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<out> tint_symbol_8 : vec4<i32>;
[[location(3), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<out> tint_symbol_9 : vec4<u32>;
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_10 : vec4<f32>;
fn tint_symbol_11(tint_symbol_5 : VertexOut) {
tint_symbol_6 = tint_symbol_5.i;
tint_symbol_7 = tint_symbol_5.u;
tint_symbol_8 = tint_symbol_5.vi;
tint_symbol_9 = tint_symbol_5.vu;
tint_symbol_10 = tint_symbol_5.pos;
}
[[stage(vertex)]]
fn vert_main() {
let tint_symbol_4 : VertexIn = VertexIn(tint_symbol, tint_symbol_1, tint_symbol_2, tint_symbol_3);
tint_symbol_11(VertexOut(tint_symbol_4.i, tint_symbol_4.u, tint_symbol_4.vi, tint_symbol_4.vu, vec4<f32>()));
return;
}
[[location(0), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<in> tint_symbol_12 : i32;
[[location(1), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<in> tint_symbol_13 : u32;
[[location(2), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<in> tint_symbol_14 : vec4<i32>;
[[location(3), internal(disable_validation__ignore_storage_class), interpolate(flat)]] var<in> tint_symbol_15 : vec4<u32>;
[[location(0), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_18 : i32;
[[location(1), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_19 : u32;
[[location(2), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_20 : vec4<i32>;
[[location(3), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_21 : vec4<u32>;
fn tint_symbol_22(tint_symbol_17 : FragmentInterface) {
tint_symbol_18 = tint_symbol_17.i;
tint_symbol_19 = tint_symbol_17.u;
tint_symbol_20 = tint_symbol_17.vi;
tint_symbol_21 = tint_symbol_17.vu;
}
[[stage(fragment)]]
fn frag_main() {
let tint_symbol_16 : FragmentInterface = FragmentInterface(tint_symbol_12, tint_symbol_13, tint_symbol_14, tint_symbol_15);
tint_symbol_22(tint_symbol_16);
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_InvariantAttributes) {
auto* src = R"(
struct VertexOut {
[[builtin(position), invariant]] pos : vec4<f32>;
};
[[stage(vertex)]]
fn main1() -> VertexOut {
return VertexOut();
}
[[stage(vertex)]]
fn main2() -> [[builtin(position), invariant]] vec4<f32> {
return vec4<f32>();
}
)";
auto* expect = R"(
struct VertexOut {
pos : vec4<f32>;
};
[[builtin(position), invariant, internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
fn tint_symbol_2(tint_symbol : VertexOut) {
tint_symbol_1 = tint_symbol.pos;
}
[[stage(vertex)]]
fn main1() {
tint_symbol_2(VertexOut());
return;
}
[[builtin(position), invariant, internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_4 : vec4<f32>;
fn tint_symbol_5(tint_symbol_3 : vec4<f32>) {
tint_symbol_4 = tint_symbol_3;
}
[[stage(vertex)]]
fn main2() {
tint_symbol_5(vec4<f32>());
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_StructLayoutDecorations) {
auto* src = R"(
[[block]]
struct FragmentInput {
[[size(16), location(1)]] value : f32;
[[builtin(position)]] [[align(32)]] coord : vec4<f32>;
[[location(0), interpolate(linear, sample)]] [[align(128)]] loc0 : f32;
};
struct FragmentOutput {
[[size(16), location(1), interpolate(flat)]] value : f32;
};
[[stage(fragment)]]
fn frag_main(inputs : FragmentInput) -> FragmentOutput {
return FragmentOutput(inputs.coord.x * inputs.value + inputs.loc0);
}
)";
auto* expect = R"(
[[block]]
struct FragmentInput {
[[size(16)]]
value : f32;
[[align(32)]]
coord : vec4<f32>;
[[align(128)]]
loc0 : f32;
};
struct FragmentOutput {
[[size(16)]]
value : f32;
};
[[location(1), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : f32;
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_1 : vec4<f32>;
[[location(0), interpolate(linear, sample), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_2 : f32;
[[location(1), interpolate(flat), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_5 : f32;
fn tint_symbol_6(tint_symbol_4 : FragmentOutput) {
tint_symbol_5 = tint_symbol_4.value;
}
[[stage(fragment)]]
fn frag_main() {
let tint_symbol_3 : FragmentInput = FragmentInput(tint_symbol, tint_symbol_1, tint_symbol_2);
tint_symbol_6(FragmentOutput(((tint_symbol_3.coord.x * tint_symbol_3.value) + tint_symbol_3.loc0)));
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleEntryPointIOTypes_WithPrivateGlobalVariable) {
// Test with a global variable to ensure that symbols are cloned correctly.
// crbug.com/tint/701
auto* src = R"(
var<private> x : f32;
struct VertexOutput {
[[builtin(position)]] Position : vec4<f32>;
};
[[stage(vertex)]]
fn main() -> VertexOutput {
return VertexOutput(vec4<f32>());
}
)";
auto* expect = R"(
var<private> x : f32;
struct VertexOutput {
Position : vec4<f32>;
};
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
fn tint_symbol_2(tint_symbol : VertexOutput) {
tint_symbol_1 = tint_symbol.Position;
}
[[stage(vertex)]]
fn main() {
tint_symbol_2(VertexOutput(vec4<f32>()));
return;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
auto* src = R"(
[[stage(fragment)]]
@@ -765,20 +33,20 @@ fn main([[builtin(sample_index)]] sample_index : u32,
)";
auto* expect = R"(
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : u32;
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_1 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_3 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1>;
fn tint_symbol_4(tint_symbol_2 : u32) {
tint_symbol_3[0] = tint_symbol_2;
fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn main() {
tint_symbol_4(tint_symbol_1[0]);
return;
let inner_result = main_inner(sample_index_1, mask_in_1[0]);
value[0] = inner_result;
}
)";
@@ -805,6 +73,10 @@ fn main([[builtin(sample_mask)]] mask_in : u32
)";
auto* expect = R"(
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1>;
fn filter(mask : u32) -> u32 {
return (mask & 3u);
}
@@ -813,18 +85,14 @@ fn set_mask(input : u32) -> u32 {
return input;
}
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_2 : array<u32, 1>;
fn tint_symbol_3(tint_symbol_1 : u32) {
tint_symbol_2[0] = tint_symbol_1;
fn main_inner(mask_in : u32) -> u32 {
return set_mask(filter(mask_in));
}
[[stage(fragment)]]
fn main() {
tint_symbol_3(set_mask(filter(tint_symbol[0])));
return;
let inner_result = main_inner(mask_in_1[0]);
value[0] = inner_result;
}
)";
@@ -833,128 +101,6 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, EmitVertexPointSize_Basic) {
auto* src = R"(
fn non_entry_point() {
}
[[stage(vertex)]]
fn main() -> [[builtin(position)]] vec4<f32> {
non_entry_point();
return vec4<f32>();
}
)";
auto* expect = R"(
[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var<out> tint_pointsize : f32;
fn non_entry_point() {
}
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
fn tint_symbol_2(tint_symbol : vec4<f32>) {
tint_symbol_1 = tint_symbol;
}
[[stage(vertex)]]
fn main() {
tint_pointsize = 1.0;
non_entry_point();
tint_symbol_2(vec4<f32>());
return;
}
)";
DataMap data;
data.Add<Spirv::Config>(true);
auto got = Run<Spirv>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, EmitVertexPointSize_MultipleVertexShaders) {
auto* src = R"(
[[stage(vertex)]]
fn main1() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
[[stage(vertex)]]
fn main2() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
[[stage(vertex)]]
fn main3() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
)";
auto* expect = R"(
[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var<out> tint_pointsize : f32;
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
fn tint_symbol_2(tint_symbol : vec4<f32>) {
tint_symbol_1 = tint_symbol;
}
[[stage(vertex)]]
fn main1() {
tint_pointsize = 1.0;
tint_symbol_2(vec4<f32>());
return;
}
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_4 : vec4<f32>;
fn tint_symbol_5(tint_symbol_3 : vec4<f32>) {
tint_symbol_4 = tint_symbol_3;
}
[[stage(vertex)]]
fn main2() {
tint_pointsize = 1.0;
tint_symbol_5(vec4<f32>());
return;
}
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_7 : vec4<f32>;
fn tint_symbol_8(tint_symbol_6 : vec4<f32>) {
tint_symbol_7 = tint_symbol_6;
}
[[stage(vertex)]]
fn main3() {
tint_pointsize = 1.0;
tint_symbol_8(vec4<f32>());
return;
}
)";
DataMap data;
data.Add<Spirv::Config>(true);
auto got = Run<Spirv>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, EmitVertexPointSize_NoVertexShaders) {
auto* src = R"(
[[stage(compute), workgroup_size(8, 1, 1)]]
fn main() {
}
)";
DataMap data;
data.Add<Spirv::Config>(true);
auto got = Run<Spirv>(src, data);
EXPECT_EQ(src, str(got));
}
TEST_F(SpirvTest, AddEmptyEntryPoint) {
auto* src = R"()";
@@ -986,35 +132,35 @@ fn frag_main([[builtin(sample_index)]] sample_index : u32,
)";
auto* expect = R"(
[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var<out> tint_pointsize : f32;
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> value : vec4<f32>;
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_1 : vec4<f32>;
[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var<out> vertex_point_size : f32;
fn tint_symbol_2(tint_symbol : vec4<f32>) {
tint_symbol_1 = tint_symbol;
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value_1 : array<u32, 1>;
fn vert_main_inner() -> vec4<f32> {
return vec4<f32>();
}
[[stage(vertex)]]
fn vert_main() {
tint_pointsize = 1.0;
tint_symbol_2(vec4<f32>());
return;
let inner_result = vert_main_inner();
value = inner_result;
vertex_point_size = 1.0;
}
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_3 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> tint_symbol_4 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> tint_symbol_6 : array<u32, 1>;
fn tint_symbol_7(tint_symbol_5 : u32) {
tint_symbol_6[0] = tint_symbol_5;
fn frag_main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn frag_main() {
tint_symbol_7(tint_symbol_4[0]);
return;
let inner_result_1 = frag_main_inner(sample_index_1, mask_in_1[0]);
value_1[0] = inner_result_1;
}
)";

View File

@@ -61,12 +61,15 @@ TEST_F(BuilderTest, EntryPoint_Parameters) {
// Input storage class, retaining their decorations.
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %9 "frag_main" %1 %5
OpExecutionMode %9 OriginUpperLeft
OpName %1 "tint_symbol"
OpName %5 "tint_symbol_1"
OpName %9 "frag_main"
OpName %17 "col"
OpEntryPoint Fragment %19 "frag_main" %1 %5
OpExecutionMode %19 OriginUpperLeft
OpName %1 "coord_1"
OpName %5 "loc1_1"
OpName %9 "frag_main_inner"
OpName %10 "coord"
OpName %11 "loc1"
OpName %15 "col"
OpName %19 "frag_main"
OpDecorate %1 BuiltIn FragCoord
OpDecorate %5 Location 1
%4 = OpTypeFloat 32
@@ -76,19 +79,25 @@ OpDecorate %5 Location 1
%6 = OpTypePointer Input %4
%5 = OpVariable %6 Input
%8 = OpTypeVoid
%7 = OpTypeFunction %8
%11 = OpTypeInt 32 0
%12 = OpConstant %11 0
%18 = OpTypePointer Function %4
%19 = OpConstantNull %4
%7 = OpTypeFunction %8 %3 %4
%16 = OpTypePointer Function %4
%17 = OpConstantNull %4
%18 = OpTypeFunction %8
%9 = OpFunction %8 None %7
%10 = OpLabel
%17 = OpVariable %18 Function %19
%13 = OpAccessChain %6 %1 %12
%14 = OpLoad %4 %13
%15 = OpLoad %4 %5
%16 = OpFMul %4 %14 %15
OpStore %17 %16
%10 = OpFunctionParameter %3
%11 = OpFunctionParameter %4
%12 = OpLabel
%15 = OpVariable %16 Function %17
%13 = OpCompositeExtract %4 %10 0
%14 = OpFMul %4 %13 %11
OpStore %15 %14
OpReturn
OpFunctionEnd
%19 = OpFunction %8 None %18
%20 = OpLabel
%22 = OpLoad %3 %1
%23 = OpLoad %4 %5
%21 = OpFunctionCall %8 %9 %22 %23
OpReturn
OpFunctionEnd
)");
@@ -125,13 +134,13 @@ TEST_F(BuilderTest, EntryPoint_ReturnValue) {
// Output storage class, and the return statements are replaced with stores.
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %14 "frag_main" %1 %4
OpExecutionMode %14 OriginUpperLeft
OpName %1 "tint_symbol"
OpName %4 "tint_symbol_2"
OpName %10 "tint_symbol_3"
OpName %11 "tint_symbol_1"
OpName %14 "frag_main"
OpEntryPoint Fragment %21 "frag_main" %1 %4
OpExecutionMode %21 OriginUpperLeft
OpName %1 "loc_in_1"
OpName %4 "value"
OpName %9 "frag_main_inner"
OpName %10 "loc_in"
OpName %21 "frag_main"
OpDecorate %1 Location 0
OpDecorate %1 Flat
OpDecorate %4 Location 0
@@ -142,30 +151,29 @@ OpDecorate %4 Location 0
%5 = OpTypePointer Output %6
%7 = OpConstantNull %6
%4 = OpVariable %5 Output %7
%9 = OpTypeVoid
%8 = OpTypeFunction %9 %6
%13 = OpTypeFunction %9
%17 = OpConstant %3 10
%19 = OpTypeBool
%23 = OpConstant %6 0.5
%25 = OpConstant %6 1
%10 = OpFunction %9 None %8
%11 = OpFunctionParameter %6
%12 = OpLabel
OpStore %4 %11
OpReturn
OpFunctionEnd
%14 = OpFunction %9 None %13
%8 = OpTypeFunction %6 %3
%12 = OpConstant %3 10
%14 = OpTypeBool
%17 = OpConstant %6 0.5
%18 = OpConstant %6 1
%20 = OpTypeVoid
%19 = OpTypeFunction %20
%9 = OpFunction %6 None %8
%10 = OpFunctionParameter %3
%11 = OpLabel
%13 = OpUGreaterThan %14 %10 %12
OpSelectionMerge %15 None
OpBranchConditional %13 %16 %15
%16 = OpLabel
OpReturnValue %17
%15 = OpLabel
%16 = OpLoad %3 %1
%18 = OpUGreaterThan %19 %16 %17
OpSelectionMerge %20 None
OpBranchConditional %18 %21 %20
%21 = OpLabel
%22 = OpFunctionCall %9 %10 %23
OpReturn
%20 = OpLabel
%24 = OpFunctionCall %9 %10 %25
OpReturnValue %18
OpFunctionEnd
%21 = OpFunction %20 None %19
%22 = OpLabel
%24 = OpLoad %3 %1
%23 = OpFunctionCall %6 %9 %24
OpStore %4 %23
OpReturn
OpFunctionEnd
)");
@@ -214,31 +222,30 @@ TEST_F(BuilderTest, EntryPoint_SharedStruct) {
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %23 "vert_main" %1 %5
OpEntryPoint Vertex %22 "vert_main" %1 %5
OpEntryPoint Fragment %32 "frag_main" %9 %11 %13
OpExecutionMode %32 OriginUpperLeft
OpExecutionMode %32 DepthReplacing
OpName %1 "tint_symbol_1"
OpName %5 "tint_symbol_2"
OpName %9 "tint_symbol_4"
OpName %11 "tint_symbol_5"
OpName %13 "tint_symbol_8"
OpName %16 "Interface"
OpMemberName %16 0 "value"
OpMemberName %16 1 "pos"
OpName %17 "tint_symbol_3"
OpName %18 "tint_symbol"
OpName %23 "vert_main"
OpName %29 "tint_symbol_9"
OpName %30 "tint_symbol_7"
OpName %1 "value_1"
OpName %5 "pos_1"
OpName %9 "value_2"
OpName %11 "pos_2"
OpName %13 "value_3"
OpName %15 "Interface"
OpMemberName %15 0 "value"
OpMemberName %15 1 "pos"
OpName %16 "vert_main_inner"
OpName %22 "vert_main"
OpName %28 "frag_main_inner"
OpName %29 "inputs"
OpName %32 "frag_main"
OpDecorate %1 Location 1
OpDecorate %5 BuiltIn Position
OpDecorate %9 Location 1
OpDecorate %11 BuiltIn FragCoord
OpDecorate %13 BuiltIn FragDepth
OpMemberDecorate %16 0 Offset 0
OpMemberDecorate %16 1 Offset 16
OpMemberDecorate %15 0 Offset 0
OpMemberDecorate %15 1 Offset 16
%3 = OpTypeFloat 32
%2 = OpTypePointer Output %3
%4 = OpConstantNull %3
@@ -252,40 +259,39 @@ OpMemberDecorate %16 1 Offset 16
%12 = OpTypePointer Input %7
%11 = OpVariable %12 Input
%13 = OpVariable %2 Output %4
%15 = OpTypeVoid
%16 = OpTypeStruct %3 %7
%14 = OpTypeFunction %15 %16
%22 = OpTypeFunction %15
%26 = OpConstant %3 42
%27 = OpConstantComposite %16 %26 %8
%28 = OpTypeFunction %15 %3
%17 = OpFunction %15 None %14
%18 = OpFunctionParameter %16
%19 = OpLabel
%20 = OpCompositeExtract %3 %18 0
OpStore %1 %20
%21 = OpCompositeExtract %7 %18 1
OpStore %5 %21
%15 = OpTypeStruct %3 %7
%14 = OpTypeFunction %15
%18 = OpConstant %3 42
%19 = OpConstantComposite %15 %18 %8
%21 = OpTypeVoid
%20 = OpTypeFunction %21
%27 = OpTypeFunction %3 %15
%16 = OpFunction %15 None %14
%17 = OpLabel
OpReturnValue %19
OpFunctionEnd
%22 = OpFunction %21 None %20
%23 = OpLabel
%24 = OpFunctionCall %15 %16
%25 = OpCompositeExtract %3 %24 0
OpStore %1 %25
%26 = OpCompositeExtract %7 %24 1
OpStore %5 %26
OpReturn
OpFunctionEnd
%23 = OpFunction %15 None %22
%24 = OpLabel
%25 = OpFunctionCall %15 %17 %27
OpReturn
%28 = OpFunction %3 None %27
%29 = OpFunctionParameter %15
%30 = OpLabel
%31 = OpCompositeExtract %3 %29 0
OpReturnValue %31
OpFunctionEnd
%29 = OpFunction %15 None %28
%30 = OpFunctionParameter %3
%31 = OpLabel
OpStore %13 %30
OpReturn
OpFunctionEnd
%32 = OpFunction %15 None %22
%32 = OpFunction %21 None %20
%33 = OpLabel
%34 = OpLoad %3 %9
%35 = OpLoad %7 %11
%36 = OpCompositeConstruct %16 %34 %35
%38 = OpCompositeExtract %3 %36 0
%37 = OpFunctionCall %15 %29 %38
%35 = OpLoad %3 %9
%36 = OpLoad %7 %11
%37 = OpCompositeConstruct %15 %35 %36
%34 = OpFunctionCall %3 %28 %37
OpStore %13 %34
OpReturn
OpFunctionEnd
)");