ast: Support non-literal workgroup_size parameters

Change the type of the values in an ast::WorkgroupDecoration to be
ast::Expression nodes, so that they can represent both
ast::ScalarExpression (literal) and ast::IdentifierExpression
(module-scope constant).

The Resolver processes these nodes to produce a uint32_t for the
default value on each dimension, and captures a reference to the
module-scope constant if it is overridable (which will soon be used by
the inspector and backends).

The WGSL parser now uses `primary_expression` to parse arguments to
workgroup_size.

Also added some WorkgroupSize() helpers to ProgramBuilder.

Bug: tint:713
Change-Id: I44b7b0021b925c84f25f65e26dc7da6b19ede508
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51262
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-05-19 13:40:08 +00:00 committed by Commit Bot service account
parent 40ac15f157
commit 70f80bb13d
22 changed files with 741 additions and 337 deletions

View File

@ -107,7 +107,7 @@ TEST_F(FunctionTest, Assert_DifferentProgramID_Deco) {
ProgramBuilder b2;
b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
DecorationList{
b2.create<WorkgroupDecoration>(2, 4, 6),
b2.WorkgroupSize(2, 4, 6),
});
},
"internal compiler error");
@ -121,7 +121,7 @@ TEST_F(FunctionTest, Assert_DifferentProgramID_ReturnDeco) {
b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
DecorationList{},
DecorationList{
b2.create<WorkgroupDecoration>(2, 4, 6),
b2.WorkgroupSize(2, 4, 6),
});
},
"internal compiler error");
@ -159,10 +159,14 @@ TEST_F(FunctionTest, ToStr_WithDecoration) {
StatementList{
create<DiscardStatement>(),
},
DecorationList{create<WorkgroupDecoration>(2, 4, 6)});
DecorationList{WorkgroupSize(2, 4, 6)});
EXPECT_EQ(str(f), R"(Function func -> __void
WorkgroupDecoration{2 4 6}
WorkgroupDecoration{
ScalarConstructor[not set]{2}
ScalarConstructor[not set]{4}
ScalarConstructor[not set]{6}
}
()
{
Discard{}

View File

@ -25,6 +25,9 @@ class IntLiteral : public Castable<IntLiteral, Literal> {
public:
~IntLiteral() override;
/// @returns the literal value as an i32
int32_t value_as_i32() const { return static_cast<int32_t>(value_); }
/// @returns the literal value as a u32
uint32_t value_as_u32() const { return value_; }

View File

@ -23,36 +23,36 @@ namespace ast {
WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
const Source& source,
uint32_t x)
: WorkgroupDecoration(program_id, source, x, 1, 1) {}
WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
const Source& source,
uint32_t x,
uint32_t y)
: WorkgroupDecoration(program_id, source, x, y, 1) {}
WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
const Source& source,
uint32_t x,
uint32_t y,
uint32_t z)
ast::Expression* x,
ast::Expression* y,
ast::Expression* z)
: Base(program_id, source), x_(x), y_(y), z_(z) {}
WorkgroupDecoration::~WorkgroupDecoration() = default;
void WorkgroupDecoration::to_str(const sem::Info&,
void WorkgroupDecoration::to_str(const sem::Info& sem,
std::ostream& out,
size_t indent) const {
make_indent(out, indent);
out << "WorkgroupDecoration{" << x_ << " " << y_ << " " << z_ << "}"
<< std::endl;
out << "WorkgroupDecoration{" << std::endl;
x_->to_str(sem, out, indent + 2);
if (y_) {
y_->to_str(sem, out, indent + 2);
if (z_) {
z_->to_str(sem, out, indent + 2);
}
}
make_indent(out, indent);
out << "}" << std::endl;
}
WorkgroupDecoration* WorkgroupDecoration::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source());
return ctx->dst->create<WorkgroupDecoration>(src, x_, y_, z_);
auto* x = ctx->Clone(x_);
auto* y = ctx->Clone(y_);
auto* z = ctx->Clone(z_);
return ctx->dst->create<WorkgroupDecoration>(src, x, y, z);
}
} // namespace ast

View File

@ -15,47 +15,35 @@
#ifndef SRC_AST_WORKGROUP_DECORATION_H_
#define SRC_AST_WORKGROUP_DECORATION_H_
#include <tuple>
#include <array>
#include "src/ast/decoration.h"
namespace tint {
namespace ast {
// Forward declaration
class Expression;
/// A workgroup decoration
class WorkgroupDecoration : public Castable<WorkgroupDecoration, Decoration> {
public:
/// constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the source of this decoration
/// @param x the workgroup x dimension size
WorkgroupDecoration(ProgramID program_id, const Source& source, uint32_t x);
/// constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the source of this decoration
/// @param x the workgroup x dimension size
/// @param y the workgroup x dimension size
/// @param x the workgroup x dimension expression
/// @param y the optional workgroup y dimension expression
/// @param z the optional workgroup z dimension expression
WorkgroupDecoration(ProgramID program_id,
const Source& source,
uint32_t x,
uint32_t y);
/// constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the source of this decoration
/// @param x the workgroup x dimension size
/// @param y the workgroup x dimension size
/// @param z the workgroup x dimension size
WorkgroupDecoration(ProgramID program_id,
const Source& source,
uint32_t x,
uint32_t y,
uint32_t z);
ast::Expression* x,
ast::Expression* y = nullptr,
ast::Expression* z = nullptr);
~WorkgroupDecoration() override;
/// @returns the workgroup dimensions
std::tuple<uint32_t, uint32_t, uint32_t> values() const {
return {x_, y_, z_};
}
std::array<ast::Expression*, 3> values() const { return {x_, y_, z_}; }
/// Outputs the decoration to the given stream
/// @param sem the semantic info for the program
@ -72,9 +60,9 @@ class WorkgroupDecoration : public Castable<WorkgroupDecoration, Decoration> {
WorkgroupDecoration* Clone(CloneContext* ctx) const override;
private:
uint32_t const x_;
uint32_t const y_;
uint32_t const z_;
ast::Expression* x_ = nullptr;
ast::Expression* y_ = nullptr;
ast::Expression* z_ = nullptr;
};
} // namespace ast

View File

@ -24,40 +24,84 @@ namespace {
using WorkgroupDecorationTest = TestHelper;
TEST_F(WorkgroupDecorationTest, Creation_1param) {
auto* d = create<WorkgroupDecoration>(2);
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = d->values();
EXPECT_EQ(x, 2u);
EXPECT_EQ(y, 1u);
EXPECT_EQ(z, 1u);
auto* d = WorkgroupSize(2);
auto values = d->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(x_scalar);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
EXPECT_EQ(values[1], nullptr);
EXPECT_EQ(values[2], nullptr);
}
TEST_F(WorkgroupDecorationTest, Creation_2param) {
auto* d = create<WorkgroupDecoration>(2, 4);
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = d->values();
EXPECT_EQ(x, 2u);
EXPECT_EQ(y, 4u);
EXPECT_EQ(z, 1u);
auto* d = WorkgroupSize(2, 4);
auto values = d->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(x_scalar);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(y_scalar);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
EXPECT_EQ(values[2], nullptr);
}
TEST_F(WorkgroupDecorationTest, Creation_3param) {
auto* d = create<WorkgroupDecoration>(2, 4, 6);
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = d->values();
EXPECT_EQ(x, 2u);
EXPECT_EQ(y, 4u);
EXPECT_EQ(z, 6u);
auto* d = WorkgroupSize(2, 4, 6);
auto values = d->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(x_scalar);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(y_scalar);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
ASSERT_NE(values[2], nullptr);
auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(z_scalar);
ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 6u);
}
TEST_F(WorkgroupDecorationTest, Creation_WithIdentifier) {
auto* d = WorkgroupSize(2, 4, "depth");
auto values = d->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(x_scalar);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_TRUE(y_scalar);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
ASSERT_NE(values[2], nullptr);
auto* z_ident = values[2]->As<ast::IdentifierExpression>();
ASSERT_TRUE(z_ident);
EXPECT_EQ(Symbols().NameFor(z_ident->symbol()), "depth");
}
TEST_F(WorkgroupDecorationTest, ToStr) {
auto* d = create<WorkgroupDecoration>(2, 4, 6);
EXPECT_EQ(str(d), R"(WorkgroupDecoration{2 4 6}
auto* d = WorkgroupSize(2, "height");
EXPECT_EQ(str(d), R"(WorkgroupDecoration{
ScalarConstructor[not set]{2}
Identifier[not set]{height}
}
)");
}

View File

@ -828,10 +828,8 @@ TEST_F(InspectorGetEntryPointTest, DefaultWorkgroupSize) {
}
TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
MakeEmptyBodyFunction("foo", ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
create<ast::WorkgroupDecoration>(8u, 2u, 1u),
});
MakeEmptyBodyFunction(
"foo", {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 1)});
Inspector& inspector = Build();

View File

@ -59,6 +59,7 @@
#include "src/ast/variable_decl_statement.h"
#include "src/ast/vector.h"
#include "src/ast/void.h"
#include "src/ast/workgroup_decoration.h"
#include "src/program.h"
#include "src/program_id.h"
#include "src/sem/array.h"
@ -1914,6 +1915,36 @@ class ProgramBuilder {
return create<ast::StageDecoration>(source_, stage);
}
/// Creates an ast::WorkgroupDecoration
/// @param x the x dimension expression
/// @returns the workgroup decoration pointer
template <typename EXPR_X>
ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x) {
return WorkgroupSize(std::forward<EXPR_X>(x), nullptr, nullptr);
}
/// Creates an ast::WorkgroupDecoration
/// @param x the x dimension expression
/// @param y the y dimension expression
/// @returns the workgroup decoration pointer
template <typename EXPR_X, typename EXPR_Y>
ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) {
return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y),
nullptr);
}
/// Creates an ast::WorkgroupDecoration
/// @param x the x dimension expression
/// @param y the y dimension expression
/// @param z the z dimension expression
/// @returns the workgroup decoration pointer
template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z>
ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) {
return create<ast::WorkgroupDecoration>(
source_, Expr(std::forward<EXPR_X>(x)), Expr(std::forward<EXPR_Y>(y)),
Expr(std::forward<EXPR_Z>(z)));
}
/// Sets the current builder source to `src`
/// @param src the Source used for future create() calls
void SetSource(const Source& src) {

View File

@ -3010,26 +3010,35 @@ Maybe<ast::Decoration*> ParserImpl::decoration() {
if (s == kWorkgroupSizeDecoration) {
return expect_paren_block("workgroup_size decoration", [&]() -> Result {
uint32_t x;
uint32_t y = 1;
uint32_t z = 1;
ast::Expression* x = nullptr;
ast::Expression* y = nullptr;
ast::Expression* z = nullptr;
auto val = expect_nonzero_positive_sint("workgroup_size x parameter");
if (val.errored)
auto expr = primary_expression();
if (expr.errored) {
return Failure::kErrored;
x = val.value;
} else if (!expr.matched) {
return add_error(peek(), "expected workgroup_size x parameter");
}
x = std::move(expr.value);
if (match(Token::Type::kComma)) {
val = expect_nonzero_positive_sint("workgroup_size y parameter");
if (val.errored)
expr = primary_expression();
if (expr.errored) {
return Failure::kErrored;
y = val.value;
} else if (!expr.matched) {
return add_error(peek(), "expected workgroup_size y parameter");
}
y = std::move(expr.value);
if (match(Token::Type::kComma)) {
val = expect_nonzero_positive_sint("workgroup_size z parameter");
if (val.errored)
expr = primary_expression();
if (expr.errored) {
return Failure::kErrored;
z = val.value;
} else if (!expr.matched) {
return add_error(peek(), "expected workgroup_size z parameter");
}
z = std::move(expr.value);
}
}

View File

@ -325,74 +325,23 @@ TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeMissingRParen) {
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXInvalid) {
EXPECT("[[workgroup_size(x)]] fn f() {}",
"test.wgsl:1:18 error: expected signed integer literal for "
"workgroup_size x parameter\n"
"[[workgroup_size(x)]] fn f() {}\n"
" ^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXNegative) {
EXPECT("[[workgroup_size(-1)]] fn f() {}",
"test.wgsl:1:18 error: workgroup_size x parameter must be greater "
"than 0\n"
"[[workgroup_size(-1)]] fn f() {}\n"
" ^^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXZero) {
EXPECT("[[workgroup_size(0)]] fn f() {}",
"test.wgsl:1:18 error: workgroup_size x parameter must be greater "
"than 0\n"
"[[workgroup_size(0)]] fn f() {}\n"
EXPECT("[[workgroup_size(@)]] fn f() {}",
"test.wgsl:1:18 error: expected workgroup_size x parameter\n"
"[[workgroup_size(@)]] fn f() {}\n"
" ^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYInvalid) {
EXPECT("[[workgroup_size(1, x)]] fn f() {}",
"test.wgsl:1:21 error: expected signed integer literal for "
"workgroup_size y parameter\n"
"[[workgroup_size(1, x)]] fn f() {}\n"
" ^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYNegative) {
EXPECT("[[workgroup_size(1, -1)]] fn f() {}",
"test.wgsl:1:21 error: workgroup_size y parameter must be greater "
"than 0\n"
"[[workgroup_size(1, -1)]] fn f() {}\n"
" ^^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYZero) {
EXPECT("[[workgroup_size(1, 0)]] fn f() {}",
"test.wgsl:1:21 error: workgroup_size y parameter must be greater "
"than 0\n"
"[[workgroup_size(1, 0)]] fn f() {}\n"
EXPECT("[[workgroup_size(1, @)]] fn f() {}",
"test.wgsl:1:21 error: expected workgroup_size y parameter\n"
"[[workgroup_size(1, @)]] fn f() {}\n"
" ^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZInvalid) {
EXPECT("[[workgroup_size(1, 2, x)]] fn f() {}",
"test.wgsl:1:24 error: expected signed integer literal for "
"workgroup_size z parameter\n"
"[[workgroup_size(1, 2, x)]] fn f() {}\n"
" ^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZNegative) {
EXPECT("[[workgroup_size(1, 2, -1)]] fn f() {}",
"test.wgsl:1:24 error: workgroup_size z parameter must be greater "
"than 0\n"
"[[workgroup_size(1, 2, -1)]] fn f() {}\n"
" ^^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZZero) {
EXPECT("[[workgroup_size(1, 2, 0)]] fn f() {}",
"test.wgsl:1:24 error: workgroup_size z parameter must be greater "
"than 0\n"
"[[workgroup_size(1, 2, 0)]] fn f() {}\n"
EXPECT("[[workgroup_size(1, 2, @)]] fn f() {}",
"test.wgsl:1:24 error: expected workgroup_size z parameter\n"
"[[workgroup_size(1, 2, @)]] fn f() {}\n"
" ^\n");
}

View File

@ -69,13 +69,25 @@ TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
ASSERT_EQ(decorations.size(), 1u);
ASSERT_TRUE(decorations[0]->Is<ast::WorkgroupDecoration>());
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = decorations[0]->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 2u);
EXPECT_EQ(y, 3u);
EXPECT_EQ(z, 4u);
auto values = decorations[0]->As<ast::WorkgroupDecoration>()->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(y_scalar, nullptr);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
ASSERT_NE(values[2], nullptr);
auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(z_scalar, nullptr);
ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
@ -84,7 +96,7 @@ TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleEntries) {
auto p = parser(R"(
[[workgroup_size(2, 3, 4), workgroup_size(5, 6, 7)]]
[[workgroup_size(2, 3, 4), stage(compute)]]
fn main() { return; })");
auto decos = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
@ -104,20 +116,30 @@ fn main() { return; })");
auto& decorations = f->decorations();
ASSERT_EQ(decorations.size(), 2u);
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
ASSERT_TRUE(decorations[0]->Is<ast::WorkgroupDecoration>());
std::tie(x, y, z) = decorations[0]->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 2u);
EXPECT_EQ(y, 3u);
EXPECT_EQ(z, 4u);
auto values = decorations[0]->As<ast::WorkgroupDecoration>()->values();
ASSERT_TRUE(decorations[1]->Is<ast::WorkgroupDecoration>());
std::tie(x, y, z) = decorations[1]->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 5u);
EXPECT_EQ(y, 6u);
EXPECT_EQ(z, 7u);
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(y_scalar, nullptr);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
ASSERT_NE(values[2], nullptr);
auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(z_scalar, nullptr);
ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
ASSERT_TRUE(decorations[1]->Is<ast::StageDecoration>());
EXPECT_EQ(decorations[1]->As<ast::StageDecoration>()->value(),
ast::PipelineStage::kCompute);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
@ -127,7 +149,7 @@ fn main() { return; })");
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleLists) {
auto p = parser(R"(
[[workgroup_size(2, 3, 4)]]
[[workgroup_size(5, 6, 7)]]
[[stage(compute)]]
fn main() { return; })");
auto decorations = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
@ -147,20 +169,30 @@ fn main() { return; })");
auto& decos = f->decorations();
ASSERT_EQ(decos.size(), 2u);
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
ASSERT_TRUE(decos[0]->Is<ast::WorkgroupDecoration>());
std::tie(x, y, z) = decos[0]->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 2u);
EXPECT_EQ(y, 3u);
EXPECT_EQ(z, 4u);
auto values = decos[0]->As<ast::WorkgroupDecoration>()->values();
ASSERT_TRUE(decos[1]->Is<ast::WorkgroupDecoration>());
std::tie(x, y, z) = decos[1]->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 5u);
EXPECT_EQ(y, 6u);
EXPECT_EQ(z, 7u);
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(y_scalar, nullptr);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
ASSERT_NE(values[2], nullptr);
auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(z_scalar, nullptr);
ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
ASSERT_TRUE(decos[1]->Is<ast::StageDecoration>());
EXPECT_EQ(decos[1]->As<ast::StageDecoration>()->value(),
ast::PipelineStage::kCompute);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);

View File

@ -21,7 +21,7 @@ namespace wgsl {
namespace {
TEST_F(ParserImplTest, DecorationList_Parses) {
auto p = parser("[[workgroup_size(2), workgroup_size(3, 4, 5)]]");
auto p = parser("[[workgroup_size(2), stage(compute)]]");
auto decos = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(decos.errored);
@ -33,18 +33,17 @@ TEST_F(ParserImplTest, DecorationList_Parses) {
ASSERT_NE(deco_0, nullptr);
ASSERT_NE(deco_1, nullptr);
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
ASSERT_TRUE(deco_0->Is<ast::WorkgroupDecoration>());
std::tie(x, y, z) = deco_0->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 2u);
ast::Expression* x = deco_0->As<ast::WorkgroupDecoration>()->values()[0];
ASSERT_NE(x, nullptr);
auto* x_scalar = x->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
ASSERT_TRUE(deco_1->Is<ast::WorkgroupDecoration>());
std::tie(x, y, z) = deco_1->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 3u);
EXPECT_EQ(y, 4u);
EXPECT_EQ(z, 5u);
ASSERT_TRUE(deco_1->Is<ast::StageDecoration>());
EXPECT_EQ(deco_1->As<ast::StageDecoration>()->value(),
ast::PipelineStage::kCompute);
}
TEST_F(ParserImplTest, DecorationList_Empty) {
@ -85,14 +84,12 @@ TEST_F(ParserImplTest, DecorationList_MissingComma) {
}
TEST_F(ParserImplTest, DecorationList_BadDecoration) {
auto p = parser("[[workgroup_size()]]");
auto p = parser("[[stage()]]");
auto decos = p->decoration_list();
EXPECT_TRUE(p->has_error());
EXPECT_TRUE(decos.errored);
EXPECT_FALSE(decos.matched);
EXPECT_EQ(
p->error(),
"1:18: expected signed integer literal for workgroup_size x parameter");
EXPECT_EQ(p->error(), "1:9: invalid value for stage decoration");
}
TEST_F(ParserImplTest, DecorationList_MissingRightAttr) {

View File

@ -32,13 +32,16 @@ TEST_F(ParserImplTest, Decoration_Workgroup) {
ASSERT_NE(func_deco, nullptr);
ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 4u);
EXPECT_EQ(y, 1u);
EXPECT_EQ(z, 1u);
auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
EXPECT_EQ(values[1], nullptr);
EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Decoration_Workgroup_2Param) {
@ -52,13 +55,21 @@ TEST_F(ParserImplTest, Decoration_Workgroup_2Param) {
ASSERT_NE(func_deco, nullptr) << p->error();
ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 4u);
EXPECT_EQ(y, 5u);
EXPECT_EQ(z, 1u);
auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(y_scalar, nullptr);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 5u);
EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Decoration_Workgroup_3Param) {
@ -72,13 +83,52 @@ TEST_F(ParserImplTest, Decoration_Workgroup_3Param) {
ASSERT_NE(func_deco, nullptr);
ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
EXPECT_EQ(x, 4u);
EXPECT_EQ(y, 5u);
EXPECT_EQ(z, 6u);
auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
ASSERT_NE(values[1], nullptr);
auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(y_scalar, nullptr);
ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 5u);
ASSERT_NE(values[2], nullptr);
auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(z_scalar, nullptr);
ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 6u);
}
TEST_F(ParserImplTest, Decoration_Workgroup_WithIdent) {
auto p = parser("workgroup_size(4, height)");
auto deco = p->decoration();
EXPECT_TRUE(deco.matched);
EXPECT_FALSE(deco.errored);
ASSERT_NE(deco.value, nullptr) << p->error();
ASSERT_FALSE(p->has_error());
auto* func_deco = deco.value->As<ast::Decoration>();
ASSERT_NE(func_deco, nullptr);
ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
ASSERT_NE(values[0], nullptr);
auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
ASSERT_NE(x_scalar, nullptr);
ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
ASSERT_NE(values[1], nullptr);
auto* y_ident = values[1]->As<ast::IdentifierExpression>();
ASSERT_NE(y_ident, nullptr);
EXPECT_EQ(p->builder().Symbols().NameFor(y_ident->symbol()), "height");
ASSERT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Decoration_Workgroup_TooManyValues) {
@ -91,39 +141,6 @@ TEST_F(ParserImplTest, Decoration_Workgroup_TooManyValues) {
EXPECT_EQ(p->error(), "1:23: expected ')' for workgroup_size decoration");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_X_Value) {
auto p = parser("workgroup_size(-2, 5, 6)");
auto deco = p->decoration();
EXPECT_FALSE(deco.matched);
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(p->error(),
"1:16: workgroup_size x parameter must be greater than 0");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_Y_Value) {
auto p = parser("workgroup_size(4, 0, 6)");
auto deco = p->decoration();
EXPECT_FALSE(deco.matched);
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(p->error(),
"1:19: workgroup_size y parameter must be greater than 0");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_Z_Value) {
auto p = parser("workgroup_size(4, 5, -3)");
auto deco = p->decoration();
EXPECT_FALSE(deco.matched);
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(p->error(),
"1:22: workgroup_size z parameter must be greater than 0");
}
TEST_F(ParserImplTest, Decoration_Workgroup_MissingLeftParam) {
auto p = parser("workgroup_size 4, 5, 6)");
auto deco = p->decoration();
@ -151,9 +168,7 @@ TEST_F(ParserImplTest, Decoration_Workgroup_MissingValues) {
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(
p->error(),
"1:16: expected signed integer literal for workgroup_size x parameter");
EXPECT_EQ(p->error(), "1:16: expected workgroup_size x parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_X_Value) {
@ -163,9 +178,7 @@ TEST_F(ParserImplTest, Decoration_Workgroup_Missing_X_Value) {
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(
p->error(),
"1:16: expected signed integer literal for workgroup_size x parameter");
EXPECT_EQ(p->error(), "1:16: expected workgroup_size x parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Y_Comma) {
@ -185,9 +198,7 @@ TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Y_Value) {
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(
p->error(),
"1:19: expected signed integer literal for workgroup_size y parameter");
EXPECT_EQ(p->error(), "1:19: expected workgroup_size y parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Z_Comma) {
@ -207,45 +218,7 @@ TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Z_Value) {
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(
p->error(),
"1:22: expected signed integer literal for workgroup_size z parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_X_Invalid) {
auto p = parser("workgroup_size(nan)");
auto deco = p->decoration();
EXPECT_FALSE(deco.matched);
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(
p->error(),
"1:16: expected signed integer literal for workgroup_size x parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Y_Invalid) {
auto p = parser("workgroup_size(2, nan)");
auto deco = p->decoration();
EXPECT_FALSE(deco.matched);
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(
p->error(),
"1:19: expected signed integer literal for workgroup_size y parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Z_Invalid) {
auto p = parser("workgroup_size(2, 3, nan)");
auto deco = p->decoration();
EXPECT_FALSE(deco.matched);
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(
p->error(),
"1:22: expected signed integer literal for workgroup_size z parameter");
EXPECT_EQ(p->error(), "1:22: expected workgroup_size z parameter");
}
TEST_F(ParserImplTest, Decoration_Stage) {

View File

@ -94,7 +94,8 @@ static ast::DecorationList createDecorations(const Source& source,
case DecorationKind::kStructBlock:
return {builder.create<ast::StructBlockDecoration>(source)};
case DecorationKind::kWorkgroup:
return {builder.create<ast::WorkgroupDecoration>(source, 1u, 1u, 1u)};
return {
builder.create<ast::WorkgroupDecoration>(source, builder.Expr(1))};
case DecorationKind::kBindingAndGroup:
return {builder.create<ast::BindingDecoration>(source, 1u),
builder.create<ast::GroupDecoration>(source, 1u)};
@ -664,7 +665,7 @@ using WorkgroupDecoration = ResolverTest;
TEST_F(WorkgroupDecoration, NotAnEntryPoint) {
Func("main", {}, ty.void_(), {},
{create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
{create<ast::WorkgroupDecoration>(Source{{12, 34}}, Expr(1))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@ -675,7 +676,7 @@ TEST_F(WorkgroupDecoration, NotAnEntryPoint) {
TEST_F(WorkgroupDecoration, NotAComputeShader) {
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment),
create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
create<ast::WorkgroupDecoration>(Source{{12, 34}}, Expr(1))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@ -685,9 +686,8 @@ TEST_F(WorkgroupDecoration, NotAComputeShader) {
TEST_F(WorkgroupDecoration, MultipleAttributes) {
Func(Source{{12, 34}}, "main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
create<ast::WorkgroupDecoration>(1u),
create<ast::WorkgroupDecoration>(2u)});
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1),
WorkgroupSize(2)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),

View File

@ -244,5 +244,127 @@ TEST_F(ResolverFunctionValidationTest, FunctionConstInitWithParam) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
// [[stage(compute), workgroup_size(64.0)]
// fn main() {}
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(create<ast::ScalarConstructorExpression>(
Source{Source::Location{12, 34}}, Literal(64.f)))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size parameter must be a literal i32 or an "
"i32 module-scope constant");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
// [[stage(compute), workgroup_size(-2)]
// fn main() {}
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(create<ast::ScalarConstructorExpression>(
Source{Source::Location{12, 34}}, Literal(-2)))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: workgroup_size parameter must be a positive i32 value");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Zero) {
// [[stage(compute), workgroup_size(0)]
// fn main() {}
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(create<ast::ScalarConstructorExpression>(
Source{Source::Location{12, 34}}, Literal(0)))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: workgroup_size parameter must be a positive i32 value");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) {
// let x = 64.0;
// [[stage(compute), workgroup_size(x)]
// fn main() {}
GlobalConst("x", ty.f32(), Expr(64.f));
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size parameter must be a literal i32 or an "
"i32 module-scope constant");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
// let x = -2;
// [[stage(compute), workgroup_size(x)]
// fn main() {}
GlobalConst("x", ty.i32(), Expr(-2));
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: workgroup_size parameter must be a positive i32 value");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Zero) {
// let x = 0;
// [[stage(compute), workgroup_size(x)]
// fn main() {}
GlobalConst("x", ty.i32(), Expr(0));
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: workgroup_size parameter must be a positive i32 value");
}
TEST_F(ResolverFunctionValidationTest,
WorkgroupSize_Const_NestedZeroValueConstructor) {
// let x = i32(i32(i32()));
// [[stage(compute), workgroup_size(x)]
// fn main() {}
GlobalConst("x", ty.i32(),
Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32()))));
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: workgroup_size parameter must be a positive i32 value");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
// var<private> x = 0;
// [[stage(compute), workgroup_size(x)]
// fn main() {}
Global("x", ty.i32(), ast::StorageClass::kPrivate, Expr(64));
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size parameter must be a literal i32 or an "
"i32 module-scope constant");
}
} // namespace
} // namespace tint

View File

@ -1295,10 +1295,79 @@ bool Resolver::Function(ast::Function* func) {
if (auto* workgroup =
ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
// TODO(crbug.com/tint/713): Handle non-literals.
info->workgroup_size[0].value = std::get<0>(workgroup->values());
info->workgroup_size[1].value = std::get<1>(workgroup->values());
info->workgroup_size[2].value = std::get<2>(workgroup->values());
auto values = workgroup->values();
for (int i = 0; i < 3; i++) {
// Each argument to this decoration can either be a literal, an
// identifier for a module-scope constants, or nullptr if not specified.
if (!values[i]) {
// Not specified, just use the default.
continue;
}
Mark(values[i]);
int32_t value = 0;
if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
// We have an identifier of a module-scope constant.
if (!Identifier(ident)) {
return false;
}
VariableInfo* var;
if (!variable_stack_.get(ident->symbol(), &var) ||
!(var->declaration->is_const() && var->type->Is<sem::I32>())) {
diagnostics_.add_error(
"workgroup_size parameter must be a literal i32 or an i32 "
"module-scope constant",
values[i]->source());
return false;
}
// Capture the constant if an [[override]] attribute is present.
if (ast::HasDecoration<ast::OverrideDecoration>(
var->declaration->decorations())) {
info->workgroup_size[i].overridable_const = var->declaration;
}
auto* constructor = var->declaration->constructor();
if (constructor) {
// Resolve the constructor expression to use as the default value.
if (!GetScalarConstExprValue(constructor, &value)) {
return false;
}
} else {
// No constructor means this value must be overriden by the user.
info->workgroup_size[i].value = 0;
continue;
}
} else if (auto* scalar =
values[i]->As<ast::ScalarConstructorExpression>()) {
// We have a literal.
Mark(scalar->literal());
if (!scalar->literal()->Is<ast::IntLiteral>()) {
diagnostics_.add_error(
"workgroup_size parameter must be a literal i32 or an i32 "
"module-scope constant",
values[i]->source());
return false;
}
if (!GetScalarConstExprValue(scalar, &value)) {
return false;
}
}
// Validate and set the default value for this dimension.
if (value < 1) {
diagnostics_.add_error(
"workgroup_size parameter must be a positive i32 value",
values[i]->source());
return false;
}
info->workgroup_size[i].value = value;
}
}
if (!ValidateFunction(func, info)) {
@ -3098,6 +3167,40 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
return true;
}
template <typename T>
bool Resolver::GetScalarConstExprValue(ast::Expression* expr, T* result) {
if (auto* type_constructor = expr->As<ast::TypeConstructorExpression>()) {
if (type_constructor->values().size() == 0) {
// Zero-valued constructor.
*result = static_cast<T>(0);
return true;
} else if (type_constructor->values().size() == 1) {
// Recurse into the constructor argument expression.
return GetScalarConstExprValue(type_constructor->values()[0], result);
} else {
TINT_ICE(diagnostics_) << "malformed scalar type constructor";
}
} else if (auto* scalar = expr->As<ast::ScalarConstructorExpression>()) {
// Cast literal to result type.
if (auto* int_lit = scalar->literal()->As<ast::IntLiteral>()) {
*result = static_cast<T>(int_lit->value_as_u32());
return true;
} else if (auto* float_lit = scalar->literal()->As<ast::FloatLiteral>()) {
*result = static_cast<T>(float_lit->value());
return true;
} else if (auto* bool_lit = scalar->literal()->As<ast::BoolLiteral>()) {
*result = static_cast<T>(bool_lit->IsTrue());
return true;
} else {
TINT_ICE(diagnostics_) << "unhandled scalar constructor";
}
} else {
TINT_ICE(diagnostics_) << "unhandled constant expression";
}
return false;
}
template <typename F>
bool Resolver::BlockScope(const ast::BlockStatement* block, F&& callback) {
auto* sem_block = builder_->Sem().Get<sem::BlockStatement>(block);

View File

@ -323,6 +323,13 @@ class Resolver {
typ::Type type,
const std::string& type_name);
/// Resolve the value of a scalar const_expr.
/// @param expr the expression
/// @param result pointer to the where the result will be stored
/// @returns true on success, false on error
template <typename T>
bool GetScalarConstExprValue(ast::Expression* expr, T* result);
/// Constructs a new semantic BlockStatement with the given type and with
/// #current_block_ as its parent, assigns this to #current_block_, and then
/// calls `callback`. The original #current_block_ is restored on exit.

View File

@ -26,6 +26,7 @@
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic_texture_helper_test.h"
#include "src/ast/loop_statement.h"
#include "src/ast/override_decoration.h"
#include "src/ast/return_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/struct_block_decoration.h"
@ -889,6 +890,8 @@ TEST_F(ResolverTest, Function_ReturnStatements) {
}
TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
// [[stage(compute)]]
// fn main() {}
auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -905,9 +908,11 @@ TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
}
TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
create<ast::WorkgroupDecoration>(8, 2, 3)});
// [[stage(compute), workgroup_size(8, 2, 3)]]
// fn main() {}
auto* func =
Func("main", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 3)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -922,6 +927,134 @@ TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
}
TEST_F(ResolverTest, Function_WorkgroupSize_Consts) {
// let width = 16;
// let height = 8;
// let depth = 2;
// [[stage(compute), workgroup_size(width, height, depth)]]
// fn main() {}
GlobalConst("width", ty.i32(), Expr(16));
GlobalConst("height", ty.i32(), Expr(8));
GlobalConst("depth", ty.i32(), Expr(2));
auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize("width", "height", "depth")});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[0].value, 16u);
EXPECT_EQ(func_sem->workgroup_size()[1].value, 8u);
EXPECT_EQ(func_sem->workgroup_size()[2].value, 2u);
EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
}
TEST_F(ResolverTest, Function_WorkgroupSize_Consts_NestedInitializer) {
// let width = i32(i32(i32(8)));
// let height = i32(i32(i32(4)));
// [[stage(compute), workgroup_size(width, height)]]
// fn main() {}
GlobalConst("width", ty.i32(),
Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 8))));
GlobalConst("height", ty.i32(),
Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 4))));
auto* func = Func(
"main", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height")});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
EXPECT_EQ(func_sem->workgroup_size()[1].value, 4u);
EXPECT_EQ(func_sem->workgroup_size()[2].value, 1u);
EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
}
TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
// [[override(0)]] let width = 16;
// [[override(1)]] let height = 8;
// [[override(2)]] let depth = 2;
// [[stage(compute), workgroup_size(width, height, depth)]]
// fn main() {}
auto* width = GlobalConst("width", ty.i32(), Expr(16), {Override(0)});
auto* height = GlobalConst("height", ty.i32(), Expr(8), {Override(1)});
auto* depth = GlobalConst("depth", ty.i32(), Expr(2), {Override(2)});
auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize("width", "height", "depth")});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[0].value, 16u);
EXPECT_EQ(func_sem->workgroup_size()[1].value, 8u);
EXPECT_EQ(func_sem->workgroup_size()[2].value, 2u);
EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, width);
EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, depth);
}
TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
// [[override(0)]] let width : i32;
// [[override(1)]] let height : i32;
// [[override(2)]] let depth : i32;
// [[stage(compute), workgroup_size(width, height, depth)]]
// fn main() {}
auto* width = GlobalConst("width", ty.i32(), nullptr, {Override(0)});
auto* height = GlobalConst("height", ty.i32(), nullptr, {Override(1)});
auto* depth = GlobalConst("depth", ty.i32(), nullptr, {Override(2)});
auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize("width", "height", "depth")});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[0].value, 0u);
EXPECT_EQ(func_sem->workgroup_size()[1].value, 0u);
EXPECT_EQ(func_sem->workgroup_size()[2].value, 0u);
EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, width);
EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, depth);
}
TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
// [[override(1)]] let height = 2;
// let depth = 3;
// [[stage(compute), workgroup_size(8, height, depth)]]
// fn main() {}
auto* height = GlobalConst("height", ty.i32(), Expr(2), {Override(0)});
GlobalConst("depth", ty.i32(), Expr(3));
auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(8, "height", "depth")});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
EXPECT_EQ(func_sem->workgroup_size()[1].value, 2u);
EXPECT_EQ(func_sem->workgroup_size()[2].value, 3u);
EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
}
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
auto* st = Structure("S", {Member("first_member", ty.i32()),
Member("second_member", ty.f32())});

View File

@ -926,7 +926,7 @@ TEST_F(HlslGeneratorImplTest_Function,
},
{
Stage(ast::PipelineStage::kCompute),
create<ast::WorkgroupDecoration>(2u, 4u, 6u),
WorkgroupSize(2, 4, 6),
});
GeneratorImpl& gen = Build();

View File

@ -200,7 +200,7 @@ TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Default) {
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize) {
auto* func = Func("main", {}, ty.void_(), ast::StatementList{},
ast::DecorationList{
create<ast::WorkgroupDecoration>(2u, 4u, 6u),
WorkgroupSize(2, 4, 6),
Stage(ast::PipelineStage::kCompute),
});

View File

@ -601,12 +601,28 @@ bool GeneratorImpl::EmitDecorations(const ast::DecorationList& decos) {
first = false;
if (auto* workgroup = deco->As<ast::WorkgroupDecoration>()) {
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
std::tie(x, y, z) = workgroup->values();
out_ << "workgroup_size(" << std::to_string(x) << ", "
<< std::to_string(y) << ", " << std::to_string(z) << ")";
auto values = workgroup->values();
out_ << "workgroup_size(";
for (int i = 0; i < 3; i++) {
if (values[i]) {
if (i > 0) {
out_ << ", ";
}
if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
if (!EmitIdentifier(ident)) {
return false;
}
} else if (auto* scalar =
values[i]->As<ast::ScalarConstructorExpression>()) {
if (!EmitScalarConstructor(scalar)) {
return false;
}
} else {
TINT_ICE(diagnostics_) << "Unsupported workgroup_size expression";
}
}
}
out_ << ")";
} else if (deco->Is<ast::StructBlockDecoration>()) {
out_ << "block";
} else if (auto* stage = deco->As<ast::StageDecoration>()) {

View File

@ -69,13 +69,10 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) {
TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) {
auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::DiscardStatement>(),
Return(),
},
ast::StatementList{Return()},
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
create<ast::WorkgroupDecoration>(2u, 4u, 6u),
WorkgroupSize(2, 4, 6),
});
GeneratorImpl& gen = Build();
@ -85,20 +82,19 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) {
ASSERT_TRUE(gen.EmitFunction(func));
EXPECT_EQ(gen.result(), R"( [[stage(compute), workgroup_size(2, 4, 6)]]
fn my_func() {
discard;
return;
}
)");
}
TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) {
TEST_F(WgslGeneratorImplTest,
Emit_Function_WithDecoration_WorkgroupSize_WithIdent) {
GlobalConst("height", ty.i32(), Expr(2));
auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::DiscardStatement>(),
Return(),
},
ast::StatementList{Return()},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(2, "height"),
});
GeneratorImpl& gen = Build();
@ -106,9 +102,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) {
gen.increment_indent();
ASSERT_TRUE(gen.EmitFunction(func));
EXPECT_EQ(gen.result(), R"( [[stage(fragment)]]
EXPECT_EQ(gen.result(), R"( [[stage(compute), workgroup_size(2, height)]]
fn my_func() {
discard;
return;
}
)");

View File

@ -2,6 +2,6 @@ fn main() -> f32 {
return (((2.0 * 3.0) - 4.0) / 5.0);
}
[[stage(compute), workgroup_size(2, 1, 1)]]
[[stage(compute), workgroup_size(2)]]
fn ep() {
}