From 451f2cc68adcbfeb0a6870ec8b4187aae55fbfb8 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 12 May 2021 12:54:21 +0000 Subject: [PATCH] Add ast::DisableValidationDecoration An [[internal]] decoration that specifically disables certain validation checks. Begin with a single kFunctionHasNoBody mode. Migrate the Resolver to using this instead of allowing any InternalDecoration to disable the checks for no-body. Bug: tint:797 Change-Id: I213b9a6844a456775ede06d60e456d9f77a449d0 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50741 Auto-Submit: Ben Clayton Reviewed-by: James Price Commit-Queue: Ben Clayton --- src/BUILD.gn | 2 + src/CMakeLists.txt | 2 + src/ast/disable_validation_decoration.cc | 46 ++++++++ src/ast/disable_validation_decoration.h | 67 ++++++++++++ src/ast/internal_decoration.h | 1 - src/resolver/resolver.cc | 23 +++- src/transform/decompose_storage_access.cc | 15 ++- .../decompose_storage_access_test.cc | 102 +++++++++--------- src/writer/wgsl/generator_impl.cc | 33 +++--- .../wgsl/generator_impl_function_test.cc | 3 +- 10 files changed, 217 insertions(+), 77 deletions(-) create mode 100644 src/ast/disable_validation_decoration.cc create mode 100644 src/ast/disable_validation_decoration.h diff --git a/src/BUILD.gn b/src/BUILD.gn index 1c69ee82d7..844a7038f3 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -298,6 +298,8 @@ libtint_source_set("libtint_core_all_src") { "ast/continue_statement.h", "ast/decoration.cc", "ast/decoration.h", + "ast/disable_validation_decoration.cc", + "ast/disable_validation_decoration.h", "ast/depth_texture.cc", "ast/depth_texture.h", "ast/discard_statement.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7024c3282f..2bd1b09345 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -80,6 +80,8 @@ set(TINT_LIB_SRCS ast/continue_statement.h ast/decoration.cc ast/decoration.h + ast/disable_validation_decoration.cc + ast/disable_validation_decoration.h ast/depth_texture.cc ast/depth_texture.h ast/discard_statement.cc diff --git a/src/ast/disable_validation_decoration.cc b/src/ast/disable_validation_decoration.cc new file mode 100644 index 0000000000..829172ed8a --- /dev/null +++ b/src/ast/disable_validation_decoration.cc @@ -0,0 +1,46 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/disable_validation_decoration.h" +#include "src/clone_context.h" +#include "src/program_builder.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ast::DisableValidationDecoration); + +namespace tint { +namespace ast { + +DisableValidationDecoration::DisableValidationDecoration( + ProgramID program_id, + DisabledValidation validation) + : Base(program_id), validation_(validation) {} + +DisableValidationDecoration::~DisableValidationDecoration() = default; + +std::string DisableValidationDecoration::Name() const { + switch (validation_) { + case DisabledValidation::kFunctionHasNoBody: + return "disable_validation__function_has_no_body"; + } + return ""; +} + +DisableValidationDecoration* DisableValidationDecoration::Clone( + CloneContext* ctx) const { + return ctx->dst->ASTNodes().Create( + ctx->dst->ID(), validation_); +} + +} // namespace ast +} // namespace tint diff --git a/src/ast/disable_validation_decoration.h b/src/ast/disable_validation_decoration.h new file mode 100644 index 0000000000..ec4821be31 --- /dev/null +++ b/src/ast/disable_validation_decoration.h @@ -0,0 +1,67 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_AST_DISABLE_VALIDATION_DECORATION_H_ +#define SRC_AST_DISABLE_VALIDATION_DECORATION_H_ + +#include + +#include "src/ast/internal_decoration.h" + +namespace tint { +namespace ast { + +/// Enumerator of validation features that can be disabled with a +/// DisableValidationDecoration decoration. +enum class DisabledValidation { + /// When applied to a function, the validator will not complain there is no + /// body to a function. + kFunctionHasNoBody, +}; + +/// An internal decoration used to tell the validator to ignore specific +/// violations. Typically generated by transforms that need to produce ASTs that +/// would otherwise cause validation errors. +class DisableValidationDecoration + : public Castable { + public: + /// Constructor + /// @param program_id the identifier of the program that owns this node + /// @param validation the validation to disable + explicit DisableValidationDecoration(ProgramID program_id, + DisabledValidation validation); + + /// Destructor + ~DisableValidationDecoration() override; + + /// @return the validation that this decoration disables + DisabledValidation Validation() const { return validation_; } + + /// @return a short description of the internal decoration which will be + /// displayed in WGSL as `[[internal()]]` (but is not parsable). + std::string Name() const override; + + /// Performs a deep clone of this object using the CloneContext `ctx`. + /// @param ctx the clone context + /// @return the newly cloned object + DisableValidationDecoration* Clone(CloneContext* ctx) const override; + + private: + DisabledValidation const validation_; +}; + +} // namespace ast +} // namespace tint + +#endif // SRC_AST_DISABLE_VALIDATION_DECORATION_H_ diff --git a/src/ast/internal_decoration.h b/src/ast/internal_decoration.h index 1164f7f0f7..94ef1aebe0 100644 --- a/src/ast/internal_decoration.h +++ b/src/ast/internal_decoration.h @@ -25,7 +25,6 @@ namespace ast { /// A decoration used to indicate that a function is tint-internal. /// These decorations are not produced by generators, but instead are usually /// created by transforms for consumption by a particular backend. -/// Functions annotated with this decoration will have relaxed validation. class InternalDecoration : public Castable { public: /// Constructor diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 8095beaa8b..cede2c521d 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -26,6 +26,7 @@ #include "src/ast/call_statement.h" #include "src/ast/continue_statement.h" #include "src/ast/depth_texture.h" +#include "src/ast/disable_validation_decoration.h" #include "src/ast/discard_statement.h" #include "src/ast/fallthrough_statement.h" #include "src/ast/if_statement.h" @@ -121,6 +122,21 @@ bool IsValidStorageTextureImageFormat(ast::ImageFormat format) { } } +/// @returns true if the decoration list contains a +/// ast::DisableValidationDecoration with the validation mode equal to +/// `validation` +bool IsValidationDisabled(const ast::DecorationList& decorations, + ast::DisabledValidation validation) { + for (auto* decoration : decorations) { + if (auto* dv = decoration->As()) { + if (dv->Validation() == validation) { + return true; + } + } + } + return false; +} + } // namespace Resolver::Resolver(ProgramBuilder* builder) @@ -749,11 +765,12 @@ bool Resolver::ValidateFunction(const ast::Function* func, func->source()); return false; } - } else if (!ast::HasDecoration( - func->decorations())) { + } else if (!IsValidationDisabled( + func->decorations(), + ast::DisabledValidation::kFunctionHasNoBody)) { TINT_ICE(diagnostics_) << "Function " << builder_->Symbols().NameFor(func->symbol()) - << " has no body and does not have the [[internal]] decoration"; + << " has no body"; } for (auto* deco : func->return_type_decorations()) { diff --git a/src/transform/decompose_storage_access.cc b/src/transform/decompose_storage_access.cc index 2f9f07a473..75d98f51c0 100644 --- a/src/transform/decompose_storage_access.cc +++ b/src/transform/decompose_storage_access.cc @@ -22,6 +22,7 @@ #include "src/ast/assignment_statement.h" #include "src/ast/call_statement.h" +#include "src/ast/disable_validation_decoration.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type_name.h" #include "src/program_builder.h" @@ -435,7 +436,12 @@ struct DecomposeStorageAccess::State { auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty); func = ctx.dst->create( ctx.dst->Sym(), params, el_ast_ty, nullptr, - ast::DecorationList{intrinsic}, ast::DecorationList{}); + ast::DecorationList{ + intrinsic, + ctx.dst->ASTNodes().Create( + ctx.dst->ID(), ast::DisabledValidation::kFunctionHasNoBody), + }, + ast::DecorationList{}); } else { ast::ExpressionList values; if (auto* mat_ty = el_ty->As()) { @@ -502,7 +508,12 @@ struct DecomposeStorageAccess::State { if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, el_ty)) { func = ctx.dst->create( ctx.dst->Sym(), params, ctx.dst->ty.void_(), nullptr, - ast::DecorationList{intrinsic}, ast::DecorationList{}); + ast::DecorationList{ + intrinsic, + ctx.dst->ASTNodes().Create( + ctx.dst->ID(), ast::DisabledValidation::kFunctionHasNoBody), + }, + ast::DecorationList{}); } else { ast::StatementList body; diff --git a/src/transform/decompose_storage_access_test.cc b/src/transform/decompose_storage_access_test.cc index 9de479c079..62fa75c7ce 100644 --- a/src/transform/decompose_storage_access_test.cc +++ b/src/transform/decompose_storage_access_test.cc @@ -106,40 +106,40 @@ struct SB { v : array, 2>; }; -[[internal(intrinsic_load_i32)]] +[[internal(intrinsic_load_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol(buffer : [[access(read_write)]] SB, offset : u32) -> i32 -[[internal(intrinsic_load_u32)]] +[[internal(intrinsic_load_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_1(buffer : [[access(read_write)]] SB, offset : u32) -> u32 -[[internal(intrinsic_load_f32)]] +[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_2(buffer : [[access(read_write)]] SB, offset : u32) -> f32 -[[internal(intrinsic_load_vec2_i32)]] +[[internal(intrinsic_load_vec2_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_3(buffer : [[access(read_write)]] SB, offset : u32) -> vec2 -[[internal(intrinsic_load_vec2_u32)]] +[[internal(intrinsic_load_vec2_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_4(buffer : [[access(read_write)]] SB, offset : u32) -> vec2 -[[internal(intrinsic_load_vec2_f32)]] +[[internal(intrinsic_load_vec2_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_5(buffer : [[access(read_write)]] SB, offset : u32) -> vec2 -[[internal(intrinsic_load_vec3_i32)]] +[[internal(intrinsic_load_vec3_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_6(buffer : [[access(read_write)]] SB, offset : u32) -> vec3 -[[internal(intrinsic_load_vec3_u32)]] +[[internal(intrinsic_load_vec3_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_7(buffer : [[access(read_write)]] SB, offset : u32) -> vec3 -[[internal(intrinsic_load_vec3_f32)]] +[[internal(intrinsic_load_vec3_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_8(buffer : [[access(read_write)]] SB, offset : u32) -> vec3 -[[internal(intrinsic_load_vec4_i32)]] +[[internal(intrinsic_load_vec4_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_9(buffer : [[access(read_write)]] SB, offset : u32) -> vec4 -[[internal(intrinsic_load_vec4_u32)]] +[[internal(intrinsic_load_vec4_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_10(buffer : [[access(read_write)]] SB, offset : u32) -> vec4 -[[internal(intrinsic_load_vec4_f32)]] +[[internal(intrinsic_load_vec4_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_11(buffer : [[access(read_write)]] SB, offset : u32) -> vec4 fn tint_symbol_12(buffer : [[access(read_write)]] SB, offset : u32) -> mat2x2 { @@ -300,40 +300,40 @@ struct SB { v : array, 2>; }; -[[internal(intrinsic_store_i32)]] +[[internal(intrinsic_store_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol(buffer : [[access(read_write)]] SB, offset : u32, value : i32) -[[internal(intrinsic_store_u32)]] +[[internal(intrinsic_store_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_1(buffer : [[access(read_write)]] SB, offset : u32, value : u32) -[[internal(intrinsic_store_f32)]] +[[internal(intrinsic_store_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_2(buffer : [[access(read_write)]] SB, offset : u32, value : f32) -[[internal(intrinsic_store_vec2_u32)]] +[[internal(intrinsic_store_vec2_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_3(buffer : [[access(read_write)]] SB, offset : u32, value : vec2) -[[internal(intrinsic_store_vec2_f32)]] +[[internal(intrinsic_store_vec2_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_4(buffer : [[access(read_write)]] SB, offset : u32, value : vec2) -[[internal(intrinsic_store_vec2_i32)]] +[[internal(intrinsic_store_vec2_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_5(buffer : [[access(read_write)]] SB, offset : u32, value : vec2) -[[internal(intrinsic_store_vec3_u32)]] +[[internal(intrinsic_store_vec3_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_6(buffer : [[access(read_write)]] SB, offset : u32, value : vec3) -[[internal(intrinsic_store_vec3_f32)]] +[[internal(intrinsic_store_vec3_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_7(buffer : [[access(read_write)]] SB, offset : u32, value : vec3) -[[internal(intrinsic_store_vec3_i32)]] +[[internal(intrinsic_store_vec3_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_8(buffer : [[access(read_write)]] SB, offset : u32, value : vec3) -[[internal(intrinsic_store_vec4_u32)]] +[[internal(intrinsic_store_vec4_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_9(buffer : [[access(read_write)]] SB, offset : u32, value : vec4) -[[internal(intrinsic_store_vec4_f32)]] +[[internal(intrinsic_store_vec4_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_10(buffer : [[access(read_write)]] SB, offset : u32, value : vec4) -[[internal(intrinsic_store_vec4_i32)]] +[[internal(intrinsic_store_vec4_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_11(buffer : [[access(read_write)]] SB, offset : u32, value : vec4) fn tint_symbol_12(buffer : [[access(read_write)]] SB, offset : u32, value : mat2x2) { @@ -492,40 +492,40 @@ struct SB { v : array, 2>; }; -[[internal(intrinsic_load_i32)]] +[[internal(intrinsic_load_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol(buffer : [[access(read_write)]] SB, offset : u32) -> i32 -[[internal(intrinsic_load_u32)]] +[[internal(intrinsic_load_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_1(buffer : [[access(read_write)]] SB, offset : u32) -> u32 -[[internal(intrinsic_load_f32)]] +[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_2(buffer : [[access(read_write)]] SB, offset : u32) -> f32 -[[internal(intrinsic_load_vec2_i32)]] +[[internal(intrinsic_load_vec2_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_3(buffer : [[access(read_write)]] SB, offset : u32) -> vec2 -[[internal(intrinsic_load_vec2_u32)]] +[[internal(intrinsic_load_vec2_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_4(buffer : [[access(read_write)]] SB, offset : u32) -> vec2 -[[internal(intrinsic_load_vec2_f32)]] +[[internal(intrinsic_load_vec2_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_5(buffer : [[access(read_write)]] SB, offset : u32) -> vec2 -[[internal(intrinsic_load_vec3_i32)]] +[[internal(intrinsic_load_vec3_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_6(buffer : [[access(read_write)]] SB, offset : u32) -> vec3 -[[internal(intrinsic_load_vec3_u32)]] +[[internal(intrinsic_load_vec3_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_7(buffer : [[access(read_write)]] SB, offset : u32) -> vec3 -[[internal(intrinsic_load_vec3_f32)]] +[[internal(intrinsic_load_vec3_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_8(buffer : [[access(read_write)]] SB, offset : u32) -> vec3 -[[internal(intrinsic_load_vec4_i32)]] +[[internal(intrinsic_load_vec4_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_9(buffer : [[access(read_write)]] SB, offset : u32) -> vec4 -[[internal(intrinsic_load_vec4_u32)]] +[[internal(intrinsic_load_vec4_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_10(buffer : [[access(read_write)]] SB, offset : u32) -> vec4 -[[internal(intrinsic_load_vec4_f32)]] +[[internal(intrinsic_load_vec4_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_11(buffer : [[access(read_write)]] SB, offset : u32) -> vec4 fn tint_symbol_12(buffer : [[access(read_write)]] SB, offset : u32) -> mat2x2 { @@ -648,40 +648,40 @@ struct SB { v : array, 2>; }; -[[internal(intrinsic_store_i32)]] +[[internal(intrinsic_store_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol(buffer : [[access(read_write)]] SB, offset : u32, value : i32) -[[internal(intrinsic_store_u32)]] +[[internal(intrinsic_store_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_1(buffer : [[access(read_write)]] SB, offset : u32, value : u32) -[[internal(intrinsic_store_f32)]] +[[internal(intrinsic_store_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_2(buffer : [[access(read_write)]] SB, offset : u32, value : f32) -[[internal(intrinsic_store_vec2_u32)]] +[[internal(intrinsic_store_vec2_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_3(buffer : [[access(read_write)]] SB, offset : u32, value : vec2) -[[internal(intrinsic_store_vec2_f32)]] +[[internal(intrinsic_store_vec2_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_4(buffer : [[access(read_write)]] SB, offset : u32, value : vec2) -[[internal(intrinsic_store_vec2_i32)]] +[[internal(intrinsic_store_vec2_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_5(buffer : [[access(read_write)]] SB, offset : u32, value : vec2) -[[internal(intrinsic_store_vec3_u32)]] +[[internal(intrinsic_store_vec3_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_6(buffer : [[access(read_write)]] SB, offset : u32, value : vec3) -[[internal(intrinsic_store_vec3_f32)]] +[[internal(intrinsic_store_vec3_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_7(buffer : [[access(read_write)]] SB, offset : u32, value : vec3) -[[internal(intrinsic_store_vec3_i32)]] +[[internal(intrinsic_store_vec3_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_8(buffer : [[access(read_write)]] SB, offset : u32, value : vec3) -[[internal(intrinsic_store_vec4_u32)]] +[[internal(intrinsic_store_vec4_u32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_9(buffer : [[access(read_write)]] SB, offset : u32, value : vec4) -[[internal(intrinsic_store_vec4_f32)]] +[[internal(intrinsic_store_vec4_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_10(buffer : [[access(read_write)]] SB, offset : u32, value : vec4) -[[internal(intrinsic_store_vec4_i32)]] +[[internal(intrinsic_store_vec4_i32), internal(disable_validation__function_has_no_body)]] fn tint_symbol_11(buffer : [[access(read_write)]] SB, offset : u32, value : vec4) fn tint_symbol_12(buffer : [[access(read_write)]] SB, offset : u32, value : mat2x2) { @@ -837,7 +837,7 @@ struct SB { b : [[stride(256)]] array; }; -[[internal(intrinsic_load_f32)]] +[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol(buffer : [[access(read_write)]] SB, offset : u32) -> f32 [[group(0), binding(0)]] var sb : [[access(read_write)]] SB; @@ -905,7 +905,7 @@ struct SB { b : [[stride(256)]] array; }; -[[internal(intrinsic_load_f32)]] +[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol(buffer : [[access(read_write)]] SB, offset : u32) -> f32 [[group(0), binding(0)]] var sb : [[access(read_write)]] SB; @@ -992,7 +992,7 @@ struct SB { b : A2_Array; }; -[[internal(intrinsic_load_f32)]] +[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]] fn tint_symbol(buffer : [[access(read_write)]] SB, offset : u32) -> f32 [[group(0), binding(0)]] var sb : [[access(read_write)]] SB; diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index b337f95f40..ac0146ed5b 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -289,24 +289,10 @@ bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) { } bool GeneratorImpl::EmitFunction(ast::Function* func) { - for (auto* deco : func->decorations()) { + if (func->decorations().size()) { make_indent(); - out_ << "[["; - if (auto* workgroup = deco->As()) { - uint32_t x = 0; - uint32_t y = 0; - uint32_t z = 0; - std::tie(x, y, z) = workgroup->values(); - out_ << "workgroup_size(" << std::to_string(x) << ", " - << std::to_string(y) << ", " << std::to_string(z) << ")"; - } - if (auto* stage = deco->As()) { - out_ << "stage(" << stage->value() << ")"; - } - if (auto* internal = deco->As()) { - out_ << "internal(" << internal->Name() << ")"; - } - out_ << "]]" << std::endl; + EmitDecorations(func->decorations()); + out_ << std::endl; } make_indent(); @@ -629,7 +615,16 @@ bool GeneratorImpl::EmitDecorations(const ast::DecorationList& decos) { } first = false; - if (auto* binding = deco->As()) { + if (auto* workgroup = deco->As()) { + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = workgroup->values(); + out_ << "workgroup_size(" << std::to_string(x) << ", " + << std::to_string(y) << ", " << std::to_string(z) << ")"; + } else if (auto* stage = deco->As()) { + out_ << "stage(" << stage->value() << ")"; + } else if (auto* binding = deco->As()) { out_ << "binding(" << binding->value() << ")"; } else if (auto* group = deco->As()) { out_ << "group(" << group->value() << ")"; @@ -643,6 +638,8 @@ bool GeneratorImpl::EmitDecorations(const ast::DecorationList& decos) { out_ << "size(" << size->size() << ")"; } else if (auto* align = deco->As()) { out_ << "align(" << align->align() << ")"; + } else if (auto* internal = deco->As()) { + out_ << "internal(" << internal->Name() << ")"; } else { TINT_ICE(diagnostics_) << "Unsupported decoration '" << deco->TypeInfo().name << "'"; diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index bf831b0495..5630bb9b72 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -130,8 +130,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) { gen.increment_indent(); ASSERT_TRUE(gen.EmitFunction(func)); - EXPECT_EQ(gen.result(), R"( [[stage(fragment)]] - [[workgroup_size(2, 4, 6)]] + EXPECT_EQ(gen.result(), R"( [[stage(fragment), workgroup_size(2, 4, 6)]] fn my_func() { discard; return;