From 7517e213cf031d0216cca3c9e5ab4f4b27ac49a4 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Wed, 24 Aug 2022 21:31:45 +0000 Subject: [PATCH] Update `workgroup_size` to use `expression`. This CL updates the `workgroup_size` attribute to use `expression` values instead of `primary_expression`. Bug: tint:1633 Change-Id: I0afbabd8ee61943469f04a55d56f85920563e2da Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99960 Reviewed-by: Ben Clayton Commit-Queue: Dan Sinclair Kokoro: Kokoro --- src/tint/program_builder.h | 9 ++ src/tint/reader/wgsl/parser_impl.cc | 35 +++---- .../parser_impl_function_attribute_test.cc | 98 +++++++++++++++++++ src/tint/resolver/function_validation_test.cc | 49 +++++++++- src/tint/resolver/resolver.cc | 4 +- .../compute_workgroup_expression.wgsl | 7 ++ ...orkgroup_expression.wgsl.expected.dxc.hlsl | 9 ++ ...orkgroup_expression.wgsl.expected.fxc.hlsl | 9 ++ ...te_workgroup_expression.wgsl.expected.glsl | 14 +++ ...ute_workgroup_expression.wgsl.expected.msl | 9 ++ ..._workgroup_expression.wgsl.expected.spvasm | 26 +++++ ...te_workgroup_expression.wgsl.expected.wgsl | 5 + 12 files changed, 250 insertions(+), 24 deletions(-) create mode 100644 test/tint/shader_io/compute_workgroup_expression.wgsl create mode 100644 test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl create mode 100644 test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl create mode 100644 test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl create mode 100644 test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl create mode 100644 test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm create mode 100644 test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 5da80d68ab..c797aaba9e 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -2968,6 +2968,15 @@ class ProgramBuilder { return WorkgroupSize(std::forward(x), nullptr, nullptr); } + /// Creates an ast::WorkgroupAttribute + /// @param source the source information + /// @param x the x dimension expression + /// @returns the workgroup attribute pointer + template + const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x) { + return WorkgroupSize(source, std::forward(x), nullptr, nullptr); + } + /// Creates an ast::WorkgroupAttribute /// @param source the source information /// @param x the x dimension expression diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc index eff842b812..bf95b4e73b 100644 --- a/src/tint/reader/wgsl/parser_impl.cc +++ b/src/tint/reader/wgsl/parser_impl.cc @@ -3405,28 +3405,25 @@ Expect ParserImpl::expect_attribute() { } // attribute -// : ATTR 'align' PAREN_LEFT expression attrib_end -// | ATTR 'binding' PAREN_LEFT expression attrib_end -// | ATTR 'builtin' PAREN_LEFT builtin_value_name attrib_end +// : ATTR 'align' PAREN_LEFT expression COMMA? PAREN_RIGHT +// | ATTR 'binding' PAREN_LEFT expression COMMA? PAREN_RIGHT +// | ATTR 'builtin' PAREN_LEFT builtin_value_name COMMA? PAREN_RIGHT // | ATTR 'const' -// | ATTR 'group' PAREN_LEFT expression attrib_end -// | ATTR 'id' PAREN_LEFT expression attrib_end -// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name attrib_end +// | ATTR 'group' PAREN_LEFT expression COMMA? PAREN_RIGHT +// | ATTR 'id' PAREN_LEFT expression COMMA? PAREN_RIGHT +// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA? PAREN_RIGHT // | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA -// interpolation_sample_name attrib_end +// interpolation_sample_name COMMA? PAREN_RIGHT // | ATTR 'invariant' -// | ATTR 'location' PAREN_LEFT expression attrib_end -// | ATTR 'size' PAREN_LEFT expression attrib_end -// | ATTR 'workgroup_size' PAREN_LEFT expression attrib_end -// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression attrib_end -// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA expression attrib_end +// | ATTR 'location' PAREN_LEFT expression COMMA? PAREN_RIGHT +// | ATTR 'size' PAREN_LEFT expression COMMA? PAREN_RIGHT +// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA? PAREN_RIGHT +// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA? PAREN_RIGHT +// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA +// expression COMMA? PAREN_RIGHT // | ATTR 'vertex' // | ATTR 'fragment' // | ATTR 'compute' -// -// attrib_end -// : COMMA? PAREN_RIGHT -// Maybe ParserImpl::attribute() { using Result = Maybe; auto& t = next(); @@ -3603,7 +3600,7 @@ Maybe ParserImpl::attribute() { const ast::Expression* y = nullptr; const ast::Expression* z = nullptr; - auto expr = primary_expression(); + auto expr = expression(); if (expr.errored) { return Failure::kErrored; } else if (!expr.matched) { @@ -3613,7 +3610,7 @@ Maybe ParserImpl::attribute() { if (match(Token::Type::kComma)) { if (!peek_is(Token::Type::kParenRight)) { - expr = primary_expression(); + expr = expression(); if (expr.errored) { return Failure::kErrored; } else if (!expr.matched) { @@ -3623,7 +3620,7 @@ Maybe ParserImpl::attribute() { if (match(Token::Type::kComma)) { if (!peek_is(Token::Type::kParenRight)) { - expr = primary_expression(); + expr = expression(); if (expr.errored) { return Failure::kErrored; } else if (!expr.matched) { diff --git a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc index 8e56ee7952..60e8f861c8 100644 --- a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc +++ b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc @@ -41,6 +41,35 @@ TEST_F(ParserImplTest, Attribute_Workgroup) { EXPECT_EQ(values[2], nullptr); } +TEST_F(ParserImplTest, Attribute_Workgroup_Expression) { + auto p = parser("workgroup_size(4 + 2)"); + auto attr = p->attribute(); + EXPECT_TRUE(attr.matched); + EXPECT_FALSE(attr.errored); + ASSERT_NE(attr.value, nullptr) << p->error(); + ASSERT_FALSE(p->has_error()); + auto* func_attr = attr.value->As(); + ASSERT_NE(func_attr, nullptr); + ASSERT_TRUE(func_attr->Is()); + + auto values = func_attr->As()->Values(); + + ASSERT_TRUE(values[0]->Is()); + auto* expr = values[0]->As(); + EXPECT_EQ(expr->op, ast::BinaryOp::kAdd); + + EXPECT_EQ(expr->lhs->As()->value, 4); + EXPECT_EQ(expr->lhs->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + EXPECT_EQ(expr->rhs->As()->value, 2); + EXPECT_EQ(expr->rhs->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + EXPECT_EQ(values[1], nullptr); + EXPECT_EQ(values[2], nullptr); +} + TEST_F(ParserImplTest, Attribute_Workgroup_1Param_TrailingComma) { auto p = parser("workgroup_size(4,)"); auto attr = p->attribute(); @@ -99,6 +128,39 @@ TEST_F(ParserImplTest, Attribute_Workgroup_2Param) { EXPECT_EQ(values[2], nullptr); } +TEST_F(ParserImplTest, Attribute_Workgroup_2Param_Expression) { + auto p = parser("workgroup_size(4, 5 - 2)"); + auto attr = p->attribute(); + EXPECT_TRUE(attr.matched); + EXPECT_FALSE(attr.errored); + ASSERT_NE(attr.value, nullptr) << p->error(); + ASSERT_FALSE(p->has_error()); + auto* func_attr = attr.value->As(); + ASSERT_NE(func_attr, nullptr) << p->error(); + ASSERT_TRUE(func_attr->Is()); + + auto values = func_attr->As()->Values(); + + ASSERT_TRUE(values[0]->Is()); + EXPECT_EQ(values[0]->As()->value, 4); + EXPECT_EQ(values[0]->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + ASSERT_TRUE(values[1]->Is()); + auto* expr = values[1]->As(); + EXPECT_EQ(expr->op, ast::BinaryOp::kSubtract); + + EXPECT_EQ(expr->lhs->As()->value, 5); + EXPECT_EQ(expr->lhs->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + EXPECT_EQ(expr->rhs->As()->value, 2); + EXPECT_EQ(expr->rhs->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + EXPECT_EQ(values[2], nullptr); +} + TEST_F(ParserImplTest, Attribute_Workgroup_2Param_TrailingComma) { auto p = parser("workgroup_size(4, 5,)"); auto attr = p->attribute(); @@ -164,6 +226,42 @@ TEST_F(ParserImplTest, Attribute_Workgroup_3Param) { ast::IntLiteralExpression::Suffix::kNone); } +TEST_F(ParserImplTest, Attribute_Workgroup_3Param_Expression) { + auto p = parser("workgroup_size(4, 5, 6 << 1)"); + auto attr = p->attribute(); + EXPECT_TRUE(attr.matched); + EXPECT_FALSE(attr.errored); + ASSERT_NE(attr.value, nullptr) << p->error(); + ASSERT_FALSE(p->has_error()); + auto* func_attr = attr.value->As(); + ASSERT_NE(func_attr, nullptr); + ASSERT_TRUE(func_attr->Is()); + + auto values = func_attr->As()->Values(); + + ASSERT_TRUE(values[0]->Is()); + EXPECT_EQ(values[0]->As()->value, 4); + EXPECT_EQ(values[0]->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + ASSERT_TRUE(values[1]->Is()); + EXPECT_EQ(values[1]->As()->value, 5); + EXPECT_EQ(values[1]->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + ASSERT_TRUE(values[2]->Is()); + auto* expr = values[2]->As(); + EXPECT_EQ(expr->op, ast::BinaryOp::kShiftLeft); + + EXPECT_EQ(expr->lhs->As()->value, 6); + EXPECT_EQ(expr->lhs->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); + + EXPECT_EQ(expr->rhs->As()->value, 1); + EXPECT_EQ(expr->rhs->As()->suffix, + ast::IntLiteralExpression::Suffix::kNone); +} + TEST_F(ParserImplTest, Attribute_Workgroup_3Param_TrailingComma) { auto p = parser("workgroup_size(4, 5, 6,)"); auto attr = p->attribute(); diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc index f69950427c..22f26205c2 100644 --- a/src/tint/resolver/function_validation_test.cc +++ b/src/tint/resolver/function_validation_test.cc @@ -545,6 +545,19 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) { ASSERT_TRUE(r()->Resolve()) << r()->error(); } +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Expr) { + // @compute @workgroup_size(1 + 2) + // fn main() {} + + Func("main", utils::Empty, ty.void_(), utils::Empty, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Source{{12, 34}}, Add(1_u, 2_u)), + }); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) { // @compute @workgroup_size(1u, 2, 3_i) // fn main() {} @@ -750,13 +763,43 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) { "overridable of type abstract-integer, i32 or u32"); } -TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) { - // @compute @workgroup_size(i32(1)) +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) { + // @compute @workgroup_size(1 << 2 + 4) // fn main() {} Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 1_i)), + WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: workgroup_size argument must be either a literal, constant, or " + "overridable of type abstract-integer, i32 or u32"); +} + +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) { + // @compute @workgroup_size(1, 1 << 2 + 4) + // fn main() {} + Func("main", utils::Empty, ty.void_(), utils::Empty, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: workgroup_size argument must be either a literal, constant, or " + "overridable of type abstract-integer, i32 or u32"); +} + +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) { + // @compute @workgroup_size(1, 1, 1 << 2 + 4) + // fn main() {} + Func("main", utils::Empty, ty.void_(), utils::Empty, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))), }); EXPECT_FALSE(r()->Resolve()); diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index ed0739923e..90ec2549da 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -965,7 +965,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { for (size_t i = 0; i < 3; i++) { // Each argument to this attribute can either be a literal, an identifier for a module-scope - // constants, or nullptr if not specified. + // constants, a constant expression, or nullptr if not specified. auto* value = values[i]; if (!value) { break; @@ -1023,7 +1023,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { ws[i].value = 0; continue; } - } else if (values[i]->Is()) { + } else if (values[i]->Is() || args[i]->ConstantValue()) { value = materialized->ConstantValue(); } else { AddError(kErrBadExpr, values[i]->source); diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl b/test/tint/shader_io/compute_workgroup_expression.wgsl new file mode 100644 index 0000000000..a465a76db3 --- /dev/null +++ b/test/tint/shader_io/compute_workgroup_expression.wgsl @@ -0,0 +1,7 @@ +@id(0) override x_dim = 2; + +@compute +@workgroup_size(1 + 2, x_dim, clamp((1 - 2) + 4, 0, 5)) +fn main() { +} + diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl new file mode 100644 index 0000000000..2098cb3336 --- /dev/null +++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl @@ -0,0 +1,9 @@ +#ifndef WGSL_SPEC_CONSTANT_0 +#define WGSL_SPEC_CONSTANT_0 2 +#endif +static const int x_dim = WGSL_SPEC_CONSTANT_0; + +[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)] +void main() { + return; +} diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl new file mode 100644 index 0000000000..2098cb3336 --- /dev/null +++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl @@ -0,0 +1,9 @@ +#ifndef WGSL_SPEC_CONSTANT_0 +#define WGSL_SPEC_CONSTANT_0 2 +#endif +static const int x_dim = WGSL_SPEC_CONSTANT_0; + +[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)] +void main() { + return; +} diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl new file mode 100644 index 0000000000..6ac96fa4aa --- /dev/null +++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl @@ -0,0 +1,14 @@ +#version 310 es + +#ifndef WGSL_SPEC_CONSTANT_0 +#define WGSL_SPEC_CONSTANT_0 2 +#endif +const int x_dim = WGSL_SPEC_CONSTANT_0; +void tint_symbol() { +} + +layout(local_size_x = 3, local_size_y = WGSL_SPEC_CONSTANT_0, local_size_z = 3) in; +void main() { + tint_symbol(); + return; +} diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl new file mode 100644 index 0000000000..d9f2428573 --- /dev/null +++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl @@ -0,0 +1,9 @@ +#include + +using namespace metal; +constant int x_dim [[function_constant(0)]]; + +kernel void tint_symbol() { + return; +} + diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm new file mode 100644 index 0000000000..bcd4536dbf --- /dev/null +++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm @@ -0,0 +1,26 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 12 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpName %x_dim "x_dim" + OpName %main "main" + OpDecorate %x_dim SpecId 0 + OpDecorate %11 SpecId 0 + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + %int = OpTypeInt 32 1 + %x_dim = OpSpecConstant %int 2 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_3 = OpConstant %uint 3 + %11 = OpSpecConstant %uint 2 +%gl_WorkGroupSize = OpSpecConstantComposite %v3uint %uint_3 %11 %uint_3 + %main = OpFunction %void None %3 + %6 = OpLabel + OpReturn + OpFunctionEnd diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl new file mode 100644 index 0000000000..c2b260b2f0 --- /dev/null +++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl @@ -0,0 +1,5 @@ +@id(0) override x_dim = 2; + +@compute @workgroup_size((1 + 2), x_dim, clamp(((1 - 2) + 4), 0, 5)) +fn main() { +}