Update `workgroup_size` to use `expression`.

This CL updates the `workgroup_size` attribute to use `expression`
values instead of `primary_expression`.

Bug: tint:1633
Change-Id: I0afbabd8ee61943469f04a55d56f85920563e2da
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99960
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair 2022-08-24 21:31:45 +00:00 committed by Dawn LUCI CQ
parent 436b4a2bd0
commit 7517e213cf
12 changed files with 250 additions and 24 deletions

View File

@ -2968,6 +2968,15 @@ class ProgramBuilder {
return WorkgroupSize(std::forward<EXPR_X>(x), nullptr, nullptr);
}
/// Creates an ast::WorkgroupAttribute
/// @param source the source information
/// @param x the x dimension expression
/// @returns the workgroup attribute pointer
template <typename EXPR_X>
const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x) {
return WorkgroupSize(source, std::forward<EXPR_X>(x), nullptr, nullptr);
}
/// Creates an ast::WorkgroupAttribute
/// @param source the source information
/// @param x the x dimension expression

View File

@ -3405,28 +3405,25 @@ Expect<const ast::Attribute*> ParserImpl::expect_attribute() {
}
// attribute
// : ATTR 'align' PAREN_LEFT expression attrib_end
// | ATTR 'binding' PAREN_LEFT expression attrib_end
// | ATTR 'builtin' PAREN_LEFT builtin_value_name attrib_end
// : ATTR 'align' PAREN_LEFT expression COMMA? PAREN_RIGHT
// | ATTR 'binding' PAREN_LEFT expression COMMA? PAREN_RIGHT
// | ATTR 'builtin' PAREN_LEFT builtin_value_name COMMA? PAREN_RIGHT
// | ATTR 'const'
// | ATTR 'group' PAREN_LEFT expression attrib_end
// | ATTR 'id' PAREN_LEFT expression attrib_end
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name attrib_end
// | ATTR 'group' PAREN_LEFT expression COMMA? PAREN_RIGHT
// | ATTR 'id' PAREN_LEFT expression COMMA? PAREN_RIGHT
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA? PAREN_RIGHT
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA
// interpolation_sample_name attrib_end
// interpolation_sample_name COMMA? PAREN_RIGHT
// | ATTR 'invariant'
// | ATTR 'location' PAREN_LEFT expression attrib_end
// | ATTR 'size' PAREN_LEFT expression attrib_end
// | ATTR 'workgroup_size' PAREN_LEFT expression attrib_end
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression attrib_end
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA expression attrib_end
// | ATTR 'location' PAREN_LEFT expression COMMA? PAREN_RIGHT
// | ATTR 'size' PAREN_LEFT expression COMMA? PAREN_RIGHT
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA? PAREN_RIGHT
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA? PAREN_RIGHT
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA
// expression COMMA? PAREN_RIGHT
// | ATTR 'vertex'
// | ATTR 'fragment'
// | ATTR 'compute'
//
// attrib_end
// : COMMA? PAREN_RIGHT
//
Maybe<const ast::Attribute*> ParserImpl::attribute() {
using Result = Maybe<const ast::Attribute*>;
auto& t = next();
@ -3603,7 +3600,7 @@ Maybe<const ast::Attribute*> ParserImpl::attribute() {
const ast::Expression* y = nullptr;
const ast::Expression* z = nullptr;
auto expr = primary_expression();
auto expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
@ -3613,7 +3610,7 @@ Maybe<const ast::Attribute*> ParserImpl::attribute() {
if (match(Token::Type::kComma)) {
if (!peek_is(Token::Type::kParenRight)) {
expr = primary_expression();
expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
@ -3623,7 +3620,7 @@ Maybe<const ast::Attribute*> ParserImpl::attribute() {
if (match(Token::Type::kComma)) {
if (!peek_is(Token::Type::kParenRight)) {
expr = primary_expression();
expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {

View File

@ -41,6 +41,35 @@ TEST_F(ParserImplTest, Attribute_Workgroup) {
EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Attribute_Workgroup_Expression) {
auto p = parser("workgroup_size(4 + 2)");
auto attr = p->attribute();
EXPECT_TRUE(attr.matched);
EXPECT_FALSE(attr.errored);
ASSERT_NE(attr.value, nullptr) << p->error();
ASSERT_FALSE(p->has_error());
auto* func_attr = attr.value->As<ast::Attribute>();
ASSERT_NE(func_attr, nullptr);
ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
ASSERT_TRUE(values[0]->Is<ast::BinaryExpression>());
auto* expr = values[0]->As<ast::BinaryExpression>();
EXPECT_EQ(expr->op, ast::BinaryOp::kAdd);
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 4);
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(values[1], nullptr);
EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Attribute_Workgroup_1Param_TrailingComma) {
auto p = parser("workgroup_size(4,)");
auto attr = p->attribute();
@ -99,6 +128,39 @@ TEST_F(ParserImplTest, Attribute_Workgroup_2Param) {
EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Attribute_Workgroup_2Param_Expression) {
auto p = parser("workgroup_size(4, 5 - 2)");
auto attr = p->attribute();
EXPECT_TRUE(attr.matched);
EXPECT_FALSE(attr.errored);
ASSERT_NE(attr.value, nullptr) << p->error();
ASSERT_FALSE(p->has_error());
auto* func_attr = attr.value->As<ast::Attribute>();
ASSERT_NE(func_attr, nullptr) << p->error();
ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(values[1]->Is<ast::BinaryExpression>());
auto* expr = values[1]->As<ast::BinaryExpression>();
EXPECT_EQ(expr->op, ast::BinaryOp::kSubtract);
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 5);
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Attribute_Workgroup_2Param_TrailingComma) {
auto p = parser("workgroup_size(4, 5,)");
auto attr = p->attribute();
@ -164,6 +226,42 @@ TEST_F(ParserImplTest, Attribute_Workgroup_3Param) {
ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, Attribute_Workgroup_3Param_Expression) {
auto p = parser("workgroup_size(4, 5, 6 << 1)");
auto attr = p->attribute();
EXPECT_TRUE(attr.matched);
EXPECT_FALSE(attr.errored);
ASSERT_NE(attr.value, nullptr) << p->error();
ASSERT_FALSE(p->has_error());
auto* func_attr = attr.value->As<ast::Attribute>();
ASSERT_NE(func_attr, nullptr);
ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(values[1]->Is<ast::IntLiteralExpression>());
EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->value, 5);
EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(values[2]->Is<ast::BinaryExpression>());
auto* expr = values[2]->As<ast::BinaryExpression>();
EXPECT_EQ(expr->op, ast::BinaryOp::kShiftLeft);
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 6);
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 1);
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, Attribute_Workgroup_3Param_TrailingComma) {
auto p = parser("workgroup_size(4, 5, 6,)");
auto attr = p->attribute();

View File

@ -545,6 +545,19 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) {
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Expr) {
// @compute @workgroup_size(1 + 2)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Source{{12, 34}}, Add(1_u, 2_u)),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
// @compute @workgroup_size(1u, 2, 3_i)
// fn main() {}
@ -750,13 +763,43 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
"overridable of type abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) {
// @compute @workgroup_size(i32(1))
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
// @compute @workgroup_size(1 << 2 + 4)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 1_i)),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size argument must be either a literal, constant, or "
"overridable of type abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
// @compute @workgroup_size(1, 1 << 2 + 4)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size argument must be either a literal, constant, or "
"overridable of type abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
// @compute @workgroup_size(1, 1, 1 << 2 + 4)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
});
EXPECT_FALSE(r()->Resolve());

View File

@ -965,7 +965,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
for (size_t i = 0; i < 3; i++) {
// Each argument to this attribute can either be a literal, an identifier for a module-scope
// constants, or nullptr if not specified.
// constants, a constant expression, or nullptr if not specified.
auto* value = values[i];
if (!value) {
break;
@ -1023,7 +1023,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
ws[i].value = 0;
continue;
}
} else if (values[i]->Is<ast::LiteralExpression>()) {
} else if (values[i]->Is<ast::LiteralExpression>() || args[i]->ConstantValue()) {
value = materialized->ConstantValue();
} else {
AddError(kErrBadExpr, values[i]->source);

View File

@ -0,0 +1,7 @@
@id(0) override x_dim = 2;
@compute
@workgroup_size(1 + 2, x_dim, clamp((1 - 2) + 4, 0, 5))
fn main() {
}

View File

@ -0,0 +1,9 @@
#ifndef WGSL_SPEC_CONSTANT_0
#define WGSL_SPEC_CONSTANT_0 2
#endif
static const int x_dim = WGSL_SPEC_CONSTANT_0;
[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
void main() {
return;
}

View File

@ -0,0 +1,9 @@
#ifndef WGSL_SPEC_CONSTANT_0
#define WGSL_SPEC_CONSTANT_0 2
#endif
static const int x_dim = WGSL_SPEC_CONSTANT_0;
[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
void main() {
return;
}

View File

@ -0,0 +1,14 @@
#version 310 es
#ifndef WGSL_SPEC_CONSTANT_0
#define WGSL_SPEC_CONSTANT_0 2
#endif
const int x_dim = WGSL_SPEC_CONSTANT_0;
void tint_symbol() {
}
layout(local_size_x = 3, local_size_y = WGSL_SPEC_CONSTANT_0, local_size_z = 3) in;
void main() {
tint_symbol();
return;
}

View File

@ -0,0 +1,9 @@
#include <metal_stdlib>
using namespace metal;
constant int x_dim [[function_constant(0)]];
kernel void tint_symbol() {
return;
}

View File

@ -0,0 +1,26 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 12
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpName %x_dim "x_dim"
OpName %main "main"
OpDecorate %x_dim SpecId 0
OpDecorate %11 SpecId 0
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
%int = OpTypeInt 32 1
%x_dim = OpSpecConstant %int 2
%void = OpTypeVoid
%3 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%uint_3 = OpConstant %uint 3
%11 = OpSpecConstant %uint 2
%gl_WorkGroupSize = OpSpecConstantComposite %v3uint %uint_3 %11 %uint_3
%main = OpFunction %void None %3
%6 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,5 @@
@id(0) override x_dim = 2;
@compute @workgroup_size((1 + 2), x_dim, clamp(((1 - 2) + 4), 0, 5))
fn main() {
}