tint/transform: Fix NPE in ZeroInitWorkgroupMemory.

If an array uses an override expression, then we'd raise an error, but then attempt to dereference a nullptr.

Bug: chromium:1392853
Change-Id: Ib1d538bc491923b628b32f2398f8b2ace24c3bc3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112561
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-12-02 13:49:57 +00:00 committed by Dawn LUCI CQ
parent 25c0bdf2a9
commit 7423496da6
2 changed files with 60 additions and 14 deletions

View File

@ -100,6 +100,9 @@ struct ZeroInitWorkgroupMemory::State {
uint32_t num_iterations = 0;
/// All array indices used by this expression
ArrayIndices array_indices;
/// @returns true if the expr is not null (null usually indicates a failure)
operator bool() const { return expr != nullptr; }
};
/// Statement holds information about a statement that will zero workgroup
@ -137,10 +140,13 @@ struct ZeroInitWorkgroupMemory::State {
auto* func = sem.Get(fn);
for (auto* var : func->TransitivelyReferencedGlobals()) {
if (var->AddressSpace() == ast::AddressSpace::kWorkgroup) {
BuildZeroingStatements(var->Type()->UnwrapRef(), [&](uint32_t num_values) {
auto get_expr = [&](uint32_t num_values) {
auto var_name = ctx.Clone(var->Declaration()->symbol);
return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
});
};
if (!BuildZeroingStatements(var->Type()->UnwrapRef(), get_expr)) {
return;
}
}
}
@ -283,41 +289,54 @@ struct ZeroInitWorkgroupMemory::State {
/// initialize the workgroup storage expression of type `ty`.
/// @param ty the expression type
/// @param get_expr a function that builds the AST nodes for the expression.
void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) {
/// @returns true on success, false on failure
[[nodiscard]] bool BuildZeroingStatements(const sem::Type* ty,
const BuildZeroingExpr& get_expr) {
if (CanTriviallyZero(ty)) {
auto var = get_expr(1u);
if (!var) {
return false;
}
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
statements.emplace_back(
Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
return;
return true;
}
if (auto* atomic = ty->As<sem::Atomic>()) {
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
auto expr = get_expr(1u);
if (!expr) {
return false;
}
auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
statements.emplace_back(
Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
return;
return true;
}
if (auto* str = ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto name = ctx.Clone(member->Declaration()->symbol);
BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
auto get_member = [&](uint32_t num_values) {
auto s = get_expr(num_values);
if (!s) {
return Expression{}; // error
}
return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
s.array_indices};
});
};
if (!BuildZeroingStatements(member->Type(), get_member)) {
return false;
}
return;
}
return true;
}
if (auto* arr = ty->As<sem::Array>()) {
BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
auto get_el = [&](uint32_t num_values) {
// num_values is the number of values to zero for the element type.
// The number of iterations required to zero the array and its elements
// is:
// The number of iterations required to zero the array and its elements is:
// `num_values * arr->Count()`
// The index for this array is:
// `(idx % modulo) / division`
@ -325,22 +344,26 @@ struct ZeroInitWorkgroupMemory::State {
if (!count) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
return Expression{};
return Expression{}; // error
}
auto modulo = num_values * count.value();
auto division = num_values;
auto a = get_expr(modulo);
if (!a) {
return Expression{}; // error
}
auto array_indices = a.array_indices;
array_indices.Add(ArrayIndex{modulo, division});
auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
[&] { return b.Symbols().New("i"); });
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
});
return;
};
return BuildZeroingStatements(arr->ElemType(), get_el);
}
TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols());
return false;
}
/// DeclareArrayIndices returns a list of statements that contain the `let`

View File

@ -1363,5 +1363,28 @@ struct S {
EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, ArrayWithOverrideCount) {
auto* src =
R"(override O = 123;
type A = array<i32, O*2>;
var<workgroup> W : A;
@compute @workgroup_size(1)
fn main() {
let p : ptr<workgroup, A> = &W;
(*p)[0] = 42;
}
)";
auto* expect =
R"(error: array size is an override-expression, when expected a constant-expression.
Was the SubstituteOverride transform run?)";
auto got = Run<ZeroInitWorkgroupMemory>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform