tint/transform: Fix NPE in ZeroInitWorkgroupMemory.
If an array uses an override expression, then we'd raise an error, but then attempt to dereference a nullptr. Bug: chromium:1392853 Change-Id: Ib1d538bc491923b628b32f2398f8b2ace24c3bc3 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112561 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
25c0bdf2a9
commit
7423496da6
|
@ -100,6 +100,9 @@ struct ZeroInitWorkgroupMemory::State {
|
|||
uint32_t num_iterations = 0;
|
||||
/// All array indices used by this expression
|
||||
ArrayIndices array_indices;
|
||||
|
||||
/// @returns true if the expr is not null (null usually indicates a failure)
|
||||
operator bool() const { return expr != nullptr; }
|
||||
};
|
||||
|
||||
/// Statement holds information about a statement that will zero workgroup
|
||||
|
@ -137,10 +140,13 @@ struct ZeroInitWorkgroupMemory::State {
|
|||
auto* func = sem.Get(fn);
|
||||
for (auto* var : func->TransitivelyReferencedGlobals()) {
|
||||
if (var->AddressSpace() == ast::AddressSpace::kWorkgroup) {
|
||||
BuildZeroingStatements(var->Type()->UnwrapRef(), [&](uint32_t num_values) {
|
||||
auto get_expr = [&](uint32_t num_values) {
|
||||
auto var_name = ctx.Clone(var->Declaration()->symbol);
|
||||
return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
|
||||
});
|
||||
};
|
||||
if (!BuildZeroingStatements(var->Type()->UnwrapRef(), get_expr)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -283,41 +289,54 @@ struct ZeroInitWorkgroupMemory::State {
|
|||
/// initialize the workgroup storage expression of type `ty`.
|
||||
/// @param ty the expression type
|
||||
/// @param get_expr a function that builds the AST nodes for the expression.
|
||||
void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) {
|
||||
/// @returns true on success, false on failure
|
||||
[[nodiscard]] bool BuildZeroingStatements(const sem::Type* ty,
|
||||
const BuildZeroingExpr& get_expr) {
|
||||
if (CanTriviallyZero(ty)) {
|
||||
auto var = get_expr(1u);
|
||||
if (!var) {
|
||||
return false;
|
||||
}
|
||||
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
|
||||
statements.emplace_back(
|
||||
Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* atomic = ty->As<sem::Atomic>()) {
|
||||
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
|
||||
auto expr = get_expr(1u);
|
||||
if (!expr) {
|
||||
return false;
|
||||
}
|
||||
auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
|
||||
statements.emplace_back(
|
||||
Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* str = ty->As<sem::Struct>()) {
|
||||
for (auto* member : str->Members()) {
|
||||
auto name = ctx.Clone(member->Declaration()->symbol);
|
||||
BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
|
||||
auto get_member = [&](uint32_t num_values) {
|
||||
auto s = get_expr(num_values);
|
||||
if (!s) {
|
||||
return Expression{}; // error
|
||||
}
|
||||
return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
|
||||
s.array_indices};
|
||||
});
|
||||
};
|
||||
if (!BuildZeroingStatements(member->Type(), get_member)) {
|
||||
return false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* arr = ty->As<sem::Array>()) {
|
||||
BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
|
||||
auto get_el = [&](uint32_t num_values) {
|
||||
// num_values is the number of values to zero for the element type.
|
||||
// The number of iterations required to zero the array and its elements
|
||||
// is:
|
||||
// The number of iterations required to zero the array and its elements is:
|
||||
// `num_values * arr->Count()`
|
||||
// The index for this array is:
|
||||
// `(idx % modulo) / division`
|
||||
|
@ -325,22 +344,26 @@ struct ZeroInitWorkgroupMemory::State {
|
|||
if (!count) {
|
||||
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
||||
sem::Array::kErrExpectedConstantCount);
|
||||
return Expression{};
|
||||
return Expression{}; // error
|
||||
}
|
||||
auto modulo = num_values * count.value();
|
||||
auto division = num_values;
|
||||
auto a = get_expr(modulo);
|
||||
if (!a) {
|
||||
return Expression{}; // error
|
||||
}
|
||||
auto array_indices = a.array_indices;
|
||||
array_indices.Add(ArrayIndex{modulo, division});
|
||||
auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
|
||||
[&] { return b.Symbols().New("i"); });
|
||||
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
|
||||
});
|
||||
return;
|
||||
};
|
||||
return BuildZeroingStatements(arr->ElemType(), get_el);
|
||||
}
|
||||
|
||||
TINT_UNREACHABLE(Transform, b.Diagnostics())
|
||||
<< "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols());
|
||||
return false;
|
||||
}
|
||||
|
||||
/// DeclareArrayIndices returns a list of statements that contain the `let`
|
||||
|
|
|
@ -1363,5 +1363,28 @@ struct S {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ZeroInitWorkgroupMemoryTest, ArrayWithOverrideCount) {
|
||||
auto* src =
|
||||
R"(override O = 123;
|
||||
type A = array<i32, O*2>;
|
||||
|
||||
var<workgroup> W : A;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
let p : ptr<workgroup, A> = &W;
|
||||
(*p)[0] = 42;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
R"(error: array size is an override-expression, when expected a constant-expression.
|
||||
Was the SubstituteOverride transform run?)";
|
||||
|
||||
auto got = Run<ZeroInitWorkgroupMemory>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::transform
|
||||
|
|
Loading…
Reference in New Issue