[wgsl-reader] Allow decorations on function return types
Add a return type decoration list field to ast::Function. Bug: tint:513 Change-Id: I41c1087f21a87731eb48ec7642997da5ae7f2baa Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44601 Commit-Queue: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
20a438a5c3
commit
feecbe0d83
|
@ -28,13 +28,15 @@ Function::Function(const Source& source,
|
|||
VariableList params,
|
||||
type::Type* return_type,
|
||||
BlockStatement* body,
|
||||
DecorationList decorations)
|
||||
DecorationList decorations,
|
||||
DecorationList return_type_decorations)
|
||||
: Base(source),
|
||||
symbol_(symbol),
|
||||
params_(std::move(params)),
|
||||
return_type_(return_type),
|
||||
body_(body),
|
||||
decorations_(std::move(decorations)) {
|
||||
decorations_(std::move(decorations)),
|
||||
return_type_decorations_(std::move(return_type_decorations)) {
|
||||
for (auto* param : params_) {
|
||||
TINT_ASSERT(param);
|
||||
}
|
||||
|
@ -77,7 +79,8 @@ Function* Function::Clone(CloneContext* ctx) const {
|
|||
auto* ret = ctx->Clone(return_type_);
|
||||
auto* b = ctx->Clone(body_);
|
||||
auto decos = ctx->Clone(decorations_);
|
||||
return ctx->dst->create<Function>(src, sym, p, ret, b, decos);
|
||||
auto ret_decos = ctx->Clone(return_type_decorations_);
|
||||
return ctx->dst->create<Function>(src, sym, p, ret, b, decos, ret_decos);
|
||||
}
|
||||
|
||||
void Function::to_str(const semantic::Info& sem,
|
||||
|
|
|
@ -42,12 +42,14 @@ class Function : public Castable<Function, Node> {
|
|||
/// @param return_type the return type
|
||||
/// @param body the function body
|
||||
/// @param decorations the function decorations
|
||||
/// @param return_type_decorations the return type decorations
|
||||
Function(const Source& source,
|
||||
Symbol symbol,
|
||||
VariableList params,
|
||||
type::Type* return_type,
|
||||
BlockStatement* body,
|
||||
DecorationList decorations);
|
||||
DecorationList decorations,
|
||||
DecorationList return_type_decorations);
|
||||
/// Move constructor
|
||||
Function(Function&&);
|
||||
|
||||
|
@ -74,6 +76,11 @@ class Function : public Castable<Function, Node> {
|
|||
/// @returns the function return type.
|
||||
type::Type* return_type() const { return return_type_; }
|
||||
|
||||
/// @returns the decorations attached to the function return type.
|
||||
const DecorationList& return_type_decorations() const {
|
||||
return return_type_decorations_;
|
||||
}
|
||||
|
||||
/// @returns a pointer to the last statement of the function or nullptr if
|
||||
// function is empty
|
||||
const Statement* get_last_statement() const;
|
||||
|
@ -108,6 +115,7 @@ class Function : public Castable<Function, Node> {
|
|||
type::Type* const return_type_;
|
||||
BlockStatement* const body_;
|
||||
DecorationList const decorations_;
|
||||
DecorationList const return_type_decorations_;
|
||||
};
|
||||
|
||||
/// A list of functions
|
||||
|
|
|
@ -994,16 +994,18 @@ class ProgramBuilder {
|
|||
/// @param type the function return type
|
||||
/// @param body the function body
|
||||
/// @param decorations the function decorations
|
||||
/// @param return_type_decorations the function return type decorations
|
||||
/// @returns the function pointer
|
||||
ast::Function* Func(Source source,
|
||||
std::string name,
|
||||
ast::VariableList params,
|
||||
type::Type* type,
|
||||
ast::StatementList body,
|
||||
ast::DecorationList decorations) {
|
||||
auto* func =
|
||||
create<ast::Function>(source, Symbols().Register(name), params, type,
|
||||
create<ast::BlockStatement>(body), decorations);
|
||||
ast::DecorationList decorations,
|
||||
ast::DecorationList return_type_decorations = {}) {
|
||||
auto* func = create<ast::Function>(source, Symbols().Register(name), params,
|
||||
type, create<ast::BlockStatement>(body),
|
||||
decorations, return_type_decorations);
|
||||
AST().AddFunction(func);
|
||||
return func;
|
||||
}
|
||||
|
@ -1014,15 +1016,17 @@ class ProgramBuilder {
|
|||
/// @param type the function return type
|
||||
/// @param body the function body
|
||||
/// @param decorations the function decorations
|
||||
/// @param return_type_decorations the function return type decorations
|
||||
/// @returns the function pointer
|
||||
ast::Function* Func(std::string name,
|
||||
ast::VariableList params,
|
||||
type::Type* type,
|
||||
ast::StatementList body,
|
||||
ast::DecorationList decorations) {
|
||||
auto* func =
|
||||
create<ast::Function>(Symbols().Register(name), params, type,
|
||||
create<ast::BlockStatement>(body), decorations);
|
||||
ast::DecorationList decorations,
|
||||
ast::DecorationList return_type_decorations = {}) {
|
||||
auto* func = create<ast::Function>(Symbols().Register(name), params, type,
|
||||
create<ast::BlockStatement>(body),
|
||||
decorations, return_type_decorations);
|
||||
AST().AddFunction(func);
|
||||
return func;
|
||||
}
|
||||
|
|
|
@ -843,10 +843,10 @@ bool FunctionEmitter::Emit() {
|
|||
|
||||
auto& statements = statements_stack_[0].GetStatements();
|
||||
auto* body = create<ast::BlockStatement>(Source{}, statements);
|
||||
builder_.AST().AddFunction(
|
||||
create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name),
|
||||
std::move(decl.params), decl.return_type, body,
|
||||
std::move(decl.decorations)));
|
||||
builder_.AST().AddFunction(create<ast::Function>(
|
||||
decl.source, builder_.Symbols().Register(decl.name),
|
||||
std::move(decl.params), decl.return_type, body,
|
||||
std::move(decl.decorations), ast::DecorationList{}));
|
||||
|
||||
// Maintain the invariant by repopulating the one and only element.
|
||||
statements_stack_.clear();
|
||||
|
|
|
@ -167,8 +167,13 @@ ParserImpl::FunctionHeader::FunctionHeader(const FunctionHeader&) = default;
|
|||
ParserImpl::FunctionHeader::FunctionHeader(Source src,
|
||||
std::string n,
|
||||
ast::VariableList p,
|
||||
type::Type* ret_ty)
|
||||
: source(src), name(n), params(p), return_type(ret_ty) {}
|
||||
type::Type* ret_ty,
|
||||
ast::DecorationList ret_decos)
|
||||
: source(src),
|
||||
name(n),
|
||||
params(p),
|
||||
return_type(ret_ty),
|
||||
return_type_decorations(ret_decos) {}
|
||||
|
||||
ParserImpl::FunctionHeader::~FunctionHeader() = default;
|
||||
|
||||
|
@ -1185,7 +1190,7 @@ Maybe<ast::Function*> ParserImpl::function_decl(ast::DecorationList& decos) {
|
|||
|
||||
return create<ast::Function>(
|
||||
header->source, builder_.Symbols().Register(header->name), header->params,
|
||||
header->return_type, body.value, decos);
|
||||
header->return_type, body.value, decos, header->return_type_decorations);
|
||||
}
|
||||
|
||||
// function_type_decl
|
||||
|
@ -1225,6 +1230,11 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
|
|||
if (!expect(use, Token::Type::kArrow))
|
||||
return Failure::kErrored;
|
||||
|
||||
auto decos = decoration_list();
|
||||
if (decos.errored) {
|
||||
return Failure::kErrored;
|
||||
}
|
||||
|
||||
auto type = function_type_decl();
|
||||
if (type.errored) {
|
||||
errored = true;
|
||||
|
@ -1235,8 +1245,8 @@ Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
|
|||
if (errored)
|
||||
return Failure::kErrored;
|
||||
|
||||
return FunctionHeader{source, name.value, std::move(params.value),
|
||||
type.value};
|
||||
return FunctionHeader{source, name.value, std::move(params.value), type.value,
|
||||
std::move(decos.value)};
|
||||
}
|
||||
|
||||
// param_list
|
||||
|
|
|
@ -218,10 +218,12 @@ class ParserImpl {
|
|||
/// @param n function name
|
||||
/// @param p function parameters
|
||||
/// @param ret_ty function return type
|
||||
/// @param ret_decos return type decorations
|
||||
FunctionHeader(Source src,
|
||||
std::string n,
|
||||
ast::VariableList p,
|
||||
type::Type* ret_ty);
|
||||
type::Type* ret_ty,
|
||||
ast::DecorationList ret_decos);
|
||||
/// Destructor
|
||||
~FunctionHeader();
|
||||
/// Assignment operator
|
||||
|
@ -237,6 +239,8 @@ class ParserImpl {
|
|||
ast::VariableList params;
|
||||
/// Function return type
|
||||
type::Type* return_type;
|
||||
/// Function return type decorations
|
||||
ast::DecorationList return_type_decorations;
|
||||
};
|
||||
|
||||
/// VarDeclInfo contains the parsed information for variable declaration.
|
||||
|
|
|
@ -173,6 +173,37 @@ fn main() -> void { return; })");
|
|||
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, FunctionDecl_ReturnTypeDecorationList) {
|
||||
auto p = parser("fn main() -> [[location(1)]] f32 { return 1.0; }");
|
||||
auto decos = p->decoration_list();
|
||||
EXPECT_FALSE(p->has_error()) << p->error();
|
||||
ASSERT_FALSE(decos.errored);
|
||||
EXPECT_FALSE(decos.matched);
|
||||
auto f = p->function_decl(decos.value);
|
||||
EXPECT_FALSE(p->has_error()) << p->error();
|
||||
EXPECT_FALSE(f.errored);
|
||||
EXPECT_TRUE(f.matched);
|
||||
ASSERT_NE(f.value, nullptr);
|
||||
|
||||
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
|
||||
ASSERT_NE(f->return_type(), nullptr);
|
||||
EXPECT_TRUE(f->return_type()->Is<type::F32>());
|
||||
ASSERT_EQ(f->params().size(), 0u);
|
||||
|
||||
auto& decorations = f->decorations();
|
||||
EXPECT_EQ(decorations.size(), 0u);
|
||||
|
||||
auto& ret_type_decorations = f->return_type_decorations();
|
||||
ASSERT_EQ(ret_type_decorations.size(), 1u);
|
||||
auto* loc = ret_type_decorations[0]->As<ast::LocationDecoration>();
|
||||
ASSERT_TRUE(loc != nullptr);
|
||||
EXPECT_EQ(loc->value(), 1u);
|
||||
|
||||
auto* body = f->body();
|
||||
ASSERT_EQ(body->size(), 1u);
|
||||
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
|
||||
auto p = parser("fn main() -> { }");
|
||||
auto decos = p->decoration_list();
|
||||
|
|
|
@ -33,6 +33,22 @@ TEST_F(ParserImplTest, FunctionHeader) {
|
|||
EXPECT_TRUE(f->return_type->Is<type::Void>());
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType) {
|
||||
auto p = parser("fn main() -> [[location(1)]] f32");
|
||||
auto f = p->function_header();
|
||||
ASSERT_FALSE(p->has_error()) << p->error();
|
||||
EXPECT_TRUE(f.matched);
|
||||
EXPECT_FALSE(f.errored);
|
||||
|
||||
EXPECT_EQ(f->name, "main");
|
||||
EXPECT_EQ(f->params.size(), 0u);
|
||||
EXPECT_TRUE(f->return_type->Is<type::F32>());
|
||||
ASSERT_TRUE(f->return_type_decorations.size() == 1u);
|
||||
auto* loc = f->return_type_decorations[0]->As<ast::LocationDecoration>();
|
||||
ASSERT_TRUE(loc != nullptr);
|
||||
EXPECT_EQ(loc->value(), 1u);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, FunctionHeader_MissingIdent) {
|
||||
auto p = parser("fn () -> void");
|
||||
auto f = p->function_header();
|
||||
|
|
|
@ -229,7 +229,8 @@ void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
|||
func->source(), ctx.Clone(func->symbol()), new_parameters,
|
||||
ctx.Clone(func->return_type()),
|
||||
ctx.dst->create<ast::BlockStatement>(new_body),
|
||||
ctx.Clone(func->decorations()));
|
||||
ctx.Clone(func->decorations()),
|
||||
ctx.Clone(func->return_type_decorations()));
|
||||
ctx.Replace(func, new_func);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -391,7 +391,8 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
|||
func->source(), ctx.Clone(func->symbol()), new_parameters,
|
||||
ctx.Clone(func->return_type()),
|
||||
ctx.dst->create<ast::BlockStatement>(new_body),
|
||||
ctx.Clone(func->decorations()));
|
||||
ctx.Clone(func->decorations()),
|
||||
ctx.Clone(func->return_type_decorations()));
|
||||
ctx.Replace(func, new_func);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,7 +138,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
|||
auto* new_func = ctx.dst->create<ast::Function>(
|
||||
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
|
||||
ctx.Clone(func->return_type()), ctx.Clone(func->body()),
|
||||
ctx.Clone(func->decorations()));
|
||||
ctx.Clone(func->decorations()),
|
||||
ctx.Clone(func->return_type_decorations()));
|
||||
ctx.Replace(func, new_func);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,8 +58,9 @@ ast::Function* Transform::CloneWithStatementsAtStart(
|
|||
auto* body = ctx->dst->create<ast::BlockStatement>(
|
||||
ctx->Clone(in->body()->source()), statements);
|
||||
auto decos = ctx->Clone(in->decorations());
|
||||
auto ret_decos = ctx->Clone(in->return_type_decorations());
|
||||
return ctx->dst->create<ast::Function>(source, symbol, params, return_type,
|
||||
body, decos);
|
||||
body, decos, ret_decos);
|
||||
}
|
||||
|
||||
void Transform::RenameReservedKeywords(CloneContext* ctx,
|
||||
|
|
|
@ -148,6 +148,41 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
false},
|
||||
DecorationTestParams{DecorationKind::kWorkgroup, true}));
|
||||
|
||||
using FunctionReturnTypeDecorationTest = ValidatorDecorationsTestWithParams;
|
||||
TEST_P(FunctionReturnTypeDecorationTest, Decoration_IsValid) {
|
||||
auto params = GetParam();
|
||||
|
||||
Func("main", ast::VariableList{}, ty.f32(),
|
||||
ast::StatementList{create<ast::ReturnStatement>(Expr(1.f))},
|
||||
ast::DecorationList{
|
||||
create<ast::StageDecoration>(ast::PipelineStage::kVertex)},
|
||||
ast::DecorationList{createDecoration(*this, params.kind)});
|
||||
|
||||
ValidatorImpl& v = Build();
|
||||
|
||||
if (params.should_pass) {
|
||||
EXPECT_TRUE(v.Validate());
|
||||
} else {
|
||||
EXPECT_FALSE(v.Validate());
|
||||
EXPECT_EQ(v.error(), "decoration is not valid for function return types");
|
||||
}
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ValidatorTest,
|
||||
FunctionReturnTypeDecorationTest,
|
||||
testing::Values(DecorationTestParams{DecorationKind::kAccess, false},
|
||||
DecorationTestParams{DecorationKind::kBinding, false},
|
||||
DecorationTestParams{DecorationKind::kBuiltin, true},
|
||||
DecorationTestParams{DecorationKind::kConstantId, false},
|
||||
DecorationTestParams{DecorationKind::kGroup, false},
|
||||
DecorationTestParams{DecorationKind::kLocation, true},
|
||||
DecorationTestParams{DecorationKind::kStage, false},
|
||||
DecorationTestParams{DecorationKind::kStride, false},
|
||||
DecorationTestParams{DecorationKind::kStructBlock, false},
|
||||
DecorationTestParams{DecorationKind::kStructMemberOffset,
|
||||
false},
|
||||
DecorationTestParams{DecorationKind::kWorkgroup, false}));
|
||||
|
||||
using StructDecorationTest = ValidatorDecorationsTestWithParams;
|
||||
TEST_P(StructDecorationTest, Decoration_IsValid) {
|
||||
auto params = GetParam();
|
||||
|
|
|
@ -248,6 +248,15 @@ bool ValidatorImpl::ValidateFunction(const ast::Function* func) {
|
|||
"non-void function must end with a return statement");
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* deco : current_function_->return_type_decorations()) {
|
||||
if (!(deco->Is<ast::BuiltinDecoration>() ||
|
||||
deco->Is<ast::LocationDecoration>())) {
|
||||
add_error(deco->source(),
|
||||
"decoration is not valid for function return types");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue