[wgsl-reader] Allow decorations on function return types

Add a return type decoration list field to ast::Function.

Bug: tint:513
Change-Id: I41c1087f21a87731eb48ec7642997da5ae7f2baa
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44601
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-03-15 17:01:34 +00:00 committed by Commit Bot service account
parent 20a438a5c3
commit feecbe0d83
14 changed files with 150 additions and 26 deletions

View File

@ -28,13 +28,15 @@ Function::Function(const Source& source,
VariableList params, VariableList params,
type::Type* return_type, type::Type* return_type,
BlockStatement* body, BlockStatement* body,
DecorationList decorations) DecorationList decorations,
DecorationList return_type_decorations)
: Base(source), : Base(source),
symbol_(symbol), symbol_(symbol),
params_(std::move(params)), params_(std::move(params)),
return_type_(return_type), return_type_(return_type),
body_(body), body_(body),
decorations_(std::move(decorations)) { decorations_(std::move(decorations)),
return_type_decorations_(std::move(return_type_decorations)) {
for (auto* param : params_) { for (auto* param : params_) {
TINT_ASSERT(param); TINT_ASSERT(param);
} }
@ -77,7 +79,8 @@ Function* Function::Clone(CloneContext* ctx) const {
auto* ret = ctx->Clone(return_type_); auto* ret = ctx->Clone(return_type_);
auto* b = ctx->Clone(body_); auto* b = ctx->Clone(body_);
auto decos = ctx->Clone(decorations_); auto decos = ctx->Clone(decorations_);
return ctx->dst->create<Function>(src, sym, p, ret, b, decos); auto ret_decos = ctx->Clone(return_type_decorations_);
return ctx->dst->create<Function>(src, sym, p, ret, b, decos, ret_decos);
} }
void Function::to_str(const semantic::Info& sem, void Function::to_str(const semantic::Info& sem,

View File

@ -42,12 +42,14 @@ class Function : public Castable<Function, Node> {
/// @param return_type the return type /// @param return_type the return type
/// @param body the function body /// @param body the function body
/// @param decorations the function decorations /// @param decorations the function decorations
/// @param return_type_decorations the return type decorations
Function(const Source& source, Function(const Source& source,
Symbol symbol, Symbol symbol,
VariableList params, VariableList params,
type::Type* return_type, type::Type* return_type,
BlockStatement* body, BlockStatement* body,
DecorationList decorations); DecorationList decorations,
DecorationList return_type_decorations);
/// Move constructor /// Move constructor
Function(Function&&); Function(Function&&);
@ -74,6 +76,11 @@ class Function : public Castable<Function, Node> {
/// @returns the function return type. /// @returns the function return type.
type::Type* return_type() const { return return_type_; } type::Type* return_type() const { return return_type_; }
/// @returns the decorations attached to the function return type.
const DecorationList& return_type_decorations() const {
return return_type_decorations_;
}
/// @returns a pointer to the last statement of the function or nullptr if /// @returns a pointer to the last statement of the function or nullptr if
// function is empty // function is empty
const Statement* get_last_statement() const; const Statement* get_last_statement() const;
@ -108,6 +115,7 @@ class Function : public Castable<Function, Node> {
type::Type* const return_type_; type::Type* const return_type_;
BlockStatement* const body_; BlockStatement* const body_;
DecorationList const decorations_; DecorationList const decorations_;
DecorationList const return_type_decorations_;
}; };
/// A list of functions /// A list of functions

View File

@ -994,16 +994,18 @@ class ProgramBuilder {
/// @param type the function return type /// @param type the function return type
/// @param body the function body /// @param body the function body
/// @param decorations the function decorations /// @param decorations the function decorations
/// @param return_type_decorations the function return type decorations
/// @returns the function pointer /// @returns the function pointer
ast::Function* Func(Source source, ast::Function* Func(Source source,
std::string name, std::string name,
ast::VariableList params, ast::VariableList params,
type::Type* type, type::Type* type,
ast::StatementList body, ast::StatementList body,
ast::DecorationList decorations) { ast::DecorationList decorations,
auto* func = ast::DecorationList return_type_decorations = {}) {
create<ast::Function>(source, Symbols().Register(name), params, type, auto* func = create<ast::Function>(source, Symbols().Register(name), params,
create<ast::BlockStatement>(body), decorations); type, create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func); AST().AddFunction(func);
return func; return func;
} }
@ -1014,15 +1016,17 @@ class ProgramBuilder {
/// @param type the function return type /// @param type the function return type
/// @param body the function body /// @param body the function body
/// @param decorations the function decorations /// @param decorations the function decorations
/// @param return_type_decorations the function return type decorations
/// @returns the function pointer /// @returns the function pointer
ast::Function* Func(std::string name, ast::Function* Func(std::string name,
ast::VariableList params, ast::VariableList params,
type::Type* type, type::Type* type,
ast::StatementList body, ast::StatementList body,
ast::DecorationList decorations) { ast::DecorationList decorations,
auto* func = ast::DecorationList return_type_decorations = {}) {
create<ast::Function>(Symbols().Register(name), params, type, auto* func = create<ast::Function>(Symbols().Register(name), params, type,
create<ast::BlockStatement>(body), decorations); create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func); AST().AddFunction(func);
return func; return func;
} }

View File

@ -843,10 +843,10 @@ bool FunctionEmitter::Emit() {
auto& statements = statements_stack_[0].GetStatements(); auto& statements = statements_stack_[0].GetStatements();
auto* body = create<ast::BlockStatement>(Source{}, statements); auto* body = create<ast::BlockStatement>(Source{}, statements);
builder_.AST().AddFunction( builder_.AST().AddFunction(create<ast::Function>(
create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name), decl.source, builder_.Symbols().Register(decl.name),
std::move(decl.params), decl.return_type, body, std::move(decl.params), decl.return_type, body,
std::move(decl.decorations))); std::move(decl.decorations), ast::DecorationList{}));
// Maintain the invariant by repopulating the one and only element. // Maintain the invariant by repopulating the one and only element.
statements_stack_.clear(); statements_stack_.clear();

View File

@ -167,8 +167,13 @@ ParserImpl::FunctionHeader::FunctionHeader(const FunctionHeader&) = default;
ParserImpl::FunctionHeader::FunctionHeader(Source src, ParserImpl::FunctionHeader::FunctionHeader(Source src,
std::string n, std::string n,
ast::VariableList p, ast::VariableList p,
type::Type* ret_ty) type::Type* ret_ty,
: source(src), name(n), params(p), return_type(ret_ty) {} ast::DecorationList ret_decos)
: source(src),
name(n),
params(p),
return_type(ret_ty),
return_type_decorations(ret_decos) {}
ParserImpl::FunctionHeader::~FunctionHeader() = default; ParserImpl::FunctionHeader::~FunctionHeader() = default;
@ -1185,7 +1190,7 @@ Maybe<ast::Function*> ParserImpl::function_decl(ast::DecorationList& decos) {
return create<ast::Function>( return create<ast::Function>(
header->source, builder_.Symbols().Register(header->name), header->params, header->source, builder_.Symbols().Register(header->name), header->params,
header->return_type, body.value, decos); header->return_type, body.value, decos, header->return_type_decorations);
} }
// function_type_decl // function_type_decl
@ -1225,6 +1230,11 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
if (!expect(use, Token::Type::kArrow)) if (!expect(use, Token::Type::kArrow))
return Failure::kErrored; return Failure::kErrored;
auto decos = decoration_list();
if (decos.errored) {
return Failure::kErrored;
}
auto type = function_type_decl(); auto type = function_type_decl();
if (type.errored) { if (type.errored) {
errored = true; errored = true;
@ -1235,8 +1245,8 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
if (errored) if (errored)
return Failure::kErrored; return Failure::kErrored;
return FunctionHeader{source, name.value, std::move(params.value), return FunctionHeader{source, name.value, std::move(params.value), type.value,
type.value}; std::move(decos.value)};
} }
// param_list // param_list

View File

@ -218,10 +218,12 @@ class ParserImpl {
/// @param n function name /// @param n function name
/// @param p function parameters /// @param p function parameters
/// @param ret_ty function return type /// @param ret_ty function return type
/// @param ret_decos return type decorations
FunctionHeader(Source src, FunctionHeader(Source src,
std::string n, std::string n,
ast::VariableList p, ast::VariableList p,
type::Type* ret_ty); type::Type* ret_ty,
ast::DecorationList ret_decos);
/// Destructor /// Destructor
~FunctionHeader(); ~FunctionHeader();
/// Assignment operator /// Assignment operator
@ -237,6 +239,8 @@ class ParserImpl {
ast::VariableList params; ast::VariableList params;
/// Function return type /// Function return type
type::Type* return_type; type::Type* return_type;
/// Function return type decorations
ast::DecorationList return_type_decorations;
}; };
/// VarDeclInfo contains the parsed information for variable declaration. /// VarDeclInfo contains the parsed information for variable declaration.

View File

@ -173,6 +173,37 @@ fn main() -> void { return; })");
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>()); EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
} }
TEST_F(ParserImplTest, FunctionDecl_ReturnTypeDecorationList) {
auto p = parser("fn main() -> [[location(1)]] f32 { return 1.0; }");
auto decos = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(decos.errored);
EXPECT_FALSE(decos.matched);
auto f = p->function_decl(decos.value);
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(f.errored);
EXPECT_TRUE(f.matched);
ASSERT_NE(f.value, nullptr);
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<type::F32>());
ASSERT_EQ(f->params().size(), 0u);
auto& decorations = f->decorations();
EXPECT_EQ(decorations.size(), 0u);
auto& ret_type_decorations = f->return_type_decorations();
ASSERT_EQ(ret_type_decorations.size(), 1u);
auto* loc = ret_type_decorations[0]->As<ast::LocationDecoration>();
ASSERT_TRUE(loc != nullptr);
EXPECT_EQ(loc->value(), 1u);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) { TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
auto p = parser("fn main() -> { }"); auto p = parser("fn main() -> { }");
auto decos = p->decoration_list(); auto decos = p->decoration_list();

View File

@ -33,6 +33,22 @@ TEST_F(ParserImplTest, FunctionHeader) {
EXPECT_TRUE(f->return_type->Is<type::Void>()); EXPECT_TRUE(f->return_type->Is<type::Void>());
} }
TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType) {
auto p = parser("fn main() -> [[location(1)]] f32");
auto f = p->function_header();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(f.matched);
EXPECT_FALSE(f.errored);
EXPECT_EQ(f->name, "main");
EXPECT_EQ(f->params.size(), 0u);
EXPECT_TRUE(f->return_type->Is<type::F32>());
ASSERT_TRUE(f->return_type_decorations.size() == 1u);
auto* loc = f->return_type_decorations[0]->As<ast::LocationDecoration>();
ASSERT_TRUE(loc != nullptr);
EXPECT_EQ(loc->value(), 1u);
}
TEST_F(ParserImplTest, FunctionHeader_MissingIdent) { TEST_F(ParserImplTest, FunctionHeader_MissingIdent) {
auto p = parser("fn () -> void"); auto p = parser("fn () -> void");
auto f = p->function_header(); auto f = p->function_header();

View File

@ -229,7 +229,8 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const {
func->source(), ctx.Clone(func->symbol()), new_parameters, func->source(), ctx.Clone(func->symbol()), new_parameters,
ctx.Clone(func->return_type()), ctx.Clone(func->return_type()),
ctx.dst->create<ast::BlockStatement>(new_body), ctx.dst->create<ast::BlockStatement>(new_body),
ctx.Clone(func->decorations())); ctx.Clone(func->decorations()),
ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func); ctx.Replace(func, new_func);
} }
} }

View File

@ -391,7 +391,8 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const {
func->source(), ctx.Clone(func->symbol()), new_parameters, func->source(), ctx.Clone(func->symbol()), new_parameters,
ctx.Clone(func->return_type()), ctx.Clone(func->return_type()),
ctx.dst->create<ast::BlockStatement>(new_body), ctx.dst->create<ast::BlockStatement>(new_body),
ctx.Clone(func->decorations())); ctx.Clone(func->decorations()),
ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func); ctx.Replace(func, new_func);
} }
} }

View File

@ -138,7 +138,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
auto* new_func = ctx.dst->create<ast::Function>( auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), ast::VariableList{}, func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
ctx.Clone(func->return_type()), ctx.Clone(func->body()), ctx.Clone(func->return_type()), ctx.Clone(func->body()),
ctx.Clone(func->decorations())); ctx.Clone(func->decorations()),
ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func); ctx.Replace(func, new_func);
} }
} }

View File

@ -58,8 +58,9 @@ ast::Function* Transform::CloneWithStatementsAtStart(
auto* body = ctx->dst->create<ast::BlockStatement>( auto* body = ctx->dst->create<ast::BlockStatement>(
ctx->Clone(in->body()->source()), statements); ctx->Clone(in->body()->source()), statements);
auto decos = ctx->Clone(in->decorations()); auto decos = ctx->Clone(in->decorations());
auto ret_decos = ctx->Clone(in->return_type_decorations());
return ctx->dst->create<ast::Function>(source, symbol, params, return_type, return ctx->dst->create<ast::Function>(source, symbol, params, return_type,
body, decos); body, decos, ret_decos);
} }
void Transform::RenameReservedKeywords(CloneContext* ctx, void Transform::RenameReservedKeywords(CloneContext* ctx,

View File

@ -148,6 +148,41 @@ INSTANTIATE_TEST_SUITE_P(
false}, false},
DecorationTestParams{DecorationKind::kWorkgroup, true})); DecorationTestParams{DecorationKind::kWorkgroup, true}));
using FunctionReturnTypeDecorationTest = ValidatorDecorationsTestWithParams;
TEST_P(FunctionReturnTypeDecorationTest, Decoration_IsValid) {
auto params = GetParam();
Func("main", ast::VariableList{}, ty.f32(),
ast::StatementList{create<ast::ReturnStatement>(Expr(1.f))},
ast::DecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex)},
ast::DecorationList{createDecoration(*this, params.kind)});
ValidatorImpl& v = Build();
if (params.should_pass) {
EXPECT_TRUE(v.Validate());
} else {
EXPECT_FALSE(v.Validate());
EXPECT_EQ(v.error(), "decoration is not valid for function return types");
}
}
INSTANTIATE_TEST_SUITE_P(
ValidatorTest,
FunctionReturnTypeDecorationTest,
testing::Values(DecorationTestParams{DecorationKind::kAccess, false},
DecorationTestParams{DecorationKind::kBinding, false},
DecorationTestParams{DecorationKind::kBuiltin, true},
DecorationTestParams{DecorationKind::kConstantId, false},
DecorationTestParams{DecorationKind::kGroup, false},
DecorationTestParams{DecorationKind::kLocation, true},
DecorationTestParams{DecorationKind::kStage, false},
DecorationTestParams{DecorationKind::kStride, false},
DecorationTestParams{DecorationKind::kStructBlock, false},
DecorationTestParams{DecorationKind::kStructMemberOffset,
false},
DecorationTestParams{DecorationKind::kWorkgroup, false}));
using StructDecorationTest = ValidatorDecorationsTestWithParams; using StructDecorationTest = ValidatorDecorationsTestWithParams;
TEST_P(StructDecorationTest, Decoration_IsValid) { TEST_P(StructDecorationTest, Decoration_IsValid) {
auto params = GetParam(); auto params = GetParam();

View File

@ -248,6 +248,15 @@ bool ValidatorImpl::ValidateFunction(const ast::Function* func) {
"non-void function must end with a return statement"); "non-void function must end with a return statement");
return false; return false;
} }
for (auto* deco : current_function_->return_type_decorations()) {
if (!(deco->Is<ast::BuiltinDecoration>() ||
deco->Is<ast::LocationDecoration>())) {
add_error(deco->source(),
"decoration is not valid for function return types");
return false;
}
}
} }
return true; return true;
} }