tint/resolver: Materialize arguments to @workgroup_size

Bug: tint:1504
Change-Id: I69b448e62a4ebd684f6832f76fd28d8a31892a1a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91847
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-28 10:34:06 +00:00 committed by Dawn LUCI CQ
parent b59de924db
commit b8ac933909
4 changed files with 158 additions and 71 deletions

View File

@ -2546,10 +2546,20 @@ class ProgramBuilder {
} }
/// Creates an ast::WorkgroupAttribute /// Creates an ast::WorkgroupAttribute
/// @param source the source information
/// @param x the x dimension expression /// @param x the x dimension expression
/// @param y the y dimension expression /// @param y the y dimension expression
/// @returns the workgroup attribute pointer /// @returns the workgroup attribute pointer
template <typename EXPR_X, typename EXPR_Y> template <typename EXPR_X, typename EXPR_Y>
const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x, EXPR_Y&& y) {
return WorkgroupSize(source, std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
}
/// Creates an ast::WorkgroupAttribute
/// @param x the x dimension expression
/// @param y the y dimension expression
/// @returns the workgroup attribute pointer
template <typename EXPR_X, typename EXPR_Y, typename = DisableIfSource<EXPR_X>>
const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) { const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) {
return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr); return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
} }
@ -2575,7 +2585,7 @@ class ProgramBuilder {
/// @param y the y dimension expression /// @param y the y dimension expression
/// @param z the z dimension expression /// @param z the z dimension expression
/// @returns the workgroup attribute pointer /// @returns the workgroup attribute pointer
template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z> template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z, typename = DisableIfSource<EXPR_X>>
const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) { const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) {
return create<ast::WorkgroupAttribute>(source_, Expr(std::forward<EXPR_X>(x)), return create<ast::WorkgroupAttribute>(source_, Expr(std::forward<EXPR_X>(x)),
Expr(std::forward<EXPR_Y>(y)), Expr(std::forward<EXPR_Y>(y)),

View File

@ -429,9 +429,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
// fn main() {} // fn main() {}
auto* x = GlobalConst("x", ty.u32(), Expr(4_u)); auto* x = GlobalConst("x", ty.u32(), Expr(4_u));
auto* y = GlobalConst("y", ty.u32(), Expr(8_u)); auto* y = GlobalConst("y", ty.u32(), Expr(8_u));
auto* func = Func( auto* func = Func("main", {}, ty.void_(), {},
"main", {}, ty.void_(), {}, {Stage(ast::PipelineStage::kCompute), WorkgroupSize("x", "y", 16_u)});
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Expr("x"), Expr("y"), Expr(16_u))});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -447,43 +446,68 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y)); EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y));
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
// @stage(compute) @workgroup_size(1i, 2i, 3i)
// fn main() {}
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_i, 3_i)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
// @stage(compute) @workgroup_size(1u, 2u, 3u) // @stage(compute) @workgroup_size(1u, 2u, 3u)
// fn main() {} // fn main() {}
Func("main", {}, ty.void_(), {}, Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_u, 3_u)});
WorkgroupSize(Source{{12, 34}}, Expr(1_u), Expr(2_u), Expr(3_u))});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeU32) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32_AInt) {
// @stage(compute) @workgroup_size(1u, 2u, 3_i) // @stage(compute) @workgroup_size(1, 2i, 3)
// fn main() {} // fn main() {}
Func("main", {}, ty.void_(), {}, Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_a, 2_i, 3_a)});
WorkgroupSize(Expr(1_u), Expr(2_u), Expr(Source{{12, 34}}, 3_i))});
EXPECT_FALSE(r()->Resolve()); ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size arguments must be of the same type, "
"either i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeI32) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) {
// @stage(compute) @workgroup_size(1_i, 2u, 3_i) // @stage(compute) @workgroup_size(1u, 2, 3u)
// fn main() {} // fn main() {}
Func("main", {}, ty.void_(), {}, Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_u)});
WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, 2_u), Expr(3_i))});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
// @stage(compute) @workgroup_size(1u, 2, 3_i)
// fn main() {}
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_i)});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size arguments must be of the same type, " "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
"either i32 or u32"); }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_I32) {
// @stage(compute) @workgroup_size(1_i, 2u, 3)
// fn main() {}
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_u, 3_a)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
@ -492,13 +516,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
// fn main() {} // fn main() {}
GlobalConst("x", ty.u32(), Expr(64_u)); GlobalConst("x", ty.u32(), Expr(64_u));
Func("main", {}, ty.void_(), {}, Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, "x")});
WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, "x"))});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size arguments must be of the same type, " "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
"either i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
@ -509,13 +531,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
GlobalConst("x", ty.u32(), Expr(64_u)); GlobalConst("x", ty.u32(), Expr(64_u));
GlobalConst("y", ty.i32(), Expr(32_i)); GlobalConst("y", ty.i32(), Expr(32_i));
Func("main", {}, ty.void_(), {}, Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y")});
WorkgroupSize(Expr("x"), Expr(Source{{12, 34}}, "y"))});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size arguments must be of the same type, " "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
"either i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
// let x = 4u; // let x = 4u;
@ -525,13 +545,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
GlobalConst("x", ty.u32(), Expr(4_u)); GlobalConst("x", ty.u32(), Expr(4_u));
GlobalConst("y", ty.u32(), Expr(8_u)); GlobalConst("y", ty.u32(), Expr(8_u));
Func("main", {}, ty.void_(), {}, Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y", 16_i)});
WorkgroupSize(Expr("x"), Expr("y"), Expr(Source{{12, 34}}, 16_i))});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size arguments must be of the same type, " "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
"either i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {

View File

@ -134,6 +134,11 @@ enum class Method {
// default: {} // default: {}
// } // }
kSwitchCaseWithAbstractCase, kSwitchCaseWithAbstractCase,
// @workgroup_size(target_expr, abstract_expr, 123)
// @stage(compute)
// fn f() {}
kWorkgroupSize
}; };
static std::ostream& operator<<(std::ostream& o, Method m) { static std::ostream& operator<<(std::ostream& o, Method m) {
@ -162,6 +167,8 @@ static std::ostream& operator<<(std::ostream& o, Method m) {
return o << "switch-cond-with-abstract"; return o << "switch-cond-with-abstract";
case Method::kSwitchCaseWithAbstractCase: case Method::kSwitchCaseWithAbstractCase:
return o << "switch-case-with-abstract"; return o << "switch-case-with-abstract";
case Method::kWorkgroupSize:
return o << "workgroup-size";
} }
return o << "<unknown>"; return o << "<unknown>";
} }
@ -286,6 +293,11 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) {
Case(abstract_expr->As<ast::IntLiteralExpression>()), // Case(abstract_expr->As<ast::IntLiteralExpression>()), //
DefaultCase())); DefaultCase()));
break; break;
case Method::kWorkgroupSize:
Func("f", {}, ty.void_(), {},
{WorkgroupSize(target_expr(), abstract_expr, Expr(123_a)),
Stage(ast::PipelineStage::kCompute)});
break;
} }
auto check_types_and_values = [&](const sem::Expression* expr) { auto check_types_and_values = [&](const sem::Expression* expr) {
@ -461,6 +473,19 @@ INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
Types<u32, AInt>(AInt(kLowestU32), kLowestU32), // Types<u32, AInt>(AInt(kLowestU32), kLowestU32), //
}))); })));
INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
MaterializeAbstractNumericToConcreteType,
testing::Combine(testing::Values(Expectation::kMaterialize),
testing::Values(Method::kWorkgroupSize),
testing::ValuesIn(std::vector<Data>{
Types<i32, AInt>(1_a, 1.0), //
Types<i32, AInt>(10_a, 10.0), //
Types<i32, AInt>(65535_a, 65535.0), //
Types<u32, AInt>(1_a, 1.0), //
Types<u32, AInt>(10_a, 10.0), //
Types<u32, AInt>(65535_a, 65535.0), //
})));
// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops. // TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops.
INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize, INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
MaterializeAbstractNumericToConcreteType, MaterializeAbstractNumericToConcreteType,
@ -558,6 +583,11 @@ enum class Method {
// default: {} // default: {}
// } // }
kSwitch, kSwitch,
// @workgroup_size(abstract_expr)
// @stage(compute)
// fn f() {}
kWorkgroupSize
}; };
static std::ostream& operator<<(std::ostream& o, Method m) { static std::ostream& operator<<(std::ostream& o, Method m) {
@ -576,6 +606,8 @@ static std::ostream& operator<<(std::ostream& o, Method m) {
return o << "array-length"; return o << "array-length";
case Method::kSwitch: case Method::kSwitch:
return o << "switch"; return o << "switch";
case Method::kWorkgroupSize:
return o << "workgroup-size";
} }
return o << "<unknown>"; return o << "<unknown>";
} }
@ -656,6 +688,10 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) {
Case(abstract_expr()->As<ast::IntLiteralExpression>()), Case(abstract_expr()->As<ast::IntLiteralExpression>()),
DefaultCase())); DefaultCase()));
break; break;
case Method::kWorkgroupSize:
Func("f", {}, ty.void_(), {},
{WorkgroupSize(abstract_expr()), Stage(ast::PipelineStage::kCompute)});
break;
} }
auto check_types_and_values = [&](const sem::Expression* expr) { auto check_types_and_values = [&](const sem::Expression* expr) {
@ -734,11 +770,6 @@ constexpr Method kMatrixMethods[] = {
Method::kVar, Method::kVar,
}; };
/// Methods that support materialization for switch cases
constexpr Method kSwitchMethods[] = {
Method::kSwitch,
};
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
MaterializeScalar, MaterializeScalar,
MaterializeAbstractNumericToDefaultType, MaterializeAbstractNumericToDefaultType,
@ -798,13 +829,23 @@ INSTANTIATE_TEST_SUITE_P(
INSTANTIATE_TEST_SUITE_P(MaterializeSwitch, INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
MaterializeAbstractNumericToDefaultType, MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kMaterialize), testing::Combine(testing::Values(Expectation::kMaterialize),
testing::ValuesIn(kSwitchMethods), testing::Values(Method::kSwitch),
testing::ValuesIn(std::vector<Data>{ testing::ValuesIn(std::vector<Data>{
Types<i32, AInt>(0_a, 0.0), // Types<i32, AInt>(0_a, 0.0), //
Types<i32, AInt>(AInt(kHighestI32), kHighestI32), // Types<i32, AInt>(AInt(kHighestI32), kHighestI32), //
Types<i32, AInt>(AInt(kLowestI32), kLowestI32), // Types<i32, AInt>(AInt(kLowestI32), kLowestI32), //
}))); })));
INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kMaterialize),
testing::Values(Method::kWorkgroupSize),
testing::ValuesIn(std::vector<Data>{
Types<i32, AInt>(1_a, 1.0), //
Types<i32, AInt>(10_a, 10.0), //
Types<i32, AInt>(65535_a, 65535.0), //
})));
INSTANTIATE_TEST_SUITE_P(ScalarValueCannotBeRepresented, INSTANTIATE_TEST_SUITE_P(ScalarValueCannotBeRepresented,
MaterializeAbstractNumericToDefaultType, MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
@ -840,7 +881,16 @@ INSTANTIATE_TEST_SUITE_P(MatrixValueCannotBeRepresented,
INSTANTIATE_TEST_SUITE_P(SwitchValueCannotBeRepresented, INSTANTIATE_TEST_SUITE_P(SwitchValueCannotBeRepresented,
MaterializeAbstractNumericToDefaultType, MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
testing::ValuesIn(kSwitchMethods), testing::Values(Method::kSwitch),
testing::ValuesIn(std::vector<Data>{
Types<i32, AInt>(0_a, kHighestI32 + 1), //
Types<i32, AInt>(0_a, kLowestI32 - 1), //
})));
INSTANTIATE_TEST_SUITE_P(WorkgroupSizeValueCannotBeRepresented,
MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
testing::Values(Method::kWorkgroupSize),
testing::ValuesIn(std::vector<Data>{ testing::ValuesIn(std::vector<Data>{
Types<i32, AInt>(0_a, kHighestI32 + 1), // Types<i32, AInt>(0_a, kHighestI32 + 1), //
Types<i32, AInt>(0_a, kLowestI32 - 1), // Types<i32, AInt>(0_a, kLowestI32 - 1), //

View File

@ -721,52 +721,61 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
} }
auto values = attr->Values(); auto values = attr->Values();
auto any_i32 = false; std::array<const sem::Expression*, 3> args = {};
auto any_u32 = false; std::array<const sem::Type*, 3> arg_tys = {};
for (int i = 0; i < 3; i++) { size_t arg_count = 0;
// Each argument to this attribute can either be a literal, an
// identifier for a module-scope constants, or nullptr if not specified.
auto* expr = values[i];
if (!expr) {
// Not specified, just use the default.
continue;
}
auto* expr_sem = Expression(expr);
if (!expr_sem) {
return false;
}
constexpr const char* kErrBadType = constexpr const char* kErrBadType =
"workgroup_size argument must be either literal or module-scope " "workgroup_size argument must be either literal or module-scope constant of type i32 "
"constant of type i32 or u32";
constexpr const char* kErrInconsistentType =
"workgroup_size arguments must be of the same type, either i32 "
"or u32"; "or u32";
auto* ty = sem_.TypeOf(expr); for (int i = 0; i < 3; i++) {
bool is_i32 = ty->UnwrapRef()->Is<sem::I32>(); // Each argument to this attribute can either be a literal, an identifier for a module-scope
bool is_u32 = ty->UnwrapRef()->Is<sem::U32>(); // constants, or nullptr if not specified.
if (!is_i32 && !is_u32) { auto* value = values[i];
AddError(kErrBadType, expr->source); if (!value) {
break;
}
const auto* expr = Expression(value);
if (!expr) {
return false;
}
auto* ty = expr->Type();
if (!ty->IsAnyOf<sem::I32, sem::U32, sem::AbstractInt>()) {
AddError(kErrBadType, value->source);
return false; return false;
} }
any_i32 = any_i32 || is_i32; args[i] = expr;
any_u32 = any_u32 || is_u32; arg_tys[i] = ty;
if (any_i32 && any_u32) { arg_count++;
AddError(kErrInconsistentType, expr->source); }
auto* common_ty = sem::Type::Common(arg_tys.data(), arg_count);
if (!common_ty) {
AddError("workgroup_size arguments must be of the same type, either i32 or u32",
attr->source);
return false;
}
// If all arguments are abstract-integers, then materialize to i32.
if (common_ty->Is<sem::AbstractInt>()) {
common_ty = builder_->create<sem::I32>();
}
for (size_t i = 0; i < arg_count; i++) {
auto* materialized = Materialize(args[i], common_ty);
if (!materialized) {
return false; return false;
} }
sem::Constant value; sem::Constant value;
if (auto* user = sem_.Get(expr)->As<sem::VariableUser>()) { if (auto* user = args[i]->As<sem::VariableUser>()) {
// We have an variable of a module-scope constant. // We have an variable of a module-scope constant.
auto* decl = user->Variable()->Declaration(); auto* decl = user->Variable()->Declaration();
if (!decl->is_const) { if (!decl->is_const) {
AddError(kErrBadType, expr->source); AddError(kErrBadType, values[i]->source);
return false; return false;
} }
// Capture the constant if it is pipeline-overridable. // Capture the constant if it is pipeline-overridable.
@ -781,8 +790,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
ws[i].value = 0; ws[i].value = 0;
continue; continue;
} }
} else if (expr->Is<ast::LiteralExpression>()) { } else if (values[i]->Is<ast::LiteralExpression>()) {
value = sem_.Get(expr)->ConstantValue(); value = materialized->ConstantValue();
} else { } else {
AddError( AddError(
"workgroup_size argument must be either a literal or a " "workgroup_size argument must be either a literal or a "