Update `workgroup_size` to use `expression`.
This CL updates the `workgroup_size` attribute to use `expression` values instead of `primary_expression`. Bug: tint:1633 Change-Id: I0afbabd8ee61943469f04a55d56f85920563e2da Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99960 Reviewed-by: Ben Clayton <bclayton@chromium.org> Commit-Queue: Dan Sinclair <dsinclair@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
436b4a2bd0
commit
7517e213cf
|
@ -2968,6 +2968,15 @@ class ProgramBuilder {
|
|||
return WorkgroupSize(std::forward<EXPR_X>(x), nullptr, nullptr);
|
||||
}
|
||||
|
||||
/// Creates an ast::WorkgroupAttribute
|
||||
/// @param source the source information
|
||||
/// @param x the x dimension expression
|
||||
/// @returns the workgroup attribute pointer
|
||||
template <typename EXPR_X>
|
||||
const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x) {
|
||||
return WorkgroupSize(source, std::forward<EXPR_X>(x), nullptr, nullptr);
|
||||
}
|
||||
|
||||
/// Creates an ast::WorkgroupAttribute
|
||||
/// @param source the source information
|
||||
/// @param x the x dimension expression
|
||||
|
|
|
@ -3405,28 +3405,25 @@ Expect<const ast::Attribute*> ParserImpl::expect_attribute() {
|
|||
}
|
||||
|
||||
// attribute
|
||||
// : ATTR 'align' PAREN_LEFT expression attrib_end
|
||||
// | ATTR 'binding' PAREN_LEFT expression attrib_end
|
||||
// | ATTR 'builtin' PAREN_LEFT builtin_value_name attrib_end
|
||||
// : ATTR 'align' PAREN_LEFT expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'binding' PAREN_LEFT expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'builtin' PAREN_LEFT builtin_value_name COMMA? PAREN_RIGHT
|
||||
// | ATTR 'const'
|
||||
// | ATTR 'group' PAREN_LEFT expression attrib_end
|
||||
// | ATTR 'id' PAREN_LEFT expression attrib_end
|
||||
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name attrib_end
|
||||
// | ATTR 'group' PAREN_LEFT expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'id' PAREN_LEFT expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA? PAREN_RIGHT
|
||||
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA
|
||||
// interpolation_sample_name attrib_end
|
||||
// interpolation_sample_name COMMA? PAREN_RIGHT
|
||||
// | ATTR 'invariant'
|
||||
// | ATTR 'location' PAREN_LEFT expression attrib_end
|
||||
// | ATTR 'size' PAREN_LEFT expression attrib_end
|
||||
// | ATTR 'workgroup_size' PAREN_LEFT expression attrib_end
|
||||
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression attrib_end
|
||||
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA expression attrib_end
|
||||
// | ATTR 'location' PAREN_LEFT expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'size' PAREN_LEFT expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA
|
||||
// expression COMMA? PAREN_RIGHT
|
||||
// | ATTR 'vertex'
|
||||
// | ATTR 'fragment'
|
||||
// | ATTR 'compute'
|
||||
//
|
||||
// attrib_end
|
||||
// : COMMA? PAREN_RIGHT
|
||||
//
|
||||
Maybe<const ast::Attribute*> ParserImpl::attribute() {
|
||||
using Result = Maybe<const ast::Attribute*>;
|
||||
auto& t = next();
|
||||
|
@ -3603,7 +3600,7 @@ Maybe<const ast::Attribute*> ParserImpl::attribute() {
|
|||
const ast::Expression* y = nullptr;
|
||||
const ast::Expression* z = nullptr;
|
||||
|
||||
auto expr = primary_expression();
|
||||
auto expr = expression();
|
||||
if (expr.errored) {
|
||||
return Failure::kErrored;
|
||||
} else if (!expr.matched) {
|
||||
|
@ -3613,7 +3610,7 @@ Maybe<const ast::Attribute*> ParserImpl::attribute() {
|
|||
|
||||
if (match(Token::Type::kComma)) {
|
||||
if (!peek_is(Token::Type::kParenRight)) {
|
||||
expr = primary_expression();
|
||||
expr = expression();
|
||||
if (expr.errored) {
|
||||
return Failure::kErrored;
|
||||
} else if (!expr.matched) {
|
||||
|
@ -3623,7 +3620,7 @@ Maybe<const ast::Attribute*> ParserImpl::attribute() {
|
|||
|
||||
if (match(Token::Type::kComma)) {
|
||||
if (!peek_is(Token::Type::kParenRight)) {
|
||||
expr = primary_expression();
|
||||
expr = expression();
|
||||
if (expr.errored) {
|
||||
return Failure::kErrored;
|
||||
} else if (!expr.matched) {
|
||||
|
|
|
@ -41,6 +41,35 @@ TEST_F(ParserImplTest, Attribute_Workgroup) {
|
|||
EXPECT_EQ(values[2], nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, Attribute_Workgroup_Expression) {
|
||||
auto p = parser("workgroup_size(4 + 2)");
|
||||
auto attr = p->attribute();
|
||||
EXPECT_TRUE(attr.matched);
|
||||
EXPECT_FALSE(attr.errored);
|
||||
ASSERT_NE(attr.value, nullptr) << p->error();
|
||||
ASSERT_FALSE(p->has_error());
|
||||
auto* func_attr = attr.value->As<ast::Attribute>();
|
||||
ASSERT_NE(func_attr, nullptr);
|
||||
ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
|
||||
|
||||
auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
|
||||
|
||||
ASSERT_TRUE(values[0]->Is<ast::BinaryExpression>());
|
||||
auto* expr = values[0]->As<ast::BinaryExpression>();
|
||||
EXPECT_EQ(expr->op, ast::BinaryOp::kAdd);
|
||||
|
||||
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 4);
|
||||
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
|
||||
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
EXPECT_EQ(values[1], nullptr);
|
||||
EXPECT_EQ(values[2], nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, Attribute_Workgroup_1Param_TrailingComma) {
|
||||
auto p = parser("workgroup_size(4,)");
|
||||
auto attr = p->attribute();
|
||||
|
@ -99,6 +128,39 @@ TEST_F(ParserImplTest, Attribute_Workgroup_2Param) {
|
|||
EXPECT_EQ(values[2], nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, Attribute_Workgroup_2Param_Expression) {
|
||||
auto p = parser("workgroup_size(4, 5 - 2)");
|
||||
auto attr = p->attribute();
|
||||
EXPECT_TRUE(attr.matched);
|
||||
EXPECT_FALSE(attr.errored);
|
||||
ASSERT_NE(attr.value, nullptr) << p->error();
|
||||
ASSERT_FALSE(p->has_error());
|
||||
auto* func_attr = attr.value->As<ast::Attribute>();
|
||||
ASSERT_NE(func_attr, nullptr) << p->error();
|
||||
ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
|
||||
|
||||
auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
|
||||
|
||||
ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
|
||||
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
|
||||
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
ASSERT_TRUE(values[1]->Is<ast::BinaryExpression>());
|
||||
auto* expr = values[1]->As<ast::BinaryExpression>();
|
||||
EXPECT_EQ(expr->op, ast::BinaryOp::kSubtract);
|
||||
|
||||
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 5);
|
||||
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
|
||||
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
EXPECT_EQ(values[2], nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, Attribute_Workgroup_2Param_TrailingComma) {
|
||||
auto p = parser("workgroup_size(4, 5,)");
|
||||
auto attr = p->attribute();
|
||||
|
@ -164,6 +226,42 @@ TEST_F(ParserImplTest, Attribute_Workgroup_3Param) {
|
|||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, Attribute_Workgroup_3Param_Expression) {
|
||||
auto p = parser("workgroup_size(4, 5, 6 << 1)");
|
||||
auto attr = p->attribute();
|
||||
EXPECT_TRUE(attr.matched);
|
||||
EXPECT_FALSE(attr.errored);
|
||||
ASSERT_NE(attr.value, nullptr) << p->error();
|
||||
ASSERT_FALSE(p->has_error());
|
||||
auto* func_attr = attr.value->As<ast::Attribute>();
|
||||
ASSERT_NE(func_attr, nullptr);
|
||||
ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
|
||||
|
||||
auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
|
||||
|
||||
ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
|
||||
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
|
||||
EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
ASSERT_TRUE(values[1]->Is<ast::IntLiteralExpression>());
|
||||
EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->value, 5);
|
||||
EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
ASSERT_TRUE(values[2]->Is<ast::BinaryExpression>());
|
||||
auto* expr = values[2]->As<ast::BinaryExpression>();
|
||||
EXPECT_EQ(expr->op, ast::BinaryOp::kShiftLeft);
|
||||
|
||||
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 6);
|
||||
EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 1);
|
||||
EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
|
||||
ast::IntLiteralExpression::Suffix::kNone);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, Attribute_Workgroup_3Param_TrailingComma) {
|
||||
auto p = parser("workgroup_size(4, 5, 6,)");
|
||||
auto attr = p->attribute();
|
||||
|
|
|
@ -545,6 +545,19 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) {
|
|||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Expr) {
|
||||
// @compute @workgroup_size(1 + 2)
|
||||
// fn main() {}
|
||||
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Source{{12, 34}}, Add(1_u, 2_u)),
|
||||
});
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
|
||||
// @compute @workgroup_size(1u, 2, 3_i)
|
||||
// fn main() {}
|
||||
|
@ -750,13 +763,43 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
|
|||
"overridable of type abstract-integer, i32 or u32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) {
|
||||
// @compute @workgroup_size(i32(1))
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
|
||||
// @compute @workgroup_size(1 << 2 + 4)
|
||||
// fn main() {}
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 1_i)),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size argument must be either a literal, constant, or "
|
||||
"overridable of type abstract-integer, i32 or u32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
|
||||
// @compute @workgroup_size(1, 1 << 2 + 4)
|
||||
// fn main() {}
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: workgroup_size argument must be either a literal, constant, or "
|
||||
"overridable of type abstract-integer, i32 or u32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
|
||||
// @compute @workgroup_size(1, 1, 1 << 2 + 4)
|
||||
// fn main() {}
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
|
|
@ -965,7 +965,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
|
|||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
// Each argument to this attribute can either be a literal, an identifier for a module-scope
|
||||
// constants, or nullptr if not specified.
|
||||
// constants, a constant expression, or nullptr if not specified.
|
||||
auto* value = values[i];
|
||||
if (!value) {
|
||||
break;
|
||||
|
@ -1023,7 +1023,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
|
|||
ws[i].value = 0;
|
||||
continue;
|
||||
}
|
||||
} else if (values[i]->Is<ast::LiteralExpression>()) {
|
||||
} else if (values[i]->Is<ast::LiteralExpression>() || args[i]->ConstantValue()) {
|
||||
value = materialized->ConstantValue();
|
||||
} else {
|
||||
AddError(kErrBadExpr, values[i]->source);
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
@id(0) override x_dim = 2;
|
||||
|
||||
@compute
|
||||
@workgroup_size(1 + 2, x_dim, clamp((1 - 2) + 4, 0, 5))
|
||||
fn main() {
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
#ifndef WGSL_SPEC_CONSTANT_0
|
||||
#define WGSL_SPEC_CONSTANT_0 2
|
||||
#endif
|
||||
static const int x_dim = WGSL_SPEC_CONSTANT_0;
|
||||
|
||||
[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
|
||||
void main() {
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
#ifndef WGSL_SPEC_CONSTANT_0
|
||||
#define WGSL_SPEC_CONSTANT_0 2
|
||||
#endif
|
||||
static const int x_dim = WGSL_SPEC_CONSTANT_0;
|
||||
|
||||
[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
|
||||
void main() {
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
#version 310 es
|
||||
|
||||
#ifndef WGSL_SPEC_CONSTANT_0
|
||||
#define WGSL_SPEC_CONSTANT_0 2
|
||||
#endif
|
||||
const int x_dim = WGSL_SPEC_CONSTANT_0;
|
||||
void tint_symbol() {
|
||||
}
|
||||
|
||||
layout(local_size_x = 3, local_size_y = WGSL_SPEC_CONSTANT_0, local_size_z = 3) in;
|
||||
void main() {
|
||||
tint_symbol();
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
constant int x_dim [[function_constant(0)]];
|
||||
|
||||
kernel void tint_symbol() {
|
||||
return;
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 12
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
OpName %x_dim "x_dim"
|
||||
OpName %main "main"
|
||||
OpDecorate %x_dim SpecId 0
|
||||
OpDecorate %11 SpecId 0
|
||||
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
|
||||
%int = OpTypeInt 32 1
|
||||
%x_dim = OpSpecConstant %int 2
|
||||
%void = OpTypeVoid
|
||||
%3 = OpTypeFunction %void
|
||||
%uint = OpTypeInt 32 0
|
||||
%v3uint = OpTypeVector %uint 3
|
||||
%uint_3 = OpConstant %uint 3
|
||||
%11 = OpSpecConstant %uint 2
|
||||
%gl_WorkGroupSize = OpSpecConstantComposite %v3uint %uint_3 %11 %uint_3
|
||||
%main = OpFunction %void None %3
|
||||
%6 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,5 @@
|
|||
@id(0) override x_dim = 2;
|
||||
|
||||
@compute @workgroup_size((1 + 2), x_dim, clamp(((1 - 2) + 4), 0, 5))
|
||||
fn main() {
|
||||
}
|
Loading…
Reference in New Issue