spirv: Handle sample_mask in shader IO transform

This is easy to do while we are processing builtins in the main
transform now that we use wrapper functions.

This is step towards removing the sanitizers completely.

Change-Id: If5472ce552e3cce1e5905916eeffa8fef90461c9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63585
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-09-07 18:59:21 +00:00
parent 1b9ed7de4a
commit 676ec7cf99
5 changed files with 67 additions and 197 deletions

View File

@ -67,6 +67,12 @@ bool IsShaderIODecoration(const ast::Decoration* deco) {
ast::InvariantDecoration, ast::LocationDecoration>();
}
// Returns true if `decos` contains a `sample_mask` builtin.
bool HasSampleMask(const ast::DecorationList& decos) {
auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos);
return builtin && builtin->value() == ast::Builtin::kSampleMask;
}
} // namespace
/// State holds the current transform state for a single entry point.
@ -166,9 +172,16 @@ struct CanonicalizeEntryPointIO::State {
// Create the global variable and use its value for the shader input.
auto var = ctx.dst->Symbols().New(name);
ast::Expression* value = ctx.dst->Expr(var);
if (HasSampleMask(attributes)) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then load the first element.
type = ctx.dst->ty.array(type, 1);
value = ctx.dst->IndexAccessor(value, 0);
}
ctx.dst->Global(var, type, ast::StorageClass::kInput,
std::move(attributes));
return ctx.dst->Expr(var);
return value;
} else if (cfg.shader_style == ShaderStyle::kMsl &&
ast::HasDecoration<ast::BuiltinDecoration>(attributes)) {
// If this input is a builtin and we are targeting MSL, then add it to the
@ -303,9 +316,7 @@ struct CanonicalizeEntryPointIO::State {
void AddFixedSampleMask() {
// Check the existing output values for a sample mask builtin.
for (auto& outval : wrapper_output_values) {
auto* builtin =
ast::GetDecoration<ast::BuiltinDecoration>(outval.attributes);
if (builtin && builtin->value() == ast::Builtin::kSampleMask) {
if (HasSampleMask(outval.attributes)) {
// Combine the authored sample mask with the fixed mask.
outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
return;
@ -390,7 +401,7 @@ struct CanonicalizeEntryPointIO::State {
}
/// Create and assign the wrapper function's output variables.
void CreateOutputVariables() {
void CreateSpirvOutputVariables() {
for (auto& outval : wrapper_output_values) {
// Disable validation for use of the `output` storage class.
ast::DecorationList attributes = std::move(outval.attributes);
@ -400,9 +411,17 @@ struct CanonicalizeEntryPointIO::State {
// Create the global variable and assign it the output value.
auto name = ctx.dst->Symbols().New(outval.name);
ctx.dst->Global(name, outval.type, ast::StorageClass::kOutput,
auto* type = outval.type;
ast::Expression* lhs = ctx.dst->Expr(name);
if (HasSampleMask(attributes)) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then store to the first element.
type = ctx.dst->ty.array(type, 1);
lhs = ctx.dst->IndexAccessor(lhs, 0);
}
ctx.dst->Global(name, type, ast::StorageClass::kOutput,
std::move(attributes));
wrapper_body.push_back(ctx.dst->Assign(name, outval.value));
wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
}
}
@ -498,7 +517,7 @@ struct CanonicalizeEntryPointIO::State {
// Produce the entry point outputs, if necessary.
if (!wrapper_output_values.empty()) {
if (cfg.shader_style == ShaderStyle::kSpirv) {
CreateOutputVariables();
CreateSpirvOutputVariables();
} else {
auto* output_struct = CreateOutputStruct();
wrapper_ret_type = [&, output_struct] {

View File

@ -491,7 +491,7 @@ fn frag_main() -> FragOutput {
[[builtin(frag_depth), internal(disable_validation__ignore_storage_class)]] var<out> depth_1 : f32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> mask_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> mask_1 : array<u32, 1>;
struct FragOutput {
color : vec4<f32>;
@ -512,7 +512,7 @@ fn frag_main() {
let inner_result = frag_main_inner();
color_1 = inner_result.color;
depth_1 = inner_result.depth;
mask_1 = inner_result.mask;
mask_1[0] = inner_result.mask;
}
)";
@ -2163,6 +2163,42 @@ fn vert_main() -> tint_symbol {
EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) {
auto* src = R"(
[[stage(fragment)]]
fn main([[builtin(sample_index)]] sample_index : u32,
[[builtin(sample_mask)]] mask_in : u32
) -> [[builtin(sample_mask)]] u32 {
return mask_in;
}
)";
auto* expect = R"(
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1>;
fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn main() {
let inner_result = main_inner(sample_index_1, mask_in_1[0]);
value[0] = inner_result;
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
auto got = Run<CanonicalizeEntryPointIO>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@ -68,61 +68,13 @@ Output Spirv::Run(const Program* in, const DataMap& data) {
ProgramBuilder builder;
CloneContext ctx(&builder, &transformedInput.program);
HandleSampleMaskBuiltins(ctx);
// TODO(jrprice): Move the sanitizer into the backend.
ctx.Clone();
builder.SetTransformApplied(this);
return Output{Program(std::move(builder))};
}
void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
// Find global variables decorated with [[builtin(sample_mask)]] and
// change their type from `u32` to `array<u32, 1>`, as required by Vulkan.
//
// Before:
// ```
// [[builtin(sample_mask)]] var<out> mask_out : u32;
// fn main() {
// mask_out = 1u;
// }
// ```
// After:
// ```
// [[builtin(sample_mask)]] var<out> mask_out : array<u32, 1>;
// fn main() {
// mask_out[0] = 1u;
// }
// ```
for (auto* var : ctx.src->AST().GlobalVariables()) {
for (auto* deco : var->decorations()) {
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
if (builtin->value() != ast::Builtin::kSampleMask) {
continue;
}
// Use the same name as the old variable.
auto var_name = ctx.Clone(var->symbol());
// Use `array<u32, 1>` for the new variable.
auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
// Create the new variable.
auto* var_arr = ctx.dst->Var(var->source(), var_name, type,
var->declared_storage_class(), nullptr,
ctx.Clone(var->decorations()));
// Replace the variable with the arrayed version.
ctx.Replace(var, var_arr);
// Replace all uses of the old variable with `var_arr[0]`.
for (auto* user : ctx.src->Sem().Get(var)->Users()) {
auto* new_ident = ctx.dst->IndexAccessor(
ctx.dst->Expr(var_arr->symbol()), ctx.dst->Expr(0));
ctx.Replace<ast::Expression>(user->Declaration(), new_ident);
}
}
}
}
}
Spirv::Config::Config(bool emit_vps, bool disable_wi)
: emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {}

View File

@ -67,10 +67,6 @@ class Spirv : public Castable<Spirv, Transform> {
/// @param data optional extra transform-specific input data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override;
private:
/// Change type of sample mask builtin variables to single element arrays.
void HandleSampleMaskBuiltins(CloneContext& ctx) const;
};
} // namespace transform

View File

@ -22,140 +22,7 @@ namespace {
using SpirvTest = TransformTest;
TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
auto* src = R"(
[[stage(fragment)]]
fn main([[builtin(sample_index)]] sample_index : u32,
[[builtin(sample_mask)]] mask_in : u32
) -> [[builtin(sample_mask)]] u32 {
return mask_in;
}
)";
auto* expect = R"(
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1u>;
fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn main() {
let inner_result = main_inner(sample_index_1, mask_in_1[0]);
value[0] = inner_result;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) {
auto* src = R"(
fn filter(mask: u32) -> u32 {
return (mask & 3u);
}
fn set_mask(input : u32) -> u32 {
return input;
}
[[stage(fragment)]]
fn main([[builtin(sample_mask)]] mask_in : u32
) -> [[builtin(sample_mask)]] u32 {
return set_mask(filter(mask_in));
}
)";
auto* expect = R"(
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1u>;
fn filter(mask : u32) -> u32 {
return (mask & 3u);
}
fn set_mask(input : u32) -> u32 {
return input;
}
fn main_inner(mask_in : u32) -> u32 {
return set_mask(filter(mask_in));
}
[[stage(fragment)]]
fn main() {
let inner_result = main_inner(mask_in_1[0]);
value[0] = inner_result;
}
)";
auto got = Run<Spirv>(src);
EXPECT_EQ(expect, str(got));
}
// Test that different transforms within the sanitizer interact correctly.
TEST_F(SpirvTest, MultipleTransforms) {
auto* src = R"(
[[stage(vertex)]]
fn vert_main() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
[[stage(fragment)]]
fn frag_main([[builtin(sample_index)]] sample_index : u32,
[[builtin(sample_mask)]] mask_in : u32)
-> [[builtin(sample_mask)]] u32 {
return mask_in;
}
)";
auto* expect = R"(
[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> value : vec4<f32>;
[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var<out> vertex_point_size : f32;
[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value_1 : array<u32, 1u>;
fn vert_main_inner() -> vec4<f32> {
return vec4<f32>();
}
[[stage(vertex)]]
fn vert_main() {
let inner_result = vert_main_inner();
value = inner_result;
vertex_point_size = 1.0;
}
fn frag_main_inner(sample_index : u32, mask_in : u32) -> u32 {
return mask_in;
}
[[stage(fragment)]]
fn frag_main() {
let inner_result_1 = frag_main_inner(sample_index_1, mask_in_1[0]);
value_1[0] = inner_result_1;
}
)";
DataMap data;
data.Add<Spirv::Config>(true);
auto got = Run<Spirv>(src, data);
EXPECT_EQ(expect, str(got));
}
// TODO(jrprice): Remove this file when we remove the sanitizers.
} // namespace
} // namespace transform