diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 38970a5338..5f47efb016 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -2546,10 +2546,20 @@ class ProgramBuilder { } /// Creates an ast::WorkgroupAttribute + /// @param source the source information /// @param x the x dimension expression /// @param y the y dimension expression /// @returns the workgroup attribute pointer template + const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x, EXPR_Y&& y) { + return WorkgroupSize(source, std::forward(x), std::forward(y), nullptr); + } + + /// Creates an ast::WorkgroupAttribute + /// @param x the x dimension expression + /// @param y the y dimension expression + /// @returns the workgroup attribute pointer + template > const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) { return WorkgroupSize(std::forward(x), std::forward(y), nullptr); } @@ -2575,7 +2585,7 @@ class ProgramBuilder { /// @param y the y dimension expression /// @param z the z dimension expression /// @returns the workgroup attribute pointer - template + template > const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) { return create(source_, Expr(std::forward(x)), Expr(std::forward(y)), diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc index ce37ba55ca..317cadc265 100644 --- a/src/tint/resolver/function_validation_test.cc +++ b/src/tint/resolver/function_validation_test.cc @@ -429,9 +429,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) { // fn main() {} auto* x = GlobalConst("x", ty.u32(), Expr(4_u)); auto* y = GlobalConst("y", ty.u32(), Expr(8_u)); - auto* func = Func( - "main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Expr("x"), Expr("y"), Expr(16_u))}); + auto* func = Func("main", {}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize("x", "y", 16_u)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -447,43 +446,68 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) { EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y)); } +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) { + // @stage(compute) @workgroup_size(1i, 2i, 3i) + // fn main() {} + + Func("main", {}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_i, 3_i)}); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) { // @stage(compute) @workgroup_size(1u, 2u, 3u) // fn main() {} Func("main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Source{{12, 34}}, Expr(1_u), Expr(2_u), Expr(3_u))}); + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_u, 3_u)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); } -TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeU32) { - // @stage(compute) @workgroup_size(1u, 2u, 3_i) +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32_AInt) { + // @stage(compute) @workgroup_size(1, 2i, 3) // fn main() {} Func("main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Expr(1_u), Expr(2_u), Expr(Source{{12, 34}}, 3_i))}); + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_a, 2_i, 3_a)}); - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size arguments must be of the same type, " - "either i32 or u32"); + ASSERT_TRUE(r()->Resolve()) << r()->error(); } -TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeI32) { - // @stage(compute) @workgroup_size(1_i, 2u, 3_i) +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) { + // @stage(compute) @workgroup_size(1u, 2, 3u) // fn main() {} Func("main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, 2_u), Expr(3_i))}); + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_u)}); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) { + // @stage(compute) @workgroup_size(1u, 2, 3_i) + // fn main() {} + + Func("main", {}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_i)}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size arguments must be of the same type, " - "either i32 or u32"); + "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32"); +} + +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_I32) { + // @stage(compute) @workgroup_size(1_i, 2u, 3) + // fn main() {} + + Func("main", {}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_u, 3_a)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) { @@ -492,13 +516,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) { // fn main() {} GlobalConst("x", ty.u32(), Expr(64_u)); Func("main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, "x"))}); + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, "x")}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size arguments must be of the same type, " - "either i32 or u32"); + "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) { @@ -509,13 +531,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) { GlobalConst("x", ty.u32(), Expr(64_u)); GlobalConst("y", ty.i32(), Expr(32_i)); Func("main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Expr("x"), Expr(Source{{12, 34}}, "y"))}); + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y")}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size arguments must be of the same type, " - "either i32 or u32"); + "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) { // let x = 4u; @@ -525,13 +545,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) { GlobalConst("x", ty.u32(), Expr(4_u)); GlobalConst("y", ty.u32(), Expr(8_u)); Func("main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Expr("x"), Expr("y"), Expr(Source{{12, 34}}, 16_i))}); + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y", 16_i)}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size arguments must be of the same type, " - "either i32 or u32"); + "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) { diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc index bcb1671856..8f4451e728 100644 --- a/src/tint/resolver/materialize_test.cc +++ b/src/tint/resolver/materialize_test.cc @@ -134,6 +134,11 @@ enum class Method { // default: {} // } kSwitchCaseWithAbstractCase, + + // @workgroup_size(target_expr, abstract_expr, 123) + // @stage(compute) + // fn f() {} + kWorkgroupSize }; static std::ostream& operator<<(std::ostream& o, Method m) { @@ -162,6 +167,8 @@ static std::ostream& operator<<(std::ostream& o, Method m) { return o << "switch-cond-with-abstract"; case Method::kSwitchCaseWithAbstractCase: return o << "switch-case-with-abstract"; + case Method::kWorkgroupSize: + return o << "workgroup-size"; } return o << ""; } @@ -286,6 +293,11 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) { Case(abstract_expr->As()), // DefaultCase())); break; + case Method::kWorkgroupSize: + Func("f", {}, ty.void_(), {}, + {WorkgroupSize(target_expr(), abstract_expr, Expr(123_a)), + Stage(ast::PipelineStage::kCompute)}); + break; } auto check_types_and_values = [&](const sem::Expression* expr) { @@ -461,6 +473,19 @@ INSTANTIATE_TEST_SUITE_P(MaterializeSwitch, Types(AInt(kLowestU32), kLowestU32), // }))); +INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize, + MaterializeAbstractNumericToConcreteType, + testing::Combine(testing::Values(Expectation::kMaterialize), + testing::Values(Method::kWorkgroupSize), + testing::ValuesIn(std::vector{ + Types(1_a, 1.0), // + Types(10_a, 10.0), // + Types(65535_a, 65535.0), // + Types(1_a, 1.0), // + Types(10_a, 10.0), // + Types(65535_a, 65535.0), // + }))); + // TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops. INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize, MaterializeAbstractNumericToConcreteType, @@ -558,6 +583,11 @@ enum class Method { // default: {} // } kSwitch, + + // @workgroup_size(abstract_expr) + // @stage(compute) + // fn f() {} + kWorkgroupSize }; static std::ostream& operator<<(std::ostream& o, Method m) { @@ -576,6 +606,8 @@ static std::ostream& operator<<(std::ostream& o, Method m) { return o << "array-length"; case Method::kSwitch: return o << "switch"; + case Method::kWorkgroupSize: + return o << "workgroup-size"; } return o << ""; } @@ -656,6 +688,10 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) { Case(abstract_expr()->As()), DefaultCase())); break; + case Method::kWorkgroupSize: + Func("f", {}, ty.void_(), {}, + {WorkgroupSize(abstract_expr()), Stage(ast::PipelineStage::kCompute)}); + break; } auto check_types_and_values = [&](const sem::Expression* expr) { @@ -734,11 +770,6 @@ constexpr Method kMatrixMethods[] = { Method::kVar, }; -/// Methods that support materialization for switch cases -constexpr Method kSwitchMethods[] = { - Method::kSwitch, -}; - INSTANTIATE_TEST_SUITE_P( MaterializeScalar, MaterializeAbstractNumericToDefaultType, @@ -798,13 +829,23 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(MaterializeSwitch, MaterializeAbstractNumericToDefaultType, testing::Combine(testing::Values(Expectation::kMaterialize), - testing::ValuesIn(kSwitchMethods), + testing::Values(Method::kSwitch), testing::ValuesIn(std::vector{ Types(0_a, 0.0), // Types(AInt(kHighestI32), kHighestI32), // Types(AInt(kLowestI32), kLowestI32), // }))); +INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize, + MaterializeAbstractNumericToDefaultType, + testing::Combine(testing::Values(Expectation::kMaterialize), + testing::Values(Method::kWorkgroupSize), + testing::ValuesIn(std::vector{ + Types(1_a, 1.0), // + Types(10_a, 10.0), // + Types(65535_a, 65535.0), // + }))); + INSTANTIATE_TEST_SUITE_P(ScalarValueCannotBeRepresented, MaterializeAbstractNumericToDefaultType, testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), @@ -840,7 +881,16 @@ INSTANTIATE_TEST_SUITE_P(MatrixValueCannotBeRepresented, INSTANTIATE_TEST_SUITE_P(SwitchValueCannotBeRepresented, MaterializeAbstractNumericToDefaultType, testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), - testing::ValuesIn(kSwitchMethods), + testing::Values(Method::kSwitch), + testing::ValuesIn(std::vector{ + Types(0_a, kHighestI32 + 1), // + Types(0_a, kLowestI32 - 1), // + }))); + +INSTANTIATE_TEST_SUITE_P(WorkgroupSizeValueCannotBeRepresented, + MaterializeAbstractNumericToDefaultType, + testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), + testing::Values(Method::kWorkgroupSize), testing::ValuesIn(std::vector{ Types(0_a, kHighestI32 + 1), // Types(0_a, kLowestI32 - 1), // diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 807dc62370..c964d8adcb 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -721,52 +721,61 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { } auto values = attr->Values(); - auto any_i32 = false; - auto any_u32 = false; + std::array args = {}; + std::array arg_tys = {}; + size_t arg_count = 0; + + constexpr const char* kErrBadType = + "workgroup_size argument must be either literal or module-scope constant of type i32 " + "or u32"; + for (int i = 0; i < 3; i++) { - // Each argument to this attribute can either be a literal, an - // identifier for a module-scope constants, or nullptr if not specified. - - auto* expr = values[i]; + // Each argument to this attribute can either be a literal, an identifier for a module-scope + // constants, or nullptr if not specified. + auto* value = values[i]; + if (!value) { + break; + } + const auto* expr = Expression(value); if (!expr) { - // Not specified, just use the default. - continue; + return false; } - - auto* expr_sem = Expression(expr); - if (!expr_sem) { + auto* ty = expr->Type(); + if (!ty->IsAnyOf()) { + AddError(kErrBadType, value->source); return false; } - constexpr const char* kErrBadType = - "workgroup_size argument must be either literal or module-scope " - "constant of type i32 or u32"; - constexpr const char* kErrInconsistentType = - "workgroup_size arguments must be of the same type, either i32 " - "or u32"; + args[i] = expr; + arg_tys[i] = ty; + arg_count++; + } - auto* ty = sem_.TypeOf(expr); - bool is_i32 = ty->UnwrapRef()->Is(); - bool is_u32 = ty->UnwrapRef()->Is(); - if (!is_i32 && !is_u32) { - AddError(kErrBadType, expr->source); - return false; - } + auto* common_ty = sem::Type::Common(arg_tys.data(), arg_count); + if (!common_ty) { + AddError("workgroup_size arguments must be of the same type, either i32 or u32", + attr->source); + return false; + } - any_i32 = any_i32 || is_i32; - any_u32 = any_u32 || is_u32; - if (any_i32 && any_u32) { - AddError(kErrInconsistentType, expr->source); + // If all arguments are abstract-integers, then materialize to i32. + if (common_ty->Is()) { + common_ty = builder_->create(); + } + + for (size_t i = 0; i < arg_count; i++) { + auto* materialized = Materialize(args[i], common_ty); + if (!materialized) { return false; } sem::Constant value; - if (auto* user = sem_.Get(expr)->As()) { + if (auto* user = args[i]->As()) { // We have an variable of a module-scope constant. auto* decl = user->Variable()->Declaration(); if (!decl->is_const) { - AddError(kErrBadType, expr->source); + AddError(kErrBadType, values[i]->source); return false; } // Capture the constant if it is pipeline-overridable. @@ -781,8 +790,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { ws[i].value = 0; continue; } - } else if (expr->Is()) { - value = sem_.Get(expr)->ConstantValue(); + } else if (values[i]->Is()) { + value = materialized->ConstantValue(); } else { AddError( "workgroup_size argument must be either a literal or a "