From d1f0a14563f114af073165cb7f89e391db80de09 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Tue, 23 Nov 2021 21:46:48 +0000 Subject: [PATCH] resolver: Track global uses in function decorations Fixed: tint:1320 Change-Id: Ib92c37d4de0641d11e508be4d8e05d641e808be9 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/70662 Reviewed-by: James Price Reviewed-by: David Neto Kokoro: Kokoro Commit-Queue: Ben Clayton --- docs/origin-trial-changes.md | 4 ++++ src/resolver/function_validation_test.cc | 23 +++++++++++++++++------ src/resolver/resolver.cc | 20 ++++++++++---------- src/resolver/resolver.h | 5 +++-- src/sem/function.cc | 6 +++--- src/sem/function.h | 12 ++++++++---- src/transform/single_entry_point_test.cc | 20 ++++++++++++++++++++ 7 files changed, 65 insertions(+), 25 deletions(-) diff --git a/docs/origin-trial-changes.md b/docs/origin-trial-changes.md index 3437a84c7b..d34c2bd185 100644 --- a/docs/origin-trial-changes.md +++ b/docs/origin-trial-changes.md @@ -16,6 +16,10 @@ * The `dot()` builtin now supports integer vector types. * Identifiers can now start with a single leading underscore. [tint:1292](https://crbug.com/tint/1292) +### Fixes + +* Fixed an issue where using a module-scoped `let` in a `workgroup_size` may result in a compilation error. [tint:1320](https://crbug.com/tint/1320) + ## Changes for M97 ### Breaking Changes diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc index af3b0975b1..302ad88608 100644 --- a/src/resolver/function_validation_test.cc +++ b/src/resolver/function_validation_test.cc @@ -427,15 +427,26 @@ TEST_F(ResolverFunctionValidationTest, FunctionParamsConst) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) { // let x = 4u; // let x = 8u; - // [[stage(compute), workgroup_size(x, y, 16u] + // [[stage(compute), workgroup_size(x, y, 16u)]] // fn main() {} - GlobalConst("x", ty.u32(), Expr(4u)); - GlobalConst("y", ty.u32(), Expr(8u)); - Func("main", {}, ty.void_(), {}, - {Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Expr("x"), Expr("y"), Expr(16u))}); + auto* x = GlobalConst("x", ty.u32(), Expr(4u)); + auto* y = GlobalConst("y", ty.u32(), Expr(8u)); + auto* func = Func("main", {}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Expr("x"), Expr("y"), Expr(16u))}); ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem_func = Sem().Get(func); + auto* sem_x = Sem().Get(x); + auto* sem_y = Sem().Get(y); + + ASSERT_NE(sem_func, nullptr); + ASSERT_NE(sem_x, nullptr); + ASSERT_NE(sem_y, nullptr); + + EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_x)); + EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y)); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) { diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 0e0abc9e7f..3eb96a3a24 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -634,21 +634,19 @@ sem::Function* Resolver::Function(const ast::Function* decl) { } } - sem::WorkgroupSize ws{}; - if (!WorkgroupSizeFor(decl, ws)) { + auto* func = builder_->create(decl, return_type, parameters); + builder_->Sem().Add(decl, func); + + TINT_SCOPED_ASSIGNMENT(current_function_, func); + + if (!WorkgroupSize(decl)) { return nullptr; } - auto* func = - builder_->create(decl, return_type, parameters, ws); - builder_->Sem().Add(decl, func); - if (decl->IsEntryPoint()) { entry_points_.emplace_back(func); } - TINT_SCOPED_ASSIGNMENT(current_function_, func); - if (decl->body) { Mark(decl->body); if (current_compound_statement_) { @@ -692,9 +690,9 @@ sem::Function* Resolver::Function(const ast::Function* decl) { return func; } -bool Resolver::WorkgroupSizeFor(const ast::Function* func, - sem::WorkgroupSize& ws) { +bool Resolver::WorkgroupSize(const ast::Function* func) { // Set work-group size defaults. + sem::WorkgroupSize ws; for (int i = 0; i < 3; i++) { ws[i].value = 1; ws[i].overridable_const = nullptr; @@ -790,6 +788,8 @@ bool Resolver::WorkgroupSizeFor(const ast::Function* func, ws[i].value = is_i32 ? static_cast(value.Elements()[0].i32) : value.Elements()[0].u32; } + + current_function_->SetWorkgroupSize(std::move(ws)); return true; } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index cca4440250..37dea95ae3 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -282,8 +282,9 @@ class Resolver { bool IsValidationEnabled(const ast::DecorationList& decorations, ast::DisabledValidation validation) const; - /// Resolves the WorkgroupSize for the given function - bool WorkgroupSizeFor(const ast::Function*, sem::WorkgroupSize& ws); + /// Resolves the WorkgroupSize for the given function, assigning it to + /// current_function_ + bool WorkgroupSize(const ast::Function*); /// @returns the sem::Type for the ast::Type `ty`, building it if it /// hasn't been constructed already. If an error is raised, nullptr is diff --git a/src/sem/function.cc b/src/sem/function.cc index 709dd03dfc..4e95ff6180 100644 --- a/src/sem/function.cc +++ b/src/sem/function.cc @@ -30,11 +30,11 @@ namespace sem { Function::Function(const ast::Function* declaration, Type* return_type, - std::vector parameters, - sem::WorkgroupSize workgroup_size) + std::vector parameters) : Base(return_type, utils::ToConstPtrVec(parameters)), declaration_(declaration), - workgroup_size_(std::move(workgroup_size)) { + workgroup_size_{WorkgroupDimension{1}, WorkgroupDimension{1}, + WorkgroupDimension{1}} { for (auto* parameter : parameters) { parameter->SetOwner(this); } diff --git a/src/sem/function.h b/src/sem/function.h index d8a58a890c..ea834a7c54 100644 --- a/src/sem/function.h +++ b/src/sem/function.h @@ -62,11 +62,9 @@ class Function : public Castable { /// @param declaration the ast::Function /// @param return_type the return type of the function /// @param parameters the parameters to the function - /// @param workgroup_size the workgroup size Function(const ast::Function* declaration, Type* return_type, - std::vector parameters, - sem::WorkgroupSize workgroup_size); + std::vector parameters); /// Destructor ~Function() override; @@ -77,6 +75,12 @@ class Function : public Castable { /// @returns the workgroup size {x, y, z} for the function. const sem::WorkgroupSize& WorkgroupSize() const { return workgroup_size_; } + /// Sets the workgroup size {x, y, z} for the function. + /// @param workgroup_size the new workgroup size of the function + void SetWorkgroupSize(sem::WorkgroupSize workgroup_size) { + workgroup_size_ = std::move(workgroup_size); + } + /// @returns all directly referenced global variables const utils::UniqueVector& DirectlyReferencedGlobals() const { @@ -243,8 +247,8 @@ class Function : public Castable { bool multisampled) const; const ast::Function* const declaration_; - const sem::WorkgroupSize workgroup_size_; + sem::WorkgroupSize workgroup_size_; utils::UniqueVector directly_referenced_globals_; utils::UniqueVector transitively_referenced_globals_; utils::UniqueVector transitively_called_functions_; diff --git a/src/transform/single_entry_point_test.cc b/src/transform/single_entry_point_test.cc index cdab09423a..b2193b1391 100644 --- a/src/transform/single_entry_point_test.cc +++ b/src/transform/single_entry_point_test.cc @@ -236,6 +236,26 @@ fn comp_main1() { EXPECT_EQ(expect, str(got)); } +TEST_F(SingleEntryPointTest, WorkgroupSizeLetPreserved) { + auto* src = R"( +let size : i32 = 1; + +[[stage(compute), workgroup_size(size)]] +fn main() { +} +)"; + + auto* expect = src; + + SingleEntryPoint::Config cfg("main"); + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + TEST_F(SingleEntryPointTest, OverridableConstants) { auto* src = R"( [[override(1001)]] let c1 : u32 = 1u;