tint/resolver: Materialize arguments to @workgroup_size
Bug: tint:1504 Change-Id: I69b448e62a4ebd684f6832f76fd28d8a31892a1a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91847 Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
b59de924db
commit
b8ac933909
|
@ -2546,10 +2546,20 @@ class ProgramBuilder {
|
|||
}
|
||||
|
||||
/// Creates an ast::WorkgroupAttribute
|
||||
/// @param source the source information
|
||||
/// @param x the x dimension expression
|
||||
/// @param y the y dimension expression
|
||||
/// @returns the workgroup attribute pointer
|
||||
template <typename EXPR_X, typename EXPR_Y>
|
||||
const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x, EXPR_Y&& y) {
|
||||
return WorkgroupSize(source, std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
|
||||
}
|
||||
|
||||
/// Creates an ast::WorkgroupAttribute
|
||||
/// @param x the x dimension expression
|
||||
/// @param y the y dimension expression
|
||||
/// @returns the workgroup attribute pointer
|
||||
template <typename EXPR_X, typename EXPR_Y, typename = DisableIfSource<EXPR_X>>
|
||||
const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) {
|
||||
return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
|
||||
}
|
||||
|
@ -2575,7 +2585,7 @@ class ProgramBuilder {
|
|||
/// @param y the y dimension expression
|
||||
/// @param z the z dimension expression
|
||||
/// @returns the workgroup attribute pointer
|
||||
template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z>
|
||||
template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z, typename = DisableIfSource<EXPR_X>>
|
||||
const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) {
|
||||
return create<ast::WorkgroupAttribute>(source_, Expr(std::forward<EXPR_X>(x)),
|
||||
Expr(std::forward<EXPR_Y>(y)),
|
||||
|
|
|
@ -429,9 +429,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
|
|||
// fn main() {}
|
||||
auto* x = GlobalConst("x", ty.u32(), Expr(4_u));
|
||||
auto* y = GlobalConst("y", ty.u32(), Expr(8_u));
|
||||
auto* func = Func(
|
||||
"main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Expr("x"), Expr("y"), Expr(16_u))});
|
||||
auto* func = Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize("x", "y", 16_u)});
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
|
@ -447,43 +446,68 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
|
|||
EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y));
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
|
||||
// @stage(compute) @workgroup_size(1i, 2i, 3i)
|
||||
// fn main() {}
|
||||
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_i, 3_i)});
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
|
||||
// @stage(compute) @workgroup_size(1u, 2u, 3u)
|
||||
// fn main() {}
|
||||
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Source{{12, 34}}, Expr(1_u), Expr(2_u), Expr(3_u))});
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_u, 3_u)});
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeU32) {
|
||||
// @stage(compute) @workgroup_size(1u, 2u, 3_i)
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32_AInt) {
|
||||
// @stage(compute) @workgroup_size(1, 2i, 3)
|
||||
// fn main() {}
|
||||
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Expr(1_u), Expr(2_u), Expr(Source{{12, 34}}, 3_i))});
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_a, 2_i, 3_a)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size arguments must be of the same type, "
|
||||
"either i32 or u32");
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeI32) {
|
||||
// @stage(compute) @workgroup_size(1_i, 2u, 3_i)
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) {
|
||||
// @stage(compute) @workgroup_size(1u, 2, 3u)
|
||||
// fn main() {}
|
||||
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, 2_u), Expr(3_i))});
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_u)});
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
|
||||
// @stage(compute) @workgroup_size(1u, 2, 3_i)
|
||||
// fn main() {}
|
||||
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_i)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size arguments must be of the same type, "
|
||||
"either i32 or u32");
|
||||
"12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_I32) {
|
||||
// @stage(compute) @workgroup_size(1_i, 2u, 3)
|
||||
// fn main() {}
|
||||
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_u, 3_a)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
|
||||
|
@ -492,13 +516,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
|
|||
// fn main() {}
|
||||
GlobalConst("x", ty.u32(), Expr(64_u));
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, "x"))});
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, "x")});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size arguments must be of the same type, "
|
||||
"either i32 or u32");
|
||||
"12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
|
||||
|
@ -509,13 +531,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
|
|||
GlobalConst("x", ty.u32(), Expr(64_u));
|
||||
GlobalConst("y", ty.i32(), Expr(32_i));
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Expr("x"), Expr(Source{{12, 34}}, "y"))});
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y")});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size arguments must be of the same type, "
|
||||
"either i32 or u32");
|
||||
"12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
|
||||
}
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
|
||||
// let x = 4u;
|
||||
|
@ -525,13 +545,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
|
|||
GlobalConst("x", ty.u32(), Expr(4_u));
|
||||
GlobalConst("y", ty.u32(), Expr(8_u));
|
||||
Func("main", {}, ty.void_(), {},
|
||||
{Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Expr("x"), Expr("y"), Expr(Source{{12, 34}}, 16_i))});
|
||||
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y", 16_i)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size arguments must be of the same type, "
|
||||
"either i32 or u32");
|
||||
"12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
|
||||
|
|
|
@ -134,6 +134,11 @@ enum class Method {
|
|||
// default: {}
|
||||
// }
|
||||
kSwitchCaseWithAbstractCase,
|
||||
|
||||
// @workgroup_size(target_expr, abstract_expr, 123)
|
||||
// @stage(compute)
|
||||
// fn f() {}
|
||||
kWorkgroupSize
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, Method m) {
|
||||
|
@ -162,6 +167,8 @@ static std::ostream& operator<<(std::ostream& o, Method m) {
|
|||
return o << "switch-cond-with-abstract";
|
||||
case Method::kSwitchCaseWithAbstractCase:
|
||||
return o << "switch-case-with-abstract";
|
||||
case Method::kWorkgroupSize:
|
||||
return o << "workgroup-size";
|
||||
}
|
||||
return o << "<unknown>";
|
||||
}
|
||||
|
@ -286,6 +293,11 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) {
|
|||
Case(abstract_expr->As<ast::IntLiteralExpression>()), //
|
||||
DefaultCase()));
|
||||
break;
|
||||
case Method::kWorkgroupSize:
|
||||
Func("f", {}, ty.void_(), {},
|
||||
{WorkgroupSize(target_expr(), abstract_expr, Expr(123_a)),
|
||||
Stage(ast::PipelineStage::kCompute)});
|
||||
break;
|
||||
}
|
||||
|
||||
auto check_types_and_values = [&](const sem::Expression* expr) {
|
||||
|
@ -461,6 +473,19 @@ INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
|
|||
Types<u32, AInt>(AInt(kLowestU32), kLowestU32), //
|
||||
})));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
|
||||
MaterializeAbstractNumericToConcreteType,
|
||||
testing::Combine(testing::Values(Expectation::kMaterialize),
|
||||
testing::Values(Method::kWorkgroupSize),
|
||||
testing::ValuesIn(std::vector<Data>{
|
||||
Types<i32, AInt>(1_a, 1.0), //
|
||||
Types<i32, AInt>(10_a, 10.0), //
|
||||
Types<i32, AInt>(65535_a, 65535.0), //
|
||||
Types<u32, AInt>(1_a, 1.0), //
|
||||
Types<u32, AInt>(10_a, 10.0), //
|
||||
Types<u32, AInt>(65535_a, 65535.0), //
|
||||
})));
|
||||
|
||||
// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops.
|
||||
INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
|
||||
MaterializeAbstractNumericToConcreteType,
|
||||
|
@ -558,6 +583,11 @@ enum class Method {
|
|||
// default: {}
|
||||
// }
|
||||
kSwitch,
|
||||
|
||||
// @workgroup_size(abstract_expr)
|
||||
// @stage(compute)
|
||||
// fn f() {}
|
||||
kWorkgroupSize
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, Method m) {
|
||||
|
@ -576,6 +606,8 @@ static std::ostream& operator<<(std::ostream& o, Method m) {
|
|||
return o << "array-length";
|
||||
case Method::kSwitch:
|
||||
return o << "switch";
|
||||
case Method::kWorkgroupSize:
|
||||
return o << "workgroup-size";
|
||||
}
|
||||
return o << "<unknown>";
|
||||
}
|
||||
|
@ -656,6 +688,10 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) {
|
|||
Case(abstract_expr()->As<ast::IntLiteralExpression>()),
|
||||
DefaultCase()));
|
||||
break;
|
||||
case Method::kWorkgroupSize:
|
||||
Func("f", {}, ty.void_(), {},
|
||||
{WorkgroupSize(abstract_expr()), Stage(ast::PipelineStage::kCompute)});
|
||||
break;
|
||||
}
|
||||
|
||||
auto check_types_and_values = [&](const sem::Expression* expr) {
|
||||
|
@ -734,11 +770,6 @@ constexpr Method kMatrixMethods[] = {
|
|||
Method::kVar,
|
||||
};
|
||||
|
||||
/// Methods that support materialization for switch cases
|
||||
constexpr Method kSwitchMethods[] = {
|
||||
Method::kSwitch,
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
MaterializeScalar,
|
||||
MaterializeAbstractNumericToDefaultType,
|
||||
|
@ -798,13 +829,23 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
|
||||
MaterializeAbstractNumericToDefaultType,
|
||||
testing::Combine(testing::Values(Expectation::kMaterialize),
|
||||
testing::ValuesIn(kSwitchMethods),
|
||||
testing::Values(Method::kSwitch),
|
||||
testing::ValuesIn(std::vector<Data>{
|
||||
Types<i32, AInt>(0_a, 0.0), //
|
||||
Types<i32, AInt>(AInt(kHighestI32), kHighestI32), //
|
||||
Types<i32, AInt>(AInt(kLowestI32), kLowestI32), //
|
||||
})));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
|
||||
MaterializeAbstractNumericToDefaultType,
|
||||
testing::Combine(testing::Values(Expectation::kMaterialize),
|
||||
testing::Values(Method::kWorkgroupSize),
|
||||
testing::ValuesIn(std::vector<Data>{
|
||||
Types<i32, AInt>(1_a, 1.0), //
|
||||
Types<i32, AInt>(10_a, 10.0), //
|
||||
Types<i32, AInt>(65535_a, 65535.0), //
|
||||
})));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ScalarValueCannotBeRepresented,
|
||||
MaterializeAbstractNumericToDefaultType,
|
||||
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
|
||||
|
@ -840,7 +881,16 @@ INSTANTIATE_TEST_SUITE_P(MatrixValueCannotBeRepresented,
|
|||
INSTANTIATE_TEST_SUITE_P(SwitchValueCannotBeRepresented,
|
||||
MaterializeAbstractNumericToDefaultType,
|
||||
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
|
||||
testing::ValuesIn(kSwitchMethods),
|
||||
testing::Values(Method::kSwitch),
|
||||
testing::ValuesIn(std::vector<Data>{
|
||||
Types<i32, AInt>(0_a, kHighestI32 + 1), //
|
||||
Types<i32, AInt>(0_a, kLowestI32 - 1), //
|
||||
})));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(WorkgroupSizeValueCannotBeRepresented,
|
||||
MaterializeAbstractNumericToDefaultType,
|
||||
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
|
||||
testing::Values(Method::kWorkgroupSize),
|
||||
testing::ValuesIn(std::vector<Data>{
|
||||
Types<i32, AInt>(0_a, kHighestI32 + 1), //
|
||||
Types<i32, AInt>(0_a, kLowestI32 - 1), //
|
||||
|
|
|
@ -721,52 +721,61 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
|
|||
}
|
||||
|
||||
auto values = attr->Values();
|
||||
auto any_i32 = false;
|
||||
auto any_u32 = false;
|
||||
std::array<const sem::Expression*, 3> args = {};
|
||||
std::array<const sem::Type*, 3> arg_tys = {};
|
||||
size_t arg_count = 0;
|
||||
|
||||
constexpr const char* kErrBadType =
|
||||
"workgroup_size argument must be either literal or module-scope constant of type i32 "
|
||||
"or u32";
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
// Each argument to this attribute can either be a literal, an
|
||||
// identifier for a module-scope constants, or nullptr if not specified.
|
||||
|
||||
auto* expr = values[i];
|
||||
// Each argument to this attribute can either be a literal, an identifier for a module-scope
|
||||
// constants, or nullptr if not specified.
|
||||
auto* value = values[i];
|
||||
if (!value) {
|
||||
break;
|
||||
}
|
||||
const auto* expr = Expression(value);
|
||||
if (!expr) {
|
||||
// Not specified, just use the default.
|
||||
continue;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* expr_sem = Expression(expr);
|
||||
if (!expr_sem) {
|
||||
auto* ty = expr->Type();
|
||||
if (!ty->IsAnyOf<sem::I32, sem::U32, sem::AbstractInt>()) {
|
||||
AddError(kErrBadType, value->source);
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr const char* kErrBadType =
|
||||
"workgroup_size argument must be either literal or module-scope "
|
||||
"constant of type i32 or u32";
|
||||
constexpr const char* kErrInconsistentType =
|
||||
"workgroup_size arguments must be of the same type, either i32 "
|
||||
"or u32";
|
||||
args[i] = expr;
|
||||
arg_tys[i] = ty;
|
||||
arg_count++;
|
||||
}
|
||||
|
||||
auto* ty = sem_.TypeOf(expr);
|
||||
bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
|
||||
bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
|
||||
if (!is_i32 && !is_u32) {
|
||||
AddError(kErrBadType, expr->source);
|
||||
return false;
|
||||
}
|
||||
auto* common_ty = sem::Type::Common(arg_tys.data(), arg_count);
|
||||
if (!common_ty) {
|
||||
AddError("workgroup_size arguments must be of the same type, either i32 or u32",
|
||||
attr->source);
|
||||
return false;
|
||||
}
|
||||
|
||||
any_i32 = any_i32 || is_i32;
|
||||
any_u32 = any_u32 || is_u32;
|
||||
if (any_i32 && any_u32) {
|
||||
AddError(kErrInconsistentType, expr->source);
|
||||
// If all arguments are abstract-integers, then materialize to i32.
|
||||
if (common_ty->Is<sem::AbstractInt>()) {
|
||||
common_ty = builder_->create<sem::I32>();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < arg_count; i++) {
|
||||
auto* materialized = Materialize(args[i], common_ty);
|
||||
if (!materialized) {
|
||||
return false;
|
||||
}
|
||||
|
||||
sem::Constant value;
|
||||
|
||||
if (auto* user = sem_.Get(expr)->As<sem::VariableUser>()) {
|
||||
if (auto* user = args[i]->As<sem::VariableUser>()) {
|
||||
// We have an variable of a module-scope constant.
|
||||
auto* decl = user->Variable()->Declaration();
|
||||
if (!decl->is_const) {
|
||||
AddError(kErrBadType, expr->source);
|
||||
AddError(kErrBadType, values[i]->source);
|
||||
return false;
|
||||
}
|
||||
// Capture the constant if it is pipeline-overridable.
|
||||
|
@ -781,8 +790,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
|
|||
ws[i].value = 0;
|
||||
continue;
|
||||
}
|
||||
} else if (expr->Is<ast::LiteralExpression>()) {
|
||||
value = sem_.Get(expr)->ConstantValue();
|
||||
} else if (values[i]->Is<ast::LiteralExpression>()) {
|
||||
value = materialized->ConstantValue();
|
||||
} else {
|
||||
AddError(
|
||||
"workgroup_size argument must be either a literal or a "
|
||||
|
|
Loading…
Reference in New Issue