spirv: Handle sample_mask in shader IO transform

This is easy to do while we are processing builtins in the main
transform now that we use wrapper functions.

This is step towards removing the sanitizers completely.

Change-Id: If5472ce552e3cce1e5905916eeffa8fef90461c9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63585
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-09-07 18:59:21 +00:00
parent 1b9ed7de4a
commit 676ec7cf99
5 changed files with 67 additions and 197 deletions

View File

@ -67,6 +67,12 @@ bool IsShaderIODecoration(const ast::Decoration* deco) {
ast::InvariantDecoration, ast::LocationDecoration>(); ast::InvariantDecoration, ast::LocationDecoration>();
} }
// Returns true if `decos` contains a `sample_mask` builtin.
bool HasSampleMask(const ast::DecorationList& decos) {
auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos);
return builtin && builtin->value() == ast::Builtin::kSampleMask;
}
} // namespace } // namespace
/// State holds the current transform state for a single entry point. /// State holds the current transform state for a single entry point.
@ -166,9 +172,16 @@ struct CanonicalizeEntryPointIO::State {
// Create the global variable and use its value for the shader input. // Create the global variable and use its value for the shader input.
auto var = ctx.dst->Symbols().New(name); auto var = ctx.dst->Symbols().New(name);
ast::Expression* value = ctx.dst->Expr(var);
if (HasSampleMask(attributes)) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then load the first element.
type = ctx.dst->ty.array(type, 1);
value = ctx.dst->IndexAccessor(value, 0);
}
ctx.dst->Global(var, type, ast::StorageClass::kInput, ctx.dst->Global(var, type, ast::StorageClass::kInput,
std::move(attributes)); std::move(attributes));
return ctx.dst->Expr(var); return value;
} else if (cfg.shader_style == ShaderStyle::kMsl && } else if (cfg.shader_style == ShaderStyle::kMsl &&
ast::HasDecoration<ast::BuiltinDecoration>(attributes)) { ast::HasDecoration<ast::BuiltinDecoration>(attributes)) {
// If this input is a builtin and we are targeting MSL, then add it to the // If this input is a builtin and we are targeting MSL, then add it to the
@ -303,9 +316,7 @@ struct CanonicalizeEntryPointIO::State {
void AddFixedSampleMask() { void AddFixedSampleMask() {
// Check the existing output values for a sample mask builtin. // Check the existing output values for a sample mask builtin.
for (auto& outval : wrapper_output_values) { for (auto& outval : wrapper_output_values) {
auto* builtin = if (HasSampleMask(outval.attributes)) {
ast::GetDecoration<ast::BuiltinDecoration>(outval.attributes);
if (builtin && builtin->value() == ast::Builtin::kSampleMask) {
// Combine the authored sample mask with the fixed mask. // Combine the authored sample mask with the fixed mask.
outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask); outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
return; return;
@ -390,7 +401,7 @@ struct CanonicalizeEntryPointIO::State {
} }
/// Create and assign the wrapper function's output variables. /// Create and assign the wrapper function's output variables.
void CreateOutputVariables() { void CreateSpirvOutputVariables() {
for (auto& outval : wrapper_output_values) { for (auto& outval : wrapper_output_values) {
// Disable validation for use of the `output` storage class. // Disable validation for use of the `output` storage class.
ast::DecorationList attributes = std::move(outval.attributes); ast::DecorationList attributes = std::move(outval.attributes);
@ -400,9 +411,17 @@ struct CanonicalizeEntryPointIO::State {
// Create the global variable and assign it the output value. // Create the global variable and assign it the output value.
auto name = ctx.dst->Symbols().New(outval.name); auto name = ctx.dst->Symbols().New(outval.name);
ctx.dst->Global(name, outval.type, ast::StorageClass::kOutput, auto* type = outval.type;
ast::Expression* lhs = ctx.dst->Expr(name);
if (HasSampleMask(attributes)) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then store to the first element.
type = ctx.dst->ty.array(type, 1);
lhs = ctx.dst->IndexAccessor(lhs, 0);
}
ctx.dst->Global(name, type, ast::StorageClass::kOutput,
std::move(attributes)); std::move(attributes));
wrapper_body.push_back(ctx.dst->Assign(name, outval.value)); wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
} }
} }
@ -498,7 +517,7 @@ struct CanonicalizeEntryPointIO::State {
// Produce the entry point outputs, if necessary. // Produce the entry point outputs, if necessary.
if (!wrapper_output_values.empty()) { if (!wrapper_output_values.empty()) {
if (cfg.shader_style == ShaderStyle::kSpirv) { if (cfg.shader_style == ShaderStyle::kSpirv) {
CreateOutputVariables(); CreateSpirvOutputVariables();
} else { } else {
auto* output_struct = CreateOutputStruct(); auto* output_struct = CreateOutputStruct();
wrapper_ret_type = [&, output_struct] { wrapper_ret_type = [&, output_struct] {

View File

@ -491,7 +491,7 @@ fn frag_main() -> FragOutput {
[[builtin(frag_depth), internal(disable_validation__ignore_storage_class)]] var<out> depth_1 : f32; [[builtin(frag_depth), internal(disable_validation__ignore_storage_class)]] var<out> depth_1 : f32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> mask_1 : u32; [[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> mask_1 : array<u32, 1>;
struct FragOutput { struct FragOutput {
color : vec4<f32>; color : vec4<f32>;
@ -512,7 +512,7 @@ fn frag_main() {
let inner_result = frag_main_inner(); let inner_result = frag_main_inner();
color_1 = inner_result.color; color_1 = inner_result.color;
depth_1 = inner_result.depth; depth_1 = inner_result.depth;
mask_1 = inner_result.mask; mask_1[0] = inner_result.mask;
} }
)"; )";
@ -2163,6 +2163,42 @@ fn vert_main() -> tint_symbol {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) {
auto* src = R"(
[[stage(fragment)]]
fn main([[builtin(sample_index)]] sample_index : u32,
[[builtin(sample_mask)]] mask_in : u32
) -> [[builtin(sample_mask)]] u32 {
return mask_in;
}
)";
auto* expect = R"(
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1>;
fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn main() {
let inner_result = main_inner(sample_index_1, mask_in_1[0]);
value[0] = inner_result;
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
auto got = Run<CanonicalizeEntryPointIO>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace } // namespace
} // namespace transform } // namespace transform
} // namespace tint } // namespace tint

View File

@ -68,61 +68,13 @@ Output Spirv::Run(const Program* in, const DataMap& data) {
ProgramBuilder builder; ProgramBuilder builder;
CloneContext ctx(&builder, &transformedInput.program); CloneContext ctx(&builder, &transformedInput.program);
HandleSampleMaskBuiltins(ctx); // TODO(jrprice): Move the sanitizer into the backend.
ctx.Clone(); ctx.Clone();
builder.SetTransformApplied(this); builder.SetTransformApplied(this);
return Output{Program(std::move(builder))}; return Output{Program(std::move(builder))};
} }
void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
// Find global variables decorated with [[builtin(sample_mask)]] and
// change their type from `u32` to `array<u32, 1>`, as required by Vulkan.
//
// Before:
// ```
// [[builtin(sample_mask)]] var<out> mask_out : u32;
// fn main() {
// mask_out = 1u;
// }
// ```
// After:
// ```
// [[builtin(sample_mask)]] var<out> mask_out : array<u32, 1>;
// fn main() {
// mask_out[0] = 1u;
// }
// ```
for (auto* var : ctx.src->AST().GlobalVariables()) {
for (auto* deco : var->decorations()) {
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
if (builtin->value() != ast::Builtin::kSampleMask) {
continue;
}
// Use the same name as the old variable.
auto var_name = ctx.Clone(var->symbol());
// Use `array<u32, 1>` for the new variable.
auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
// Create the new variable.
auto* var_arr = ctx.dst->Var(var->source(), var_name, type,
var->declared_storage_class(), nullptr,
ctx.Clone(var->decorations()));
// Replace the variable with the arrayed version.
ctx.Replace(var, var_arr);
// Replace all uses of the old variable with `var_arr[0]`.
for (auto* user : ctx.src->Sem().Get(var)->Users()) {
auto* new_ident = ctx.dst->IndexAccessor(
ctx.dst->Expr(var_arr->symbol()), ctx.dst->Expr(0));
ctx.Replace<ast::Expression>(user->Declaration(), new_ident);
}
}
}
}
}
Spirv::Config::Config(bool emit_vps, bool disable_wi) Spirv::Config::Config(bool emit_vps, bool disable_wi)
: emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {} : emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {}

View File

@ -67,10 +67,6 @@ class Spirv : public Castable<Spirv, Transform> {
/// @param data optional extra transform-specific input data /// @param data optional extra transform-specific input data
/// @returns the transformation result /// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override; Output Run(const Program* program, const DataMap& data = {}) override;
private:
/// Change type of sample mask builtin variables to single element arrays.
void HandleSampleMaskBuiltins(CloneContext& ctx) const;
}; };
} // namespace transform } // namespace transform

View File

@ -22,140 +22,7 @@ namespace {
using SpirvTest = TransformTest; using SpirvTest = TransformTest;
TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) { // TODO(jrprice): Remove this file when we remove the sanitizers.
auto* src = R"(
[[stage(fragment)]]
fn main([[builtin(sample_index)]] sample_index : u32,
[[builtin(sample_mask)]] mask_in : u32
) -> [[builtin(sample_mask)]] u32 {
return mask_in;
}
)";
auto* expect = R"(
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1u>;
fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn main() {
let inner_result = main_inner(sample_index_1, mask_in_1[0]);
value[0] = inner_result;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) {
auto* src = R"(
fn filter(mask: u32) -> u32 {
return (mask & 3u);
}
fn set_mask(input : u32) -> u32 {
return input;
}
[[stage(fragment)]]
fn main([[builtin(sample_mask)]] mask_in : u32
) -> [[builtin(sample_mask)]] u32 {
return set_mask(filter(mask_in));
}
)";
auto* expect = R"(
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1u>;
fn filter(mask : u32) -> u32 {
return (mask & 3u);
}
fn set_mask(input : u32) -> u32 {
return input;
}
fn main_inner(mask_in : u32) -> u32 {
return set_mask(filter(mask_in));
}
[[stage(fragment)]]
fn main() {
let inner_result = main_inner(mask_in_1[0]);
value[0] = inner_result;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
// Test that different transforms within the sanitizer interact correctly.
TEST_F(SpirvTest, MultipleTransforms) {
auto* src = R"(
[[stage(vertex)]]
fn vert_main() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
[[stage(fragment)]]
fn frag_main([[builtin(sample_index)]] sample_index : u32,
[[builtin(sample_mask)]] mask_in : u32)
-> [[builtin(sample_mask)]] u32 {
return mask_in;
}
)";
auto* expect = R"(
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> value : vec4<f32>;
[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var<out> vertex_point_size : f32;
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value_1 : array<u32, 1u>;
fn vert_main_inner() -> vec4<f32> {
return vec4<f32>();
}
[[stage(vertex)]]
fn vert_main() {
let inner_result = vert_main_inner();
value = inner_result;
vertex_point_size = 1.0;
}
fn frag_main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn frag_main() {
let inner_result_1 = frag_main_inner(sample_index_1, mask_in_1[0]);
value_1[0] = inner_result_1;
}
)";
DataMap data;
data.Add<Spirv::Config>(true);
auto got = Run<Spirv>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace } // namespace
} // namespace transform } // namespace transform