diff --git a/src/transform/robustness.cc b/src/transform/robustness.cc index 9aa5a3e76c..6bd7678242 100644 --- a/src/transform/robustness.cc +++ b/src/transform/robustness.cc @@ -22,9 +22,11 @@ #include "src/sem/block_statement.h" #include "src/sem/call.h" #include "src/sem/expression.h" +#include "src/sem/reference_type.h" #include "src/sem/statement.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness); +TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness::Config); namespace tint { namespace transform { @@ -34,6 +36,9 @@ struct Robustness::State { /// The clone context CloneContext& ctx; + /// Set of storage classes to not apply the transform to + std::unordered_set omitted_classes; + /// Applies the transformation state to `ctx`. void Transform() { ctx.ReplaceAll( @@ -46,7 +51,14 @@ struct Robustness::State { /// @return the clamped replacement expression, or nullptr if `expr` should be /// cloned without changes. ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr) { - auto* ret_type = ctx.src->Sem().Get(expr->array)->Type()->UnwrapRef(); + auto* ret_type = ctx.src->Sem().Get(expr->array)->Type(); + + auto* ref = ret_type->As(); + if (ref && omitted_classes.count(ref->StorageClass()) != 0) { + return nullptr; + } + + auto* ret_unwrapped = ret_type->UnwrapRef(); ProgramBuilder& b = *ctx.dst; using u32 = ProgramBuilder::u32; @@ -62,12 +74,12 @@ struct Robustness::State { Value size; // size of the array, vector or matrix size.is_signed = false; // size is always unsigned - if (auto* vec = ret_type->As()) { + if (auto* vec = ret_unwrapped->As()) { size.u32 = vec->Width(); - } else if (auto* arr = ret_type->As()) { + } else if (auto* arr = ret_unwrapped->As()) { size.u32 = arr->Count(); - } else if (auto* mat = ret_type->As()) { + } else if (auto* mat = ret_unwrapped->As()) { // The row accessor would have been an embedded array accessor and already // handled, so we just need to do columns here. size.u32 = mat->columns(); @@ -76,7 +88,7 @@ struct Robustness::State { } if (size.u32 == 0) { - if (!ret_type->Is()) { + if (!ret_unwrapped->Is()) { b.Diagnostics().add_error(diag::System::Transform, "invalid 0 sized non-array", expr->source); return nullptr; @@ -268,11 +280,34 @@ struct Robustness::State { } }; +Robustness::Config::Config() = default; +Robustness::Config::Config(const Config&) = default; +Robustness::Config::~Config() = default; +Robustness::Config& Robustness::Config::operator=(const Config&) = default; + Robustness::Robustness() = default; Robustness::~Robustness() = default; -void Robustness::Run(CloneContext& ctx, const DataMap&, DataMap&) { - State state{ctx}; +void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { + Config cfg; + if (auto* cfg_data = inputs.Get()) { + cfg = *cfg_data; + } + + std::unordered_set omitted_classes; + for (auto sc : cfg.omitted_classes) { + switch (sc) { + case StorageClass::kUniform: + omitted_classes.insert(ast::StorageClass::kUniform); + break; + case StorageClass::kStorage: + omitted_classes.insert(ast::StorageClass::kStorage); + break; + } + } + + State state{ctx, std::move(omitted_classes)}; + state.Transform(); ctx.Clone(); } diff --git a/src/transform/robustness.h b/src/transform/robustness.h index fcade3aff5..1333e5c8e2 100644 --- a/src/transform/robustness.h +++ b/src/transform/robustness.h @@ -15,6 +15,8 @@ #ifndef SRC_TRANSFORM_ROBUSTNESS_H_ #define SRC_TRANSFORM_ROBUSTNESS_H_ +#include + #include "src/transform/transform.h" // Forward declarations @@ -34,6 +36,32 @@ namespace transform { /// (array length - 1). class Robustness : public Castable { public: + /// Storage class to be skipped in the transform + enum class StorageClass { + kUniform, + kStorage, + }; + + /// Configuration options for the transform + struct Config : public Castable { + /// Constructor + Config(); + + /// Copy constructor + Config(const Config&); + + /// Destructor + ~Config() override; + + /// Assignment operator + /// @returns this Config + Config& operator=(const Config&); + + /// Storage classes to omit from apply the transform to. + /// This allows for optimizing on hardware that provide safe accesses. + std::unordered_set omitted_classes; + }; + /// Constructor Robustness(); /// Destructor diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc index a6612bee68..a1cf043839 100644 --- a/src/transform/robustness_test.cc +++ b/src/transform/robustness_test.cc @@ -818,6 +818,331 @@ fn f() { EXPECT_EQ(expect, str(got)); } +const char* kOmitSourceShader = R"( +[[block]] +struct S { + a : array; + b : array; +}; +[[group(0), binding(0)]] var s : S; + +type UArr = [[stride(16)]] array; +[[block]] struct U { + a : UArr; +}; +[[group(1), binding(0)]] var u : U; + +fn f() { + // Signed + var i32_sa1 : f32 = s.a[4]; + var i32_sa2 : f32 = s.a[1]; + var i32_sa3 : f32 = s.a[0]; + var i32_sa4 : f32 = s.a[-1]; + var i32_sa5 : f32 = s.a[-4]; + + var i32_sb1 : f32 = s.b[4]; + var i32_sb2 : f32 = s.b[1]; + var i32_sb3 : f32 = s.b[0]; + var i32_sb4 : f32 = s.b[-1]; + var i32_sb5 : f32 = s.b[-4]; + + var i32_ua1 : f32 = u.a[4]; + var i32_ua2 : f32 = u.a[1]; + var i32_ua3 : f32 = u.a[0]; + var i32_ua4 : f32 = u.a[-1]; + var i32_ua5 : f32 = u.a[-4]; + + // Unsigned + var u32_sa1 : f32 = s.a[0u]; + var u32_sa2 : f32 = s.a[1u]; + var u32_sa3 : f32 = s.a[3u]; + var u32_sa4 : f32 = s.a[4u]; + var u32_sa5 : f32 = s.a[10u]; + var u32_sa6 : f32 = s.a[100u]; + + var u32_sb1 : f32 = s.b[0u]; + var u32_sb2 : f32 = s.b[1u]; + var u32_sb3 : f32 = s.b[3u]; + var u32_sb4 : f32 = s.b[4u]; + var u32_sb5 : f32 = s.b[10u]; + var u32_sb6 : f32 = s.b[100u]; + + var u32_ua1 : f32 = u.a[0u]; + var u32_ua2 : f32 = u.a[1u]; + var u32_ua3 : f32 = u.a[3u]; + var u32_ua4 : f32 = u.a[4u]; + var u32_ua5 : f32 = u.a[10u]; + var u32_ua6 : f32 = u.a[100u]; +} +)"; + +TEST_F(RobustnessTest, OmitNone) { + auto* expect = R"( +[[block]] +struct S { + a : array; + b : array; +}; + +[[group(0), binding(0)]] var s : S; + +type UArr = [[stride(16)]] array; + +[[block]] +struct U { + a : UArr; +}; + +[[group(1), binding(0)]] var u : U; + +fn f() { + var i32_sa1 : f32 = s.a[3]; + var i32_sa2 : f32 = s.a[1]; + var i32_sa3 : f32 = s.a[0]; + var i32_sa4 : f32 = s.a[0]; + var i32_sa5 : f32 = s.a[0]; + var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var i32_ua1 : f32 = u.a[3]; + var i32_ua2 : f32 = u.a[1]; + var i32_ua3 : f32 = u.a[0]; + var i32_ua4 : f32 = u.a[0]; + var i32_ua5 : f32 = u.a[0]; + var u32_sa1 : f32 = s.a[0u]; + var u32_sa2 : f32 = s.a[1u]; + var u32_sa3 : f32 = s.a[3u]; + var u32_sa4 : f32 = s.a[3u]; + var u32_sa5 : f32 = s.a[3u]; + var u32_sa6 : f32 = s.a[3u]; + var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))]; + var u32_ua1 : f32 = u.a[0u]; + var u32_ua2 : f32 = u.a[1u]; + var u32_ua3 : f32 = u.a[3u]; + var u32_ua4 : f32 = u.a[3u]; + var u32_ua5 : f32 = u.a[3u]; + var u32_ua6 : f32 = u.a[3u]; +} +)"; + + Robustness::Config cfg; + DataMap data; + data.Add(cfg); + + auto got = Run(kOmitSourceShader, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(RobustnessTest, OmitStorage) { + auto* expect = R"( +[[block]] +struct S { + a : array; + b : array; +}; + +[[group(0), binding(0)]] var s : S; + +type UArr = [[stride(16)]] array; + +[[block]] +struct U { + a : UArr; +}; + +[[group(1), binding(0)]] var u : U; + +fn f() { + var i32_sa1 : f32 = s.a[4]; + var i32_sa2 : f32 = s.a[1]; + var i32_sa3 : f32 = s.a[0]; + var i32_sa4 : f32 = s.a[-1]; + var i32_sa5 : f32 = s.a[-4]; + var i32_sb1 : f32 = s.b[4]; + var i32_sb2 : f32 = s.b[1]; + var i32_sb3 : f32 = s.b[0]; + var i32_sb4 : f32 = s.b[-1]; + var i32_sb5 : f32 = s.b[-4]; + var i32_ua1 : f32 = u.a[3]; + var i32_ua2 : f32 = u.a[1]; + var i32_ua3 : f32 = u.a[0]; + var i32_ua4 : f32 = u.a[0]; + var i32_ua5 : f32 = u.a[0]; + var u32_sa1 : f32 = s.a[0u]; + var u32_sa2 : f32 = s.a[1u]; + var u32_sa3 : f32 = s.a[3u]; + var u32_sa4 : f32 = s.a[4u]; + var u32_sa5 : f32 = s.a[10u]; + var u32_sa6 : f32 = s.a[100u]; + var u32_sb1 : f32 = s.b[0u]; + var u32_sb2 : f32 = s.b[1u]; + var u32_sb3 : f32 = s.b[3u]; + var u32_sb4 : f32 = s.b[4u]; + var u32_sb5 : f32 = s.b[10u]; + var u32_sb6 : f32 = s.b[100u]; + var u32_ua1 : f32 = u.a[0u]; + var u32_ua2 : f32 = u.a[1u]; + var u32_ua3 : f32 = u.a[3u]; + var u32_ua4 : f32 = u.a[3u]; + var u32_ua5 : f32 = u.a[3u]; + var u32_ua6 : f32 = u.a[3u]; +} +)"; + + Robustness::Config cfg; + cfg.omitted_classes.insert(Robustness::StorageClass::kStorage); + + DataMap data; + data.Add(cfg); + + auto got = Run(kOmitSourceShader, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(RobustnessTest, OmitUniform) { + auto* expect = R"( +[[block]] +struct S { + a : array; + b : array; +}; + +[[group(0), binding(0)]] var s : S; + +type UArr = [[stride(16)]] array; + +[[block]] +struct U { + a : UArr; +}; + +[[group(1), binding(0)]] var u : U; + +fn f() { + var i32_sa1 : f32 = s.a[3]; + var i32_sa2 : f32 = s.a[1]; + var i32_sa3 : f32 = s.a[0]; + var i32_sa4 : f32 = s.a[0]; + var i32_sa5 : f32 = s.a[0]; + var i32_sb1 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb3 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb4 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var i32_sb5 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var i32_ua1 : f32 = u.a[4]; + var i32_ua2 : f32 = u.a[1]; + var i32_ua3 : f32 = u.a[0]; + var i32_ua4 : f32 = u.a[-1]; + var i32_ua5 : f32 = u.a[-4]; + var u32_sa1 : f32 = s.a[0u]; + var u32_sa2 : f32 = s.a[1u]; + var u32_sa3 : f32 = s.a[3u]; + var u32_sa4 : f32 = s.a[3u]; + var u32_sa5 : f32 = s.a[3u]; + var u32_sa6 : f32 = s.a[3u]; + var u32_sb1 : f32 = s.b[min(0u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb2 : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb3 : f32 = s.b[min(3u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb4 : f32 = s.b[min(4u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb5 : f32 = s.b[min(10u, (arrayLength(&(s.b)) - 1u))]; + var u32_sb6 : f32 = s.b[min(100u, (arrayLength(&(s.b)) - 1u))]; + var u32_ua1 : f32 = u.a[0u]; + var u32_ua2 : f32 = u.a[1u]; + var u32_ua3 : f32 = u.a[3u]; + var u32_ua4 : f32 = u.a[4u]; + var u32_ua5 : f32 = u.a[10u]; + var u32_ua6 : f32 = u.a[100u]; +} +)"; + + Robustness::Config cfg; + cfg.omitted_classes.insert(Robustness::StorageClass::kUniform); + + DataMap data; + data.Add(cfg); + + auto got = Run(kOmitSourceShader, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(RobustnessTest, OmitBoth) { + auto* expect = R"( +[[block]] +struct S { + a : array; + b : array; +}; + +[[group(0), binding(0)]] var s : S; + +type UArr = [[stride(16)]] array; + +[[block]] +struct U { + a : UArr; +}; + +[[group(1), binding(0)]] var u : U; + +fn f() { + var i32_sa1 : f32 = s.a[4]; + var i32_sa2 : f32 = s.a[1]; + var i32_sa3 : f32 = s.a[0]; + var i32_sa4 : f32 = s.a[-1]; + var i32_sa5 : f32 = s.a[-4]; + var i32_sb1 : f32 = s.b[4]; + var i32_sb2 : f32 = s.b[1]; + var i32_sb3 : f32 = s.b[0]; + var i32_sb4 : f32 = s.b[-1]; + var i32_sb5 : f32 = s.b[-4]; + var i32_ua1 : f32 = u.a[4]; + var i32_ua2 : f32 = u.a[1]; + var i32_ua3 : f32 = u.a[0]; + var i32_ua4 : f32 = u.a[-1]; + var i32_ua5 : f32 = u.a[-4]; + var u32_sa1 : f32 = s.a[0u]; + var u32_sa2 : f32 = s.a[1u]; + var u32_sa3 : f32 = s.a[3u]; + var u32_sa4 : f32 = s.a[4u]; + var u32_sa5 : f32 = s.a[10u]; + var u32_sa6 : f32 = s.a[100u]; + var u32_sb1 : f32 = s.b[0u]; + var u32_sb2 : f32 = s.b[1u]; + var u32_sb3 : f32 = s.b[3u]; + var u32_sb4 : f32 = s.b[4u]; + var u32_sb5 : f32 = s.b[10u]; + var u32_sb6 : f32 = s.b[100u]; + var u32_ua1 : f32 = u.a[0u]; + var u32_ua2 : f32 = u.a[1u]; + var u32_ua3 : f32 = u.a[3u]; + var u32_ua4 : f32 = u.a[4u]; + var u32_ua5 : f32 = u.a[10u]; + var u32_ua6 : f32 = u.a[100u]; +} +)"; + + Robustness::Config cfg; + cfg.omitted_classes.insert(Robustness::StorageClass::kStorage); + cfg.omitted_classes.insert(Robustness::StorageClass::kUniform); + + DataMap data; + data.Add(cfg); + + auto got = Run(kOmitSourceShader, data); + + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace transform } // namespace tint