Add knob for omitting certain storage classes in Robustness transform

BUG=tint:779

Change-Id: Ibcedb998671dd2bf189cc795299ea92846196ade
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/66780
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ryan Harrison 2021-10-19 16:51:23 +00:00 committed by Tint LUCI CQ
parent 5f5d43ff51
commit 7d0fc07b20
3 changed files with 395 additions and 7 deletions

View File

@ -22,9 +22,11 @@
#include "src/sem/block_statement.h" #include "src/sem/block_statement.h"
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/reference_type.h"
#include "src/sem/statement.h" #include "src/sem/statement.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness); TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness::Config);
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -34,6 +36,9 @@ struct Robustness::State {
/// The clone context /// The clone context
CloneContext& ctx; CloneContext& ctx;
/// Set of storage classes to not apply the transform to
std::unordered_set<ast::StorageClass> omitted_classes;
/// Applies the transformation state to `ctx`. /// Applies the transformation state to `ctx`.
void Transform() { void Transform() {
ctx.ReplaceAll( ctx.ReplaceAll(
@ -46,7 +51,14 @@ struct Robustness::State {
/// @return the clamped replacement expression, or nullptr if `expr` should be /// @return the clamped replacement expression, or nullptr if `expr` should be
/// cloned without changes. /// cloned without changes.
ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr) { ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr) {
auto* ret_type = ctx.src->Sem().Get(expr->array)->Type()->UnwrapRef(); auto* ret_type = ctx.src->Sem().Get(expr->array)->Type();
auto* ref = ret_type->As<sem::Reference>();
if (ref && omitted_classes.count(ref->StorageClass()) != 0) {
return nullptr;
}
auto* ret_unwrapped = ret_type->UnwrapRef();
ProgramBuilder& b = *ctx.dst; ProgramBuilder& b = *ctx.dst;
using u32 = ProgramBuilder::u32; using u32 = ProgramBuilder::u32;
@ -62,12 +74,12 @@ struct Robustness::State {
Value size; // size of the array, vector or matrix Value size; // size of the array, vector or matrix
size.is_signed = false; // size is always unsigned size.is_signed = false; // size is always unsigned
if (auto* vec = ret_type->As<sem::Vector>()) { if (auto* vec = ret_unwrapped->As<sem::Vector>()) {
size.u32 = vec->Width(); size.u32 = vec->Width();
} else if (auto* arr = ret_type->As<sem::Array>()) { } else if (auto* arr = ret_unwrapped->As<sem::Array>()) {
size.u32 = arr->Count(); size.u32 = arr->Count();
} else if (auto* mat = ret_type->As<sem::Matrix>()) { } else if (auto* mat = ret_unwrapped->As<sem::Matrix>()) {
// The row accessor would have been an embedded array accessor and already // The row accessor would have been an embedded array accessor and already
// handled, so we just need to do columns here. // handled, so we just need to do columns here.
size.u32 = mat->columns(); size.u32 = mat->columns();
@ -76,7 +88,7 @@ struct Robustness::State {
} }
if (size.u32 == 0) { if (size.u32 == 0) {
if (!ret_type->Is<sem::Array>()) { if (!ret_unwrapped->Is<sem::Array>()) {
b.Diagnostics().add_error(diag::System::Transform, b.Diagnostics().add_error(diag::System::Transform,
"invalid 0 sized non-array", expr->source); "invalid 0 sized non-array", expr->source);
return nullptr; return nullptr;
@ -268,11 +280,34 @@ struct Robustness::State {
} }
}; };
Robustness::Config::Config() = default;
Robustness::Config::Config(const Config&) = default;
Robustness::Config::~Config() = default;
Robustness::Config& Robustness::Config::operator=(const Config&) = default;
Robustness::Robustness() = default; Robustness::Robustness() = default;
Robustness::~Robustness() = default; Robustness::~Robustness() = default;
void Robustness::Run(CloneContext& ctx, const DataMap&, DataMap&) { void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
State state{ctx}; Config cfg;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
std::unordered_set<ast::StorageClass> omitted_classes;
for (auto sc : cfg.omitted_classes) {
switch (sc) {
case StorageClass::kUniform:
omitted_classes.insert(ast::StorageClass::kUniform);
break;
case StorageClass::kStorage:
omitted_classes.insert(ast::StorageClass::kStorage);
break;
}
}
State state{ctx, std::move(omitted_classes)};
state.Transform(); state.Transform();
ctx.Clone(); ctx.Clone();
} }

View File

@ -15,6 +15,8 @@
#ifndef SRC_TRANSFORM_ROBUSTNESS_H_ #ifndef SRC_TRANSFORM_ROBUSTNESS_H_
#define SRC_TRANSFORM_ROBUSTNESS_H_ #define SRC_TRANSFORM_ROBUSTNESS_H_
#include <unordered_set>
#include "src/transform/transform.h" #include "src/transform/transform.h"
// Forward declarations // Forward declarations
@ -34,6 +36,32 @@ namespace transform {
/// (array length - 1). /// (array length - 1).
class Robustness : public Castable<Robustness, Transform> { class Robustness : public Castable<Robustness, Transform> {
public: public:
/// Storage class to be skipped in the transform
enum class StorageClass {
kUniform,
kStorage,
};
/// Configuration options for the transform
struct Config : public Castable<Config, Data> {
/// Constructor
Config();
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// Assignment operator
/// @returns this Config
Config& operator=(const Config&);
/// Storage classes to omit from apply the transform to.
/// This allows for optimizing on hardware that provide safe accesses.
std::unordered_set<StorageClass> omitted_classes;
};
/// Constructor /// Constructor
Robustness(); Robustness();
/// Destructor /// Destructor

View File

@ -818,6 +818,331 @@ fn f() {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
const char* kOmitSourceShader = R"(
[[block]]
struct S {
a : array<f32, 4>;
b : array<f32>;
};
[[group(0), binding(0)]] var<storage, read> s : S;
type UArr = [[stride(16)]] array<f32, 4>;
[[block]] struct U {
a : UArr;
};
[[group(1), binding(0)]] var<uniform> u : U;
fn f() {
// Signed
var i32_sa1 : f32 = s.a[4];
var i32_sa2 : f32 = s.a[1];
var i32_sa3 : f32 = s.a[0];
var i32_sa4 : f32 = s.a[-1];
var i32_sa5 : f32 = s.a[-4];
var i32_sb1 : f32 = s.b[4];
var i32_sb2 : f32 = s.b[1];
var i32_sb3 : f32 = s.b[0];
var i32_sb4 : f32 = s.b[-1];
var i32_sb5 : f32 = s.b[-4];
var i32_ua1 : f32 = u.a[4];
var i32_ua2 : f32 = u.a[1];
var i32_ua3 : f32 = u.a[0];
var i32_ua4 : f32 = u.a[-1];
var i32_ua5 : f32 = u.a[-4];
// Unsigned
var u32_sa1 : f32 = s.a[0u];
var u32_sa2 : f32 = s.a[1u];
var u32_sa3 : f32 = s.a[3u];
var u32_sa4 : f32 = s.a[4u];
var u32_sa5 : f32 = s.a[10u];
var u32_sa6 : f32 = s.a[100u];
var u32_sb1 : f32 = s.b[0u];
var u32_sb2 : f32 = s.b[1u];
var u32_sb3 : f32 = s.b[3u];
var u32_sb4 : f32 = s.b[4u];
var u32_sb5 : f32 = s.b[10u];
var u32_sb6 : f32 = s.b[100u];
var u32_ua1 : f32 = u.a[0u];
var u32_ua2 : f32 = u.a[1u];
var u32_ua3 : f32 = u.a[3u];
var u32_ua4 : f32 = u.a[4u];
var u32_ua5 : f32 = u.a[10u];
var u32_ua6 : f32 = u.a[100u];
}
)";
TEST_F(RobustnessTest, OmitNone) {
auto* expect = R"(
[[block]]
struct S {
a : array<f32, 4>;
b : array<f32>;
};
[[group(0), binding(0)]] var<storage, read> s : S;
type UArr = [[stride(16)]] array<f32, 4>;
[[block]]
struct U {
a : UArr;
};
[[group(1), binding(0)]] var<uniform> u : U;
fn f() {
var i32_sa1 : f32 = s.a[3];
var i32_sa2 : f32 = s.a[1];
var i32_sa3 : f32 = s.a[0];
var i32_sa4 : f32 = s.a[0];
var i32_sa5 : f32 = s.a[0];
var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var i32_ua1 : f32 = u.a[3];
var i32_ua2 : f32 = u.a[1];
var i32_ua3 : f32 = u.a[0];
var i32_ua4 : f32 = u.a[0];
var i32_ua5 : f32 = u.a[0];
var u32_sa1 : f32 = s.a[0u];
var u32_sa2 : f32 = s.a[1u];
var u32_sa3 : f32 = s.a[3u];
var u32_sa4 : f32 = s.a[3u];
var u32_sa5 : f32 = s.a[3u];
var u32_sa6 : f32 = s.a[3u];
var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))];
var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))];
var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))];
var u32_ua1 : f32 = u.a[0u];
var u32_ua2 : f32 = u.a[1u];
var u32_ua3 : f32 = u.a[3u];
var u32_ua4 : f32 = u.a[3u];
var u32_ua5 : f32 = u.a[3u];
var u32_ua6 : f32 = u.a[3u];
}
)";
Robustness::Config cfg;
DataMap data;
data.Add<Robustness::Config>(cfg);
auto got = Run<Robustness>(kOmitSourceShader, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, OmitStorage) {
auto* expect = R"(
[[block]]
struct S {
a : array<f32, 4>;
b : array<f32>;
};
[[group(0), binding(0)]] var<storage, read> s : S;
type UArr = [[stride(16)]] array<f32, 4>;
[[block]]
struct U {
a : UArr;
};
[[group(1), binding(0)]] var<uniform> u : U;
fn f() {
var i32_sa1 : f32 = s.a[4];
var i32_sa2 : f32 = s.a[1];
var i32_sa3 : f32 = s.a[0];
var i32_sa4 : f32 = s.a[-1];
var i32_sa5 : f32 = s.a[-4];
var i32_sb1 : f32 = s.b[4];
var i32_sb2 : f32 = s.b[1];
var i32_sb3 : f32 = s.b[0];
var i32_sb4 : f32 = s.b[-1];
var i32_sb5 : f32 = s.b[-4];
var i32_ua1 : f32 = u.a[3];
var i32_ua2 : f32 = u.a[1];
var i32_ua3 : f32 = u.a[0];
var i32_ua4 : f32 = u.a[0];
var i32_ua5 : f32 = u.a[0];
var u32_sa1 : f32 = s.a[0u];
var u32_sa2 : f32 = s.a[1u];
var u32_sa3 : f32 = s.a[3u];
var u32_sa4 : f32 = s.a[4u];
var u32_sa5 : f32 = s.a[10u];
var u32_sa6 : f32 = s.a[100u];
var u32_sb1 : f32 = s.b[0u];
var u32_sb2 : f32 = s.b[1u];
var u32_sb3 : f32 = s.b[3u];
var u32_sb4 : f32 = s.b[4u];
var u32_sb5 : f32 = s.b[10u];
var u32_sb6 : f32 = s.b[100u];
var u32_ua1 : f32 = u.a[0u];
var u32_ua2 : f32 = u.a[1u];
var u32_ua3 : f32 = u.a[3u];
var u32_ua4 : f32 = u.a[3u];
var u32_ua5 : f32 = u.a[3u];
var u32_ua6 : f32 = u.a[3u];
}
)";
Robustness::Config cfg;
cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
DataMap data;
data.Add<Robustness::Config>(cfg);
auto got = Run<Robustness>(kOmitSourceShader, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, OmitUniform) {
auto* expect = R"(
[[block]]
struct S {
a : array<f32, 4>;
b : array<f32>;
};
[[group(0), binding(0)]] var<storage, read> s : S;
type UArr = [[stride(16)]] array<f32, 4>;
[[block]]
struct U {
a : UArr;
};
[[group(1), binding(0)]] var<uniform> u : U;
fn f() {
var i32_sa1 : f32 = s.a[3];
var i32_sa2 : f32 = s.a[1];
var i32_sa3 : f32 = s.a[0];
var i32_sa4 : f32 = s.a[0];
var i32_sa5 : f32 = s.a[0];
var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var i32_ua1 : f32 = u.a[4];
var i32_ua2 : f32 = u.a[1];
var i32_ua3 : f32 = u.a[0];
var i32_ua4 : f32 = u.a[-1];
var i32_ua5 : f32 = u.a[-4];
var u32_sa1 : f32 = s.a[0u];
var u32_sa2 : f32 = s.a[1u];
var u32_sa3 : f32 = s.a[3u];
var u32_sa4 : f32 = s.a[3u];
var u32_sa5 : f32 = s.a[3u];
var u32_sa6 : f32 = s.a[3u];
var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))];
var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))];
var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))];
var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))];
var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))];
var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))];
var u32_ua1 : f32 = u.a[0u];
var u32_ua2 : f32 = u.a[1u];
var u32_ua3 : f32 = u.a[3u];
var u32_ua4 : f32 = u.a[4u];
var u32_ua5 : f32 = u.a[10u];
var u32_ua6 : f32 = u.a[100u];
}
)";
Robustness::Config cfg;
cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
DataMap data;
data.Add<Robustness::Config>(cfg);
auto got = Run<Robustness>(kOmitSourceShader, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, OmitBoth) {
auto* expect = R"(
[[block]]
struct S {
a : array<f32, 4>;
b : array<f32>;
};
[[group(0), binding(0)]] var<storage, read> s : S;
type UArr = [[stride(16)]] array<f32, 4>;
[[block]]
struct U {
a : UArr;
};
[[group(1), binding(0)]] var<uniform> u : U;
fn f() {
var i32_sa1 : f32 = s.a[4];
var i32_sa2 : f32 = s.a[1];
var i32_sa3 : f32 = s.a[0];
var i32_sa4 : f32 = s.a[-1];
var i32_sa5 : f32 = s.a[-4];
var i32_sb1 : f32 = s.b[4];
var i32_sb2 : f32 = s.b[1];
var i32_sb3 : f32 = s.b[0];
var i32_sb4 : f32 = s.b[-1];
var i32_sb5 : f32 = s.b[-4];
var i32_ua1 : f32 = u.a[4];
var i32_ua2 : f32 = u.a[1];
var i32_ua3 : f32 = u.a[0];
var i32_ua4 : f32 = u.a[-1];
var i32_ua5 : f32 = u.a[-4];
var u32_sa1 : f32 = s.a[0u];
var u32_sa2 : f32 = s.a[1u];
var u32_sa3 : f32 = s.a[3u];
var u32_sa4 : f32 = s.a[4u];
var u32_sa5 : f32 = s.a[10u];
var u32_sa6 : f32 = s.a[100u];
var u32_sb1 : f32 = s.b[0u];
var u32_sb2 : f32 = s.b[1u];
var u32_sb3 : f32 = s.b[3u];
var u32_sb4 : f32 = s.b[4u];
var u32_sb5 : f32 = s.b[10u];
var u32_sb6 : f32 = s.b[100u];
var u32_ua1 : f32 = u.a[0u];
var u32_ua2 : f32 = u.a[1u];
var u32_ua3 : f32 = u.a[3u];
var u32_ua4 : f32 = u.a[4u];
var u32_ua5 : f32 = u.a[10u];
var u32_ua6 : f32 = u.a[100u];
}
)";
Robustness::Config cfg;
cfg.omitted_classes.insert(Robustness::StorageClass::kStorage);
cfg.omitted_classes.insert(Robustness::StorageClass::kUniform);
DataMap data;
data.Add<Robustness::Config>(cfg);
auto got = Run<Robustness>(kOmitSourceShader, data);
EXPECT_EQ(expect, str(got));
}
} // namespace } // namespace
} // namespace transform } // namespace transform
} // namespace tint } // namespace tint