[wgsl-reader] Allow decorations on function return types
Add a return type decoration list field to ast::Function. Bug: tint:513 Change-Id: I41c1087f21a87731eb48ec7642997da5ae7f2baa Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44601 Commit-Queue: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
20a438a5c3
commit
feecbe0d83
|
@ -28,13 +28,15 @@ Function::Function(const Source& source,
|
||||||
VariableList params,
|
VariableList params,
|
||||||
type::Type* return_type,
|
type::Type* return_type,
|
||||||
BlockStatement* body,
|
BlockStatement* body,
|
||||||
DecorationList decorations)
|
DecorationList decorations,
|
||||||
|
DecorationList return_type_decorations)
|
||||||
: Base(source),
|
: Base(source),
|
||||||
symbol_(symbol),
|
symbol_(symbol),
|
||||||
params_(std::move(params)),
|
params_(std::move(params)),
|
||||||
return_type_(return_type),
|
return_type_(return_type),
|
||||||
body_(body),
|
body_(body),
|
||||||
decorations_(std::move(decorations)) {
|
decorations_(std::move(decorations)),
|
||||||
|
return_type_decorations_(std::move(return_type_decorations)) {
|
||||||
for (auto* param : params_) {
|
for (auto* param : params_) {
|
||||||
TINT_ASSERT(param);
|
TINT_ASSERT(param);
|
||||||
}
|
}
|
||||||
|
@ -77,7 +79,8 @@ Function* Function::Clone(CloneContext* ctx) const {
|
||||||
auto* ret = ctx->Clone(return_type_);
|
auto* ret = ctx->Clone(return_type_);
|
||||||
auto* b = ctx->Clone(body_);
|
auto* b = ctx->Clone(body_);
|
||||||
auto decos = ctx->Clone(decorations_);
|
auto decos = ctx->Clone(decorations_);
|
||||||
return ctx->dst->create<Function>(src, sym, p, ret, b, decos);
|
auto ret_decos = ctx->Clone(return_type_decorations_);
|
||||||
|
return ctx->dst->create<Function>(src, sym, p, ret, b, decos, ret_decos);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Function::to_str(const semantic::Info& sem,
|
void Function::to_str(const semantic::Info& sem,
|
||||||
|
|
|
@ -42,12 +42,14 @@ class Function : public Castable<Function, Node> {
|
||||||
/// @param return_type the return type
|
/// @param return_type the return type
|
||||||
/// @param body the function body
|
/// @param body the function body
|
||||||
/// @param decorations the function decorations
|
/// @param decorations the function decorations
|
||||||
|
/// @param return_type_decorations the return type decorations
|
||||||
Function(const Source& source,
|
Function(const Source& source,
|
||||||
Symbol symbol,
|
Symbol symbol,
|
||||||
VariableList params,
|
VariableList params,
|
||||||
type::Type* return_type,
|
type::Type* return_type,
|
||||||
BlockStatement* body,
|
BlockStatement* body,
|
||||||
DecorationList decorations);
|
DecorationList decorations,
|
||||||
|
DecorationList return_type_decorations);
|
||||||
/// Move constructor
|
/// Move constructor
|
||||||
Function(Function&&);
|
Function(Function&&);
|
||||||
|
|
||||||
|
@ -74,6 +76,11 @@ class Function : public Castable<Function, Node> {
|
||||||
/// @returns the function return type.
|
/// @returns the function return type.
|
||||||
type::Type* return_type() const { return return_type_; }
|
type::Type* return_type() const { return return_type_; }
|
||||||
|
|
||||||
|
/// @returns the decorations attached to the function return type.
|
||||||
|
const DecorationList& return_type_decorations() const {
|
||||||
|
return return_type_decorations_;
|
||||||
|
}
|
||||||
|
|
||||||
/// @returns a pointer to the last statement of the function or nullptr if
|
/// @returns a pointer to the last statement of the function or nullptr if
|
||||||
// function is empty
|
// function is empty
|
||||||
const Statement* get_last_statement() const;
|
const Statement* get_last_statement() const;
|
||||||
|
@ -108,6 +115,7 @@ class Function : public Castable<Function, Node> {
|
||||||
type::Type* const return_type_;
|
type::Type* const return_type_;
|
||||||
BlockStatement* const body_;
|
BlockStatement* const body_;
|
||||||
DecorationList const decorations_;
|
DecorationList const decorations_;
|
||||||
|
DecorationList const return_type_decorations_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A list of functions
|
/// A list of functions
|
||||||
|
|
|
@ -994,16 +994,18 @@ class ProgramBuilder {
|
||||||
/// @param type the function return type
|
/// @param type the function return type
|
||||||
/// @param body the function body
|
/// @param body the function body
|
||||||
/// @param decorations the function decorations
|
/// @param decorations the function decorations
|
||||||
|
/// @param return_type_decorations the function return type decorations
|
||||||
/// @returns the function pointer
|
/// @returns the function pointer
|
||||||
ast::Function* Func(Source source,
|
ast::Function* Func(Source source,
|
||||||
std::string name,
|
std::string name,
|
||||||
ast::VariableList params,
|
ast::VariableList params,
|
||||||
type::Type* type,
|
type::Type* type,
|
||||||
ast::StatementList body,
|
ast::StatementList body,
|
||||||
ast::DecorationList decorations) {
|
ast::DecorationList decorations,
|
||||||
auto* func =
|
ast::DecorationList return_type_decorations = {}) {
|
||||||
create<ast::Function>(source, Symbols().Register(name), params, type,
|
auto* func = create<ast::Function>(source, Symbols().Register(name), params,
|
||||||
create<ast::BlockStatement>(body), decorations);
|
type, create<ast::BlockStatement>(body),
|
||||||
|
decorations, return_type_decorations);
|
||||||
AST().AddFunction(func);
|
AST().AddFunction(func);
|
||||||
return func;
|
return func;
|
||||||
}
|
}
|
||||||
|
@ -1014,15 +1016,17 @@ class ProgramBuilder {
|
||||||
/// @param type the function return type
|
/// @param type the function return type
|
||||||
/// @param body the function body
|
/// @param body the function body
|
||||||
/// @param decorations the function decorations
|
/// @param decorations the function decorations
|
||||||
|
/// @param return_type_decorations the function return type decorations
|
||||||
/// @returns the function pointer
|
/// @returns the function pointer
|
||||||
ast::Function* Func(std::string name,
|
ast::Function* Func(std::string name,
|
||||||
ast::VariableList params,
|
ast::VariableList params,
|
||||||
type::Type* type,
|
type::Type* type,
|
||||||
ast::StatementList body,
|
ast::StatementList body,
|
||||||
ast::DecorationList decorations) {
|
ast::DecorationList decorations,
|
||||||
auto* func =
|
ast::DecorationList return_type_decorations = {}) {
|
||||||
create<ast::Function>(Symbols().Register(name), params, type,
|
auto* func = create<ast::Function>(Symbols().Register(name), params, type,
|
||||||
create<ast::BlockStatement>(body), decorations);
|
create<ast::BlockStatement>(body),
|
||||||
|
decorations, return_type_decorations);
|
||||||
AST().AddFunction(func);
|
AST().AddFunction(func);
|
||||||
return func;
|
return func;
|
||||||
}
|
}
|
||||||
|
|
|
@ -843,10 +843,10 @@ bool FunctionEmitter::Emit() {
|
||||||
|
|
||||||
auto& statements = statements_stack_[0].GetStatements();
|
auto& statements = statements_stack_[0].GetStatements();
|
||||||
auto* body = create<ast::BlockStatement>(Source{}, statements);
|
auto* body = create<ast::BlockStatement>(Source{}, statements);
|
||||||
builder_.AST().AddFunction(
|
builder_.AST().AddFunction(create<ast::Function>(
|
||||||
create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name),
|
decl.source, builder_.Symbols().Register(decl.name),
|
||||||
std::move(decl.params), decl.return_type, body,
|
std::move(decl.params), decl.return_type, body,
|
||||||
std::move(decl.decorations)));
|
std::move(decl.decorations), ast::DecorationList{}));
|
||||||
|
|
||||||
// Maintain the invariant by repopulating the one and only element.
|
// Maintain the invariant by repopulating the one and only element.
|
||||||
statements_stack_.clear();
|
statements_stack_.clear();
|
||||||
|
|
|
@ -167,8 +167,13 @@ ParserImpl::FunctionHeader::FunctionHeader(const FunctionHeader&) = default;
|
||||||
ParserImpl::FunctionHeader::FunctionHeader(Source src,
|
ParserImpl::FunctionHeader::FunctionHeader(Source src,
|
||||||
std::string n,
|
std::string n,
|
||||||
ast::VariableList p,
|
ast::VariableList p,
|
||||||
type::Type* ret_ty)
|
type::Type* ret_ty,
|
||||||
: source(src), name(n), params(p), return_type(ret_ty) {}
|
ast::DecorationList ret_decos)
|
||||||
|
: source(src),
|
||||||
|
name(n),
|
||||||
|
params(p),
|
||||||
|
return_type(ret_ty),
|
||||||
|
return_type_decorations(ret_decos) {}
|
||||||
|
|
||||||
ParserImpl::FunctionHeader::~FunctionHeader() = default;
|
ParserImpl::FunctionHeader::~FunctionHeader() = default;
|
||||||
|
|
||||||
|
@ -1185,7 +1190,7 @@ Maybe<ast::Function*> ParserImpl::function_decl(ast::DecorationList& decos) {
|
||||||
|
|
||||||
return create<ast::Function>(
|
return create<ast::Function>(
|
||||||
header->source, builder_.Symbols().Register(header->name), header->params,
|
header->source, builder_.Symbols().Register(header->name), header->params,
|
||||||
header->return_type, body.value, decos);
|
header->return_type, body.value, decos, header->return_type_decorations);
|
||||||
}
|
}
|
||||||
|
|
||||||
// function_type_decl
|
// function_type_decl
|
||||||
|
@ -1225,6 +1230,11 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
|
||||||
if (!expect(use, Token::Type::kArrow))
|
if (!expect(use, Token::Type::kArrow))
|
||||||
return Failure::kErrored;
|
return Failure::kErrored;
|
||||||
|
|
||||||
|
auto decos = decoration_list();
|
||||||
|
if (decos.errored) {
|
||||||
|
return Failure::kErrored;
|
||||||
|
}
|
||||||
|
|
||||||
auto type = function_type_decl();
|
auto type = function_type_decl();
|
||||||
if (type.errored) {
|
if (type.errored) {
|
||||||
errored = true;
|
errored = true;
|
||||||
|
@ -1235,8 +1245,8 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
|
||||||
if (errored)
|
if (errored)
|
||||||
return Failure::kErrored;
|
return Failure::kErrored;
|
||||||
|
|
||||||
return FunctionHeader{source, name.value, std::move(params.value),
|
return FunctionHeader{source, name.value, std::move(params.value), type.value,
|
||||||
type.value};
|
std::move(decos.value)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// param_list
|
// param_list
|
||||||
|
|
|
@ -218,10 +218,12 @@ class ParserImpl {
|
||||||
/// @param n function name
|
/// @param n function name
|
||||||
/// @param p function parameters
|
/// @param p function parameters
|
||||||
/// @param ret_ty function return type
|
/// @param ret_ty function return type
|
||||||
|
/// @param ret_decos return type decorations
|
||||||
FunctionHeader(Source src,
|
FunctionHeader(Source src,
|
||||||
std::string n,
|
std::string n,
|
||||||
ast::VariableList p,
|
ast::VariableList p,
|
||||||
type::Type* ret_ty);
|
type::Type* ret_ty,
|
||||||
|
ast::DecorationList ret_decos);
|
||||||
/// Destructor
|
/// Destructor
|
||||||
~FunctionHeader();
|
~FunctionHeader();
|
||||||
/// Assignment operator
|
/// Assignment operator
|
||||||
|
@ -237,6 +239,8 @@ class ParserImpl {
|
||||||
ast::VariableList params;
|
ast::VariableList params;
|
||||||
/// Function return type
|
/// Function return type
|
||||||
type::Type* return_type;
|
type::Type* return_type;
|
||||||
|
/// Function return type decorations
|
||||||
|
ast::DecorationList return_type_decorations;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// VarDeclInfo contains the parsed information for variable declaration.
|
/// VarDeclInfo contains the parsed information for variable declaration.
|
||||||
|
|
|
@ -173,6 +173,37 @@ fn main() -> void { return; })");
|
||||||
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
|
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ParserImplTest, FunctionDecl_ReturnTypeDecorationList) {
|
||||||
|
auto p = parser("fn main() -> [[location(1)]] f32 { return 1.0; }");
|
||||||
|
auto decos = p->decoration_list();
|
||||||
|
EXPECT_FALSE(p->has_error()) << p->error();
|
||||||
|
ASSERT_FALSE(decos.errored);
|
||||||
|
EXPECT_FALSE(decos.matched);
|
||||||
|
auto f = p->function_decl(decos.value);
|
||||||
|
EXPECT_FALSE(p->has_error()) << p->error();
|
||||||
|
EXPECT_FALSE(f.errored);
|
||||||
|
EXPECT_TRUE(f.matched);
|
||||||
|
ASSERT_NE(f.value, nullptr);
|
||||||
|
|
||||||
|
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
|
||||||
|
ASSERT_NE(f->return_type(), nullptr);
|
||||||
|
EXPECT_TRUE(f->return_type()->Is<type::F32>());
|
||||||
|
ASSERT_EQ(f->params().size(), 0u);
|
||||||
|
|
||||||
|
auto& decorations = f->decorations();
|
||||||
|
EXPECT_EQ(decorations.size(), 0u);
|
||||||
|
|
||||||
|
auto& ret_type_decorations = f->return_type_decorations();
|
||||||
|
ASSERT_EQ(ret_type_decorations.size(), 1u);
|
||||||
|
auto* loc = ret_type_decorations[0]->As<ast::LocationDecoration>();
|
||||||
|
ASSERT_TRUE(loc != nullptr);
|
||||||
|
EXPECT_EQ(loc->value(), 1u);
|
||||||
|
|
||||||
|
auto* body = f->body();
|
||||||
|
ASSERT_EQ(body->size(), 1u);
|
||||||
|
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
|
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
|
||||||
auto p = parser("fn main() -> { }");
|
auto p = parser("fn main() -> { }");
|
||||||
auto decos = p->decoration_list();
|
auto decos = p->decoration_list();
|
||||||
|
|
|
@ -33,6 +33,22 @@ TEST_F(ParserImplTest, FunctionHeader) {
|
||||||
EXPECT_TRUE(f->return_type->Is<type::Void>());
|
EXPECT_TRUE(f->return_type->Is<type::Void>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType) {
|
||||||
|
auto p = parser("fn main() -> [[location(1)]] f32");
|
||||||
|
auto f = p->function_header();
|
||||||
|
ASSERT_FALSE(p->has_error()) << p->error();
|
||||||
|
EXPECT_TRUE(f.matched);
|
||||||
|
EXPECT_FALSE(f.errored);
|
||||||
|
|
||||||
|
EXPECT_EQ(f->name, "main");
|
||||||
|
EXPECT_EQ(f->params.size(), 0u);
|
||||||
|
EXPECT_TRUE(f->return_type->Is<type::F32>());
|
||||||
|
ASSERT_TRUE(f->return_type_decorations.size() == 1u);
|
||||||
|
auto* loc = f->return_type_decorations[0]->As<ast::LocationDecoration>();
|
||||||
|
ASSERT_TRUE(loc != nullptr);
|
||||||
|
EXPECT_EQ(loc->value(), 1u);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ParserImplTest, FunctionHeader_MissingIdent) {
|
TEST_F(ParserImplTest, FunctionHeader_MissingIdent) {
|
||||||
auto p = parser("fn () -> void");
|
auto p = parser("fn () -> void");
|
||||||
auto f = p->function_header();
|
auto f = p->function_header();
|
||||||
|
|
|
@ -229,7 +229,8 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
||||||
func->source(), ctx.Clone(func->symbol()), new_parameters,
|
func->source(), ctx.Clone(func->symbol()), new_parameters,
|
||||||
ctx.Clone(func->return_type()),
|
ctx.Clone(func->return_type()),
|
||||||
ctx.dst->create<ast::BlockStatement>(new_body),
|
ctx.dst->create<ast::BlockStatement>(new_body),
|
||||||
ctx.Clone(func->decorations()));
|
ctx.Clone(func->decorations()),
|
||||||
|
ctx.Clone(func->return_type_decorations()));
|
||||||
ctx.Replace(func, new_func);
|
ctx.Replace(func, new_func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -391,7 +391,8 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
||||||
func->source(), ctx.Clone(func->symbol()), new_parameters,
|
func->source(), ctx.Clone(func->symbol()), new_parameters,
|
||||||
ctx.Clone(func->return_type()),
|
ctx.Clone(func->return_type()),
|
||||||
ctx.dst->create<ast::BlockStatement>(new_body),
|
ctx.dst->create<ast::BlockStatement>(new_body),
|
||||||
ctx.Clone(func->decorations()));
|
ctx.Clone(func->decorations()),
|
||||||
|
ctx.Clone(func->return_type_decorations()));
|
||||||
ctx.Replace(func, new_func);
|
ctx.Replace(func, new_func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,7 +138,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
||||||
auto* new_func = ctx.dst->create<ast::Function>(
|
auto* new_func = ctx.dst->create<ast::Function>(
|
||||||
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
|
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
|
||||||
ctx.Clone(func->return_type()), ctx.Clone(func->body()),
|
ctx.Clone(func->return_type()), ctx.Clone(func->body()),
|
||||||
ctx.Clone(func->decorations()));
|
ctx.Clone(func->decorations()),
|
||||||
|
ctx.Clone(func->return_type_decorations()));
|
||||||
ctx.Replace(func, new_func);
|
ctx.Replace(func, new_func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,8 +58,9 @@ ast::Function* Transform::CloneWithStatementsAtStart(
|
||||||
auto* body = ctx->dst->create<ast::BlockStatement>(
|
auto* body = ctx->dst->create<ast::BlockStatement>(
|
||||||
ctx->Clone(in->body()->source()), statements);
|
ctx->Clone(in->body()->source()), statements);
|
||||||
auto decos = ctx->Clone(in->decorations());
|
auto decos = ctx->Clone(in->decorations());
|
||||||
|
auto ret_decos = ctx->Clone(in->return_type_decorations());
|
||||||
return ctx->dst->create<ast::Function>(source, symbol, params, return_type,
|
return ctx->dst->create<ast::Function>(source, symbol, params, return_type,
|
||||||
body, decos);
|
body, decos, ret_decos);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Transform::RenameReservedKeywords(CloneContext* ctx,
|
void Transform::RenameReservedKeywords(CloneContext* ctx,
|
||||||
|
|
|
@ -148,6 +148,41 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
false},
|
false},
|
||||||
DecorationTestParams{DecorationKind::kWorkgroup, true}));
|
DecorationTestParams{DecorationKind::kWorkgroup, true}));
|
||||||
|
|
||||||
|
using FunctionReturnTypeDecorationTest = ValidatorDecorationsTestWithParams;
|
||||||
|
TEST_P(FunctionReturnTypeDecorationTest, Decoration_IsValid) {
|
||||||
|
auto params = GetParam();
|
||||||
|
|
||||||
|
Func("main", ast::VariableList{}, ty.f32(),
|
||||||
|
ast::StatementList{create<ast::ReturnStatement>(Expr(1.f))},
|
||||||
|
ast::DecorationList{
|
||||||
|
create<ast::StageDecoration>(ast::PipelineStage::kVertex)},
|
||||||
|
ast::DecorationList{createDecoration(*this, params.kind)});
|
||||||
|
|
||||||
|
ValidatorImpl& v = Build();
|
||||||
|
|
||||||
|
if (params.should_pass) {
|
||||||
|
EXPECT_TRUE(v.Validate());
|
||||||
|
} else {
|
||||||
|
EXPECT_FALSE(v.Validate());
|
||||||
|
EXPECT_EQ(v.error(), "decoration is not valid for function return types");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
ValidatorTest,
|
||||||
|
FunctionReturnTypeDecorationTest,
|
||||||
|
testing::Values(DecorationTestParams{DecorationKind::kAccess, false},
|
||||||
|
DecorationTestParams{DecorationKind::kBinding, false},
|
||||||
|
DecorationTestParams{DecorationKind::kBuiltin, true},
|
||||||
|
DecorationTestParams{DecorationKind::kConstantId, false},
|
||||||
|
DecorationTestParams{DecorationKind::kGroup, false},
|
||||||
|
DecorationTestParams{DecorationKind::kLocation, true},
|
||||||
|
DecorationTestParams{DecorationKind::kStage, false},
|
||||||
|
DecorationTestParams{DecorationKind::kStride, false},
|
||||||
|
DecorationTestParams{DecorationKind::kStructBlock, false},
|
||||||
|
DecorationTestParams{DecorationKind::kStructMemberOffset,
|
||||||
|
false},
|
||||||
|
DecorationTestParams{DecorationKind::kWorkgroup, false}));
|
||||||
|
|
||||||
using StructDecorationTest = ValidatorDecorationsTestWithParams;
|
using StructDecorationTest = ValidatorDecorationsTestWithParams;
|
||||||
TEST_P(StructDecorationTest, Decoration_IsValid) {
|
TEST_P(StructDecorationTest, Decoration_IsValid) {
|
||||||
auto params = GetParam();
|
auto params = GetParam();
|
||||||
|
|
|
@ -248,6 +248,15 @@ bool ValidatorImpl::ValidateFunction(const ast::Function* func) {
|
||||||
"non-void function must end with a return statement");
|
"non-void function must end with a return statement");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto* deco : current_function_->return_type_decorations()) {
|
||||||
|
if (!(deco->Is<ast::BuiltinDecoration>() ||
|
||||||
|
deco->Is<ast::LocationDecoration>())) {
|
||||||
|
add_error(deco->source(),
|
||||||
|
"decoration is not valid for function return types");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue