mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-05-16 20:31:20 +00:00
tint/transform: Fix NPE in ZeroInitWorkgroupMemory.
If an array uses an override expression, then we'd raise an error, but then attempt to dereference a nullptr. Bug: chromium:1392853 Change-Id: Ib1d538bc491923b628b32f2398f8b2ace24c3bc3 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112561 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
25c0bdf2a9
commit
7423496da6
@ -100,6 +100,9 @@ struct ZeroInitWorkgroupMemory::State {
|
|||||||
uint32_t num_iterations = 0;
|
uint32_t num_iterations = 0;
|
||||||
/// All array indices used by this expression
|
/// All array indices used by this expression
|
||||||
ArrayIndices array_indices;
|
ArrayIndices array_indices;
|
||||||
|
|
||||||
|
/// @returns true if the expr is not null (null usually indicates a failure)
|
||||||
|
operator bool() const { return expr != nullptr; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Statement holds information about a statement that will zero workgroup
|
/// Statement holds information about a statement that will zero workgroup
|
||||||
@ -137,10 +140,13 @@ struct ZeroInitWorkgroupMemory::State {
|
|||||||
auto* func = sem.Get(fn);
|
auto* func = sem.Get(fn);
|
||||||
for (auto* var : func->TransitivelyReferencedGlobals()) {
|
for (auto* var : func->TransitivelyReferencedGlobals()) {
|
||||||
if (var->AddressSpace() == ast::AddressSpace::kWorkgroup) {
|
if (var->AddressSpace() == ast::AddressSpace::kWorkgroup) {
|
||||||
BuildZeroingStatements(var->Type()->UnwrapRef(), [&](uint32_t num_values) {
|
auto get_expr = [&](uint32_t num_values) {
|
||||||
auto var_name = ctx.Clone(var->Declaration()->symbol);
|
auto var_name = ctx.Clone(var->Declaration()->symbol);
|
||||||
return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
|
return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
|
||||||
});
|
};
|
||||||
|
if (!BuildZeroingStatements(var->Type()->UnwrapRef(), get_expr)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,41 +289,54 @@ struct ZeroInitWorkgroupMemory::State {
|
|||||||
/// initialize the workgroup storage expression of type `ty`.
|
/// initialize the workgroup storage expression of type `ty`.
|
||||||
/// @param ty the expression type
|
/// @param ty the expression type
|
||||||
/// @param get_expr a function that builds the AST nodes for the expression.
|
/// @param get_expr a function that builds the AST nodes for the expression.
|
||||||
void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) {
|
/// @returns true on success, false on failure
|
||||||
|
[[nodiscard]] bool BuildZeroingStatements(const sem::Type* ty,
|
||||||
|
const BuildZeroingExpr& get_expr) {
|
||||||
if (CanTriviallyZero(ty)) {
|
if (CanTriviallyZero(ty)) {
|
||||||
auto var = get_expr(1u);
|
auto var = get_expr(1u);
|
||||||
|
if (!var) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
|
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
|
||||||
statements.emplace_back(
|
statements.emplace_back(
|
||||||
Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
|
Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
|
||||||
return;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* atomic = ty->As<sem::Atomic>()) {
|
if (auto* atomic = ty->As<sem::Atomic>()) {
|
||||||
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
|
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
|
||||||
auto expr = get_expr(1u);
|
auto expr = get_expr(1u);
|
||||||
|
if (!expr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
|
auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
|
||||||
statements.emplace_back(
|
statements.emplace_back(
|
||||||
Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
|
Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
|
||||||
return;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* str = ty->As<sem::Struct>()) {
|
if (auto* str = ty->As<sem::Struct>()) {
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
auto name = ctx.Clone(member->Declaration()->symbol);
|
auto name = ctx.Clone(member->Declaration()->symbol);
|
||||||
BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
|
auto get_member = [&](uint32_t num_values) {
|
||||||
auto s = get_expr(num_values);
|
auto s = get_expr(num_values);
|
||||||
|
if (!s) {
|
||||||
|
return Expression{}; // error
|
||||||
|
}
|
||||||
return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
|
return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
|
||||||
s.array_indices};
|
s.array_indices};
|
||||||
});
|
};
|
||||||
|
if (!BuildZeroingStatements(member->Type(), get_member)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* arr = ty->As<sem::Array>()) {
|
if (auto* arr = ty->As<sem::Array>()) {
|
||||||
BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
|
auto get_el = [&](uint32_t num_values) {
|
||||||
// num_values is the number of values to zero for the element type.
|
// num_values is the number of values to zero for the element type.
|
||||||
// The number of iterations required to zero the array and its elements
|
// The number of iterations required to zero the array and its elements is:
|
||||||
// is:
|
|
||||||
// `num_values * arr->Count()`
|
// `num_values * arr->Count()`
|
||||||
// The index for this array is:
|
// The index for this array is:
|
||||||
// `(idx % modulo) / division`
|
// `(idx % modulo) / division`
|
||||||
@ -325,22 +344,26 @@ struct ZeroInitWorkgroupMemory::State {
|
|||||||
if (!count) {
|
if (!count) {
|
||||||
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
||||||
sem::Array::kErrExpectedConstantCount);
|
sem::Array::kErrExpectedConstantCount);
|
||||||
return Expression{};
|
return Expression{}; // error
|
||||||
}
|
}
|
||||||
auto modulo = num_values * count.value();
|
auto modulo = num_values * count.value();
|
||||||
auto division = num_values;
|
auto division = num_values;
|
||||||
auto a = get_expr(modulo);
|
auto a = get_expr(modulo);
|
||||||
|
if (!a) {
|
||||||
|
return Expression{}; // error
|
||||||
|
}
|
||||||
auto array_indices = a.array_indices;
|
auto array_indices = a.array_indices;
|
||||||
array_indices.Add(ArrayIndex{modulo, division});
|
array_indices.Add(ArrayIndex{modulo, division});
|
||||||
auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
|
auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
|
||||||
[&] { return b.Symbols().New("i"); });
|
[&] { return b.Symbols().New("i"); });
|
||||||
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
|
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
|
||||||
});
|
};
|
||||||
return;
|
return BuildZeroingStatements(arr->ElemType(), get_el);
|
||||||
}
|
}
|
||||||
|
|
||||||
TINT_UNREACHABLE(Transform, b.Diagnostics())
|
TINT_UNREACHABLE(Transform, b.Diagnostics())
|
||||||
<< "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols());
|
<< "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols());
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// DeclareArrayIndices returns a list of statements that contain the `let`
|
/// DeclareArrayIndices returns a list of statements that contain the `let`
|
||||||
|
@ -1363,5 +1363,28 @@ struct S {
|
|||||||
EXPECT_EQ(expect, str(got));
|
EXPECT_EQ(expect, str(got));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ZeroInitWorkgroupMemoryTest, ArrayWithOverrideCount) {
|
||||||
|
auto* src =
|
||||||
|
R"(override O = 123;
|
||||||
|
type A = array<i32, O*2>;
|
||||||
|
|
||||||
|
var<workgroup> W : A;
|
||||||
|
|
||||||
|
@compute @workgroup_size(1)
|
||||||
|
fn main() {
|
||||||
|
let p : ptr<workgroup, A> = &W;
|
||||||
|
(*p)[0] = 42;
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto* expect =
|
||||||
|
R"(error: array size is an override-expression, when expected a constant-expression.
|
||||||
|
Was the SubstituteOverride transform run?)";
|
||||||
|
|
||||||
|
auto got = Run<ZeroInitWorkgroupMemory>(src);
|
||||||
|
|
||||||
|
EXPECT_EQ(expect, str(got));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tint::transform
|
} // namespace tint::transform
|
||||||
|
Loading…
x
Reference in New Issue
Block a user