[wgsl-reader] Allow decorations on function return types

Add a return type decoration list field to ast::Function.

Bug: tint:513
Change-Id: I41c1087f21a87731eb48ec7642997da5ae7f2baa
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44601
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-03-15 17:01:34 +00:00 committed by Commit Bot service account
parent 20a438a5c3
commit feecbe0d83
14 changed files with 150 additions and 26 deletions

View File

@ -28,13 +28,15 @@ Function::Function(const Source& source,
VariableList params,
type::Type* return_type,
BlockStatement* body,
DecorationList decorations)
DecorationList decorations,
DecorationList return_type_decorations)
: Base(source),
symbol_(symbol),
params_(std::move(params)),
return_type_(return_type),
body_(body),
decorations_(std::move(decorations)) {
decorations_(std::move(decorations)),
return_type_decorations_(std::move(return_type_decorations)) {
for (auto* param : params_) {
TINT_ASSERT(param);
}
@ -77,7 +79,8 @@ Function* Function::Clone(CloneContext* ctx) const {
auto* ret = ctx->Clone(return_type_);
auto* b = ctx->Clone(body_);
auto decos = ctx->Clone(decorations_);
return ctx->dst->create<Function>(src, sym, p, ret, b, decos);
auto ret_decos = ctx->Clone(return_type_decorations_);
return ctx->dst->create<Function>(src, sym, p, ret, b, decos, ret_decos);
}
void Function::to_str(const semantic::Info& sem,

View File

@ -42,12 +42,14 @@ class Function : public Castable<Function, Node> {
/// @param return_type the return type
/// @param body the function body
/// @param decorations the function decorations
/// @param return_type_decorations the return type decorations
Function(const Source& source,
Symbol symbol,
VariableList params,
type::Type* return_type,
BlockStatement* body,
DecorationList decorations);
DecorationList decorations,
DecorationList return_type_decorations);
/// Move constructor
Function(Function&&);
@ -74,6 +76,11 @@ class Function : public Castable<Function, Node> {
/// @returns the function return type.
type::Type* return_type() const { return return_type_; }
/// @returns the decorations attached to the function return type.
const DecorationList& return_type_decorations() const {
return return_type_decorations_;
}
/// @returns a pointer to the last statement of the function or nullptr if
// function is empty
const Statement* get_last_statement() const;
@ -108,6 +115,7 @@ class Function : public Castable<Function, Node> {
type::Type* const return_type_;
BlockStatement* const body_;
DecorationList const decorations_;
DecorationList const return_type_decorations_;
};
/// A list of functions

View File

@ -994,16 +994,18 @@ class ProgramBuilder {
/// @param type the function return type
/// @param body the function body
/// @param decorations the function decorations
/// @param return_type_decorations the function return type decorations
/// @returns the function pointer
ast::Function* Func(Source source,
std::string name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
ast::DecorationList decorations) {
auto* func =
create<ast::Function>(source, Symbols().Register(name), params, type,
create<ast::BlockStatement>(body), decorations);
ast::DecorationList decorations,
ast::DecorationList return_type_decorations = {}) {
auto* func = create<ast::Function>(source, Symbols().Register(name), params,
type, create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func);
return func;
}
@ -1014,15 +1016,17 @@ class ProgramBuilder {
/// @param type the function return type
/// @param body the function body
/// @param decorations the function decorations
/// @param return_type_decorations the function return type decorations
/// @returns the function pointer
ast::Function* Func(std::string name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
ast::DecorationList decorations) {
auto* func =
create<ast::Function>(Symbols().Register(name), params, type,
create<ast::BlockStatement>(body), decorations);
ast::DecorationList decorations,
ast::DecorationList return_type_decorations = {}) {
auto* func = create<ast::Function>(Symbols().Register(name), params, type,
create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func);
return func;
}

View File

@ -843,10 +843,10 @@ bool FunctionEmitter::Emit() {
auto& statements = statements_stack_[0].GetStatements();
auto* body = create<ast::BlockStatement>(Source{}, statements);
builder_.AST().AddFunction(
create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name),
std::move(decl.params), decl.return_type, body,
std::move(decl.decorations)));
builder_.AST().AddFunction(create<ast::Function>(
decl.source, builder_.Symbols().Register(decl.name),
std::move(decl.params), decl.return_type, body,
std::move(decl.decorations), ast::DecorationList{}));
// Maintain the invariant by repopulating the one and only element.
statements_stack_.clear();

View File

@ -167,8 +167,13 @@ ParserImpl::FunctionHeader::FunctionHeader(const FunctionHeader&) = default;
ParserImpl::FunctionHeader::FunctionHeader(Source src,
std::string n,
ast::VariableList p,
type::Type* ret_ty)
: source(src), name(n), params(p), return_type(ret_ty) {}
type::Type* ret_ty,
ast::DecorationList ret_decos)
: source(src),
name(n),
params(p),
return_type(ret_ty),
return_type_decorations(ret_decos) {}
ParserImpl::FunctionHeader::~FunctionHeader() = default;
@ -1185,7 +1190,7 @@ Maybe<ast::Function*> ParserImpl::function_decl(ast::DecorationList& decos) {
return create<ast::Function>(
header->source, builder_.Symbols().Register(header->name), header->params,
header->return_type, body.value, decos);
header->return_type, body.value, decos, header->return_type_decorations);
}
// function_type_decl
@ -1225,6 +1230,11 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
if (!expect(use, Token::Type::kArrow))
return Failure::kErrored;
auto decos = decoration_list();
if (decos.errored) {
return Failure::kErrored;
}
auto type = function_type_decl();
if (type.errored) {
errored = true;
@ -1235,8 +1245,8 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
if (errored)
return Failure::kErrored;
return FunctionHeader{source, name.value, std::move(params.value),
type.value};
return FunctionHeader{source, name.value, std::move(params.value), type.value,
std::move(decos.value)};
}
// param_list

View File

@ -218,10 +218,12 @@ class ParserImpl {
/// @param n function name
/// @param p function parameters
/// @param ret_ty function return type
/// @param ret_decos return type decorations
FunctionHeader(Source src,
std::string n,
ast::VariableList p,
type::Type* ret_ty);
type::Type* ret_ty,
ast::DecorationList ret_decos);
/// Destructor
~FunctionHeader();
/// Assignment operator
@ -237,6 +239,8 @@ class ParserImpl {
ast::VariableList params;
/// Function return type
type::Type* return_type;
/// Function return type decorations
ast::DecorationList return_type_decorations;
};
/// VarDeclInfo contains the parsed information for variable declaration.

View File

@ -173,6 +173,37 @@ fn main() -> void { return; })");
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_ReturnTypeDecorationList) {
auto p = parser("fn main() -> [[location(1)]] f32 { return 1.0; }");
auto decos = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(decos.errored);
EXPECT_FALSE(decos.matched);
auto f = p->function_decl(decos.value);
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(f.errored);
EXPECT_TRUE(f.matched);
ASSERT_NE(f.value, nullptr);
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<type::F32>());
ASSERT_EQ(f->params().size(), 0u);
auto& decorations = f->decorations();
EXPECT_EQ(decorations.size(), 0u);
auto& ret_type_decorations = f->return_type_decorations();
ASSERT_EQ(ret_type_decorations.size(), 1u);
auto* loc = ret_type_decorations[0]->As<ast::LocationDecoration>();
ASSERT_TRUE(loc != nullptr);
EXPECT_EQ(loc->value(), 1u);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
auto p = parser("fn main() -> { }");
auto decos = p->decoration_list();

View File

@ -33,6 +33,22 @@ TEST_F(ParserImplTest, FunctionHeader) {
EXPECT_TRUE(f->return_type->Is<type::Void>());
}
TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType) {
auto p = parser("fn main() -> [[location(1)]] f32");
auto f = p->function_header();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(f.matched);
EXPECT_FALSE(f.errored);
EXPECT_EQ(f->name, "main");
EXPECT_EQ(f->params.size(), 0u);
EXPECT_TRUE(f->return_type->Is<type::F32>());
ASSERT_TRUE(f->return_type_decorations.size() == 1u);
auto* loc = f->return_type_decorations[0]->As<ast::LocationDecoration>();
ASSERT_TRUE(loc != nullptr);
EXPECT_EQ(loc->value(), 1u);
}
TEST_F(ParserImplTest, FunctionHeader_MissingIdent) {
auto p = parser("fn () -> void");
auto f = p->function_header();

View File

@ -229,7 +229,8 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const {
func->source(), ctx.Clone(func->symbol()), new_parameters,
ctx.Clone(func->return_type()),
ctx.dst->create<ast::BlockStatement>(new_body),
ctx.Clone(func->decorations()));
ctx.Clone(func->decorations()),
ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func);
}
}

View File

@ -391,7 +391,8 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const {
func->source(), ctx.Clone(func->symbol()), new_parameters,
ctx.Clone(func->return_type()),
ctx.dst->create<ast::BlockStatement>(new_body),
ctx.Clone(func->decorations()));
ctx.Clone(func->decorations()),
ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func);
}
}

View File

@ -138,7 +138,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
ctx.Clone(func->return_type()), ctx.Clone(func->body()),
ctx.Clone(func->decorations()));
ctx.Clone(func->decorations()),
ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func);
}
}

View File

@ -58,8 +58,9 @@ ast::Function* Transform::CloneWithStatementsAtStart(
auto* body = ctx->dst->create<ast::BlockStatement>(
ctx->Clone(in->body()->source()), statements);
auto decos = ctx->Clone(in->decorations());
auto ret_decos = ctx->Clone(in->return_type_decorations());
return ctx->dst->create<ast::Function>(source, symbol, params, return_type,
body, decos);
body, decos, ret_decos);
}
void Transform::RenameReservedKeywords(CloneContext* ctx,

View File

@ -148,6 +148,41 @@ INSTANTIATE_TEST_SUITE_P(
false},
DecorationTestParams{DecorationKind::kWorkgroup, true}));
using FunctionReturnTypeDecorationTest = ValidatorDecorationsTestWithParams;
TEST_P(FunctionReturnTypeDecorationTest, Decoration_IsValid) {
auto params = GetParam();
Func("main", ast::VariableList{}, ty.f32(),
ast::StatementList{create<ast::ReturnStatement>(Expr(1.f))},
ast::DecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex)},
ast::DecorationList{createDecoration(*this, params.kind)});
ValidatorImpl& v = Build();
if (params.should_pass) {
EXPECT_TRUE(v.Validate());
} else {
EXPECT_FALSE(v.Validate());
EXPECT_EQ(v.error(), "decoration is not valid for function return types");
}
}
INSTANTIATE_TEST_SUITE_P(
ValidatorTest,
FunctionReturnTypeDecorationTest,
testing::Values(DecorationTestParams{DecorationKind::kAccess, false},
DecorationTestParams{DecorationKind::kBinding, false},
DecorationTestParams{DecorationKind::kBuiltin, true},
DecorationTestParams{DecorationKind::kConstantId, false},
DecorationTestParams{DecorationKind::kGroup, false},
DecorationTestParams{DecorationKind::kLocation, true},
DecorationTestParams{DecorationKind::kStage, false},
DecorationTestParams{DecorationKind::kStride, false},
DecorationTestParams{DecorationKind::kStructBlock, false},
DecorationTestParams{DecorationKind::kStructMemberOffset,
false},
DecorationTestParams{DecorationKind::kWorkgroup, false}));
using StructDecorationTest = ValidatorDecorationsTestWithParams;
TEST_P(StructDecorationTest, Decoration_IsValid) {
auto params = GetParam();

View File

@ -248,6 +248,15 @@ bool ValidatorImpl::ValidateFunction(const ast::Function* func) {
"non-void function must end with a return statement");
return false;
}
for (auto* deco : current_function_->return_type_decorations()) {
if (!(deco->Is<ast::BuiltinDecoration>() ||
deco->Is<ast::LocationDecoration>())) {
add_error(deco->source(),
"decoration is not valid for function return types");
return false;
}
}
}
return true;
}