diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc index 0a0998fba8..c8385e535e 100644 --- a/src/reader/wgsl/parser_impl.cc +++ b/src/reader/wgsl/parser_impl.cc @@ -1392,9 +1392,19 @@ Maybe ParserImpl::function_header() { } return_decorations = decos.value; + // Apply stride decorations to the type node instead of the function. + ast::DecorationList type_decorations; + auto itr = std::find_if( + return_decorations.begin(), return_decorations.end(), + [](auto* deco) { return Is(deco); }); + if (itr != return_decorations.end()) { + type_decorations.emplace_back(*itr); + return_decorations.erase(itr); + } + auto tok = peek(); - auto type = type_decl(); + auto type = type_decl(type_decorations); if (type.errored) { errored = true; } else if (!type.matched) { @@ -3147,23 +3157,6 @@ Maybe ParserImpl::decoration() { return Failure::kNoMatch; } -template -std::vector ParserImpl::take_decorations(ast::DecorationList& in) { - ast::DecorationList remaining; - std::vector out; - out.reserve(in.size()); - for (auto* deco : in) { - if (auto* t = deco->As()) { - out.emplace_back(t); - } else { - remaining.emplace_back(deco); - } - } - - in = std::move(remaining); - return out; -} - bool ParserImpl::expect_decorations_consumed(const ast::DecorationList& in) { if (in.empty()) { return true; diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h index 99da98f693..232d952e72 100644 --- a/src/reader/wgsl/parser_impl.h +++ b/src/reader/wgsl/parser_impl.h @@ -839,11 +839,6 @@ class ParserImpl { template > T without_error(F&& func); - /// Returns all the decorations taken from `list` that matches the type `T`. - /// Those that do not match are kept in `list`. - template - std::vector take_decorations(ast::DecorationList& list); - /// Reports an error if the decoration list `list` is not empty. /// Used to ensure that all decorations are consumed. bool expect_decorations_consumed(const ast::DecorationList& list); diff --git a/src/reader/wgsl/parser_impl_function_header_test.cc b/src/reader/wgsl/parser_impl_function_header_test.cc index 0cef4aaff9..f15699ffba 100644 --- a/src/reader/wgsl/parser_impl_function_header_test.cc +++ b/src/reader/wgsl/parser_impl_function_header_test.cc @@ -61,6 +61,27 @@ TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType) { EXPECT_EQ(loc->value(), 1u); } +TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType_WithArrayStride) { + auto p = parser("fn main() -> [[location(1), stride(16)]] array"); + auto f = p->function_header(); + ASSERT_FALSE(p->has_error()) << p->error(); + EXPECT_TRUE(f.matched); + EXPECT_FALSE(f.errored); + + EXPECT_EQ(f->name, "main"); + EXPECT_EQ(f->params.size(), 0u); + ASSERT_EQ(f->return_type_decorations.size(), 1u); + auto* loc = f->return_type_decorations[0]->As(); + ASSERT_TRUE(loc != nullptr); + EXPECT_EQ(loc->value(), 1u); + + auto* array_type = f->return_type->As(); + ASSERT_EQ(array_type->decorations().size(), 1u); + auto* stride = array_type->decorations()[0]->As(); + ASSERT_TRUE(stride != nullptr); + EXPECT_EQ(stride->stride(), 16u); +} + TEST_F(ParserImplTest, FunctionHeader_MissingIdent) { auto p = parser("fn ()"); auto f = p->function_header();