diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7f0e5b12c4..d23767aa88 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -473,6 +473,7 @@ if(${TINT_BUILD_TESTS}) resolver/resolver_test_helper.h resolver/resolver_test.cc resolver/struct_layout_test.cc + resolver/struct_storage_class_use_test.cc resolver/validation_test.cc scope_stack_test.cc semantic/sem_intrinsic_test.cc diff --git a/src/ast/storage_class.h b/src/ast/storage_class.h index 09bb41ca6f..4232aa5788 100644 --- a/src/ast/storage_class.h +++ b/src/ast/storage_class.h @@ -34,6 +34,12 @@ enum class StorageClass { kFunction }; +/// @returns true if the StorageClass is host-sharable +/// @see https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable +inline bool IsHostSharable(StorageClass sc) { + return sc == ast::StorageClass::kUniform || sc == ast::StorageClass::kStorage; +} + std::ostream& operator<<(std::ostream& out, StorageClass sc); } // namespace ast diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 53a69f9d23..6e16b2ef45 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -151,6 +151,11 @@ bool Resolver::ResolveInternal() { return false; } } + + if (!ApplyStorageClassUsageToType(var->declared_storage_class(), + var->type())) { + return false; + } } if (!Functions(builder_->AST().Functions())) { @@ -200,16 +205,6 @@ bool Resolver::BlockStatement(const ast::BlockStatement* stmt) { bool Resolver::Statements(const ast::StatementList& stmts) { for (auto* stmt : stmts) { - if (auto* decl = stmt->As()) { - if (!VariableDeclStatement(decl)) { - return false; - } - } - - if (!VariableStorageClass(stmt)) { - return false; - } - if (!Statement(stmt)) { return false; } @@ -217,36 +212,6 @@ bool Resolver::Statements(const ast::StatementList& stmts) { return true; } -bool Resolver::VariableStorageClass(ast::Statement* stmt) { - auto* var_decl = stmt->As(); - if (var_decl == nullptr) { - return true; - } - - auto* var = var_decl->variable(); - - auto* info = CreateVariableInfo(var); - variable_to_info_.emplace(var, info); - - // Nothing to do for const - if (var->is_const()) { - return true; - } - - if (info->storage_class == ast::StorageClass::kFunction) { - return true; - } - - if (info->storage_class != ast::StorageClass::kNone) { - diagnostics_.add_error("function variable has a non-function storage class", - stmt->source()); - return false; - } - - info->storage_class = ast::StorageClass::kFunction; - return true; -} - bool Resolver::Statement(ast::Statement* stmt) { auto* sem_statement = builder_->create(stmt); @@ -336,10 +301,7 @@ bool Resolver::Statement(ast::Statement* stmt) { return true; } if (auto* v = stmt->As()) { - variable_stack_.set(v->variable()->symbol(), - variable_to_info_.at(v->variable())); - current_block_->decls.push_back(v->variable()); - return Expression(v->variable()->constructor()); + return VariableDeclStatement(v); } diagnostics_.add_error( @@ -1118,21 +1080,44 @@ bool Resolver::UnaryOp(ast::UnaryOpExpression* expr) { } bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { - auto* ctor = stmt->variable()->constructor(); - if (!ctor) { - return true; - } - - if (auto* sce = ctor->As()) { - auto* lhs_type = stmt->variable()->type()->UnwrapAliasIfNeeded(); - auto* rhs_type = sce->literal()->type()->UnwrapAliasIfNeeded(); - - if (lhs_type != rhs_type) { - diagnostics_.add_error( - "constructor expression type does not match variable type", - stmt->source()); + if (auto* ctor = stmt->variable()->constructor()) { + if (!Expression(ctor)) { return false; } + if (auto* sce = ctor->As()) { + auto* lhs_type = stmt->variable()->type()->UnwrapAliasIfNeeded(); + auto* rhs_type = sce->literal()->type()->UnwrapAliasIfNeeded(); + + if (lhs_type != rhs_type) { + diagnostics_.add_error( + "constructor expression type does not match variable type", + stmt->source()); + return false; + } + } + } + + auto* var = stmt->variable(); + + auto* info = CreateVariableInfo(var); + variable_to_info_.emplace(var, info); + variable_stack_.set(var->symbol(), info); + current_block_->decls.push_back(var); + + if (!var->is_const()) { + if (info->storage_class != ast::StorageClass::kFunction) { + if (info->storage_class != ast::StorageClass::kNone) { + diagnostics_.add_error( + "function variable has a non-function storage class", + stmt->source()); + return false; + } + info->storage_class = ast::StorageClass::kFunction; + } + } + + if (!ApplyStorageClassUsageToType(info->storage_class, var->type())) { + return false; } return true; @@ -1247,9 +1232,10 @@ void Resolver::CreateSemanticNodes() const { for (auto it : struct_info_) { auto* str = it.first; auto* info = it.second; - builder_->Sem().Add(str, builder_->create( - str, std::move(info->members), info->align, - info->size, info->size_no_padding)); + builder_->Sem().Add( + str, builder_->create( + str, std::move(info->members), info->align, info->size, + info->size_no_padding, info->storage_class_usage)); } } @@ -1470,6 +1456,44 @@ Resolver::StructInfo* Resolver::Structure(type::Struct* str) { return info; } +bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, + type::Type* ty) { + ty = ty->UnwrapAliasIfNeeded(); + + if (auto* str = ty->As()) { + auto* info = Structure(str); + if (!info) { + return false; + } + if (info->storage_class_usage.count(sc)) { + return true; // Already applied + } + info->storage_class_usage.emplace(sc); + for (auto* member : str->impl()->members()) { + // TODO(amaiorano): Determine the host-sharable types + bool can_be_host_sharable = true; + if (ast::IsHostSharable(sc) && !can_be_host_sharable) { + std::stringstream err; + err << "Structure '" << str->FriendlyName(builder_->Symbols()) + << "' is used by storage class " << sc + << " which contains a member of non-host-sharable type " + << member->type()->FriendlyName(builder_->Symbols()); + diagnostics_.add_error(err.str(), member->source()); + return false; + } + if (!ApplyStorageClassUsageToType(sc, member->type())) { + return false; + } + } + } + + if (auto* arr = ty->As()) { + return ApplyStorageClassUsageToType(sc, arr->type()); + } + + return true; +} + template bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) { BlockInfo block_info(type, current_block_); diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 4bbe5049b2..7bb56b6de2 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "src/intrinsic_table.h" @@ -124,6 +125,7 @@ class Resolver { uint32_t align = 0; uint32_t size = 0; uint32_t size_no_padding = 0; + std::unordered_set storage_class_usage; }; /// Structure holding semantic information about a block (i.e. scope), such as @@ -206,7 +208,6 @@ class Resolver { bool Statements(const ast::StatementList&); bool UnaryOp(ast::UnaryOpExpression*); bool VariableDeclStatement(const ast::VariableDeclStatement*); - bool VariableStorageClass(ast::Statement*); /// @returns the semantic information for the array `arr`, building it if it /// hasn't been constructed already. If an error is raised, nullptr is @@ -217,6 +218,12 @@ class Resolver { /// been constructed already. If an error is raised, nullptr is returned. StructInfo* Structure(type::Struct* str); + /// Records the storage class usage for the given type, and any transient + /// dependencies of the type. Validates that the type can be used for the + /// given storage class, erroring if it cannot. + /// @returns true on success, false on error + bool ApplyStorageClassUsageToType(ast::StorageClass, type::Type*); + /// @param align the output default alignment in bytes for the type `ty` /// @param size the output default size in bytes for the type `ty` /// @returns true on success, false on error diff --git a/src/resolver/struct_storage_class_use_test.cc b/src/resolver/struct_storage_class_use_test.cc new file mode 100644 index 0000000000..34fa7e27aa --- /dev/null +++ b/src/resolver/struct_storage_class_use_test.cc @@ -0,0 +1,161 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/resolver/resolver.h" + +#include "gmock/gmock.h" +#include "src/resolver/resolver_test_helper.h" +#include "src/semantic/struct.h" + +using ::testing::UnorderedElementsAre; + +namespace tint { +namespace resolver { +namespace { + +using ResolverStorageClassUseTest = ResolverTest; + +TEST_F(ResolverStorageClassUseTest, UnreachableStruct) { + auto* s = Structure("S", {Member("a", ty.f32())}); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_TRUE(sem->StorageClassUsage().empty()); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableFromGlobal) { + auto* s = Structure("S", {Member("a", ty.f32())}); + + Global("g", s, ast::StorageClass::kStorage); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kStorage)); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalAlias) { + auto* s = Structure("S", {Member("a", ty.f32())}); + auto* a = ty.alias("A", s); + Global("g", a, ast::StorageClass::kStorage); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kStorage)); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalStruct) { + auto* s = Structure("S", {Member("a", ty.f32())}); + auto* o = Structure("O", {Member("a", s)}); + Global("g", o, ast::StorageClass::kStorage); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kStorage)); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalArray) { + auto* s = Structure("S", {Member("a", ty.f32())}); + auto* a = ty.array(s, 3); + Global("g", a, ast::StorageClass::kStorage); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kStorage)); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableFromLocal) { + auto* s = Structure("S", {Member("a", ty.f32())}); + + WrapInFunction(Var("g", s, ast::StorageClass::kFunction)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kFunction)); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalAlias) { + auto* s = Structure("S", {Member("a", ty.f32())}); + auto* a = ty.alias("A", s); + WrapInFunction(Var("g", a, ast::StorageClass::kFunction)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kFunction)); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalStruct) { + auto* s = Structure("S", {Member("a", ty.f32())}); + auto* o = Structure("O", {Member("a", s)}); + WrapInFunction(Var("g", o, ast::StorageClass::kFunction)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kFunction)); +} + +TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalArray) { + auto* s = Structure("S", {Member("a", ty.f32())}); + auto* a = ty.array(s, 3); + WrapInFunction(Var("g", a, ast::StorageClass::kFunction)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kFunction)); +} + +TEST_F(ResolverStorageClassUseTest, StructMultipleStorageClassUses) { + auto* s = Structure("S", {Member("a", ty.f32())}); + Global("x", s, ast::StorageClass::kStorage); + Global("y", s, ast::StorageClass::kUniform); + WrapInFunction(Var("g", s, ast::StorageClass::kFunction)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(s); + ASSERT_NE(sem, nullptr); + EXPECT_THAT(sem->StorageClassUsage(), + UnorderedElementsAre(ast::StorageClass::kStorage, + ast::StorageClass::kUniform, + ast::StorageClass::kFunction)); +} + +} // namespace +} // namespace resolver +} // namespace tint diff --git a/src/semantic/sem_struct.cc b/src/semantic/sem_struct.cc index c1379f8240..4201517838 100644 --- a/src/semantic/sem_struct.cc +++ b/src/semantic/sem_struct.cc @@ -24,12 +24,14 @@ Struct::Struct(type::Struct* type, StructMemberList members, uint32_t align, uint32_t size, - uint32_t size_no_padding) + uint32_t size_no_padding, + std::unordered_set storage_class_usage) : type_(type), members_(std::move(members)), align_(align), size_(size), - size_no_padding_(size_no_padding) {} + size_no_padding_(size_no_padding), + storage_class_usage_(std::move(storage_class_usage)) {} Struct::~Struct() = default; diff --git a/src/semantic/struct.h b/src/semantic/struct.h index 08db295f0c..d8c42f42e7 100644 --- a/src/semantic/struct.h +++ b/src/semantic/struct.h @@ -17,8 +17,10 @@ #include +#include #include +#include "src/ast/storage_class.h" #include "src/semantic/node.h" namespace tint { @@ -48,11 +50,13 @@ class Struct : public Castable { /// @param size the byte size of the structure /// @param size_no_padding size of the members without the end of structure /// alignment padding + /// @param storage_class_usage a set of all the storage class usages Struct(type::Struct* type, StructMemberList members, uint32_t align, uint32_t size, - uint32_t size_no_padding); + uint32_t size_no_padding, + std::unordered_set storage_class_usage); /// Destructor ~Struct() override; @@ -79,12 +83,24 @@ class Struct : public Castable { /// alignment padding uint32_t SizeNoPadding() const { return size_no_padding_; } + /// @returns the set of storage class uses of this structure + const std::unordered_set& StorageClassUsage() const { + return storage_class_usage_; + } + + /// @param usage the ast::StorageClass usage type to query + /// @returns true iff this structure has been used as the given storage class + bool UsedAs(ast::StorageClass usage) const { + return storage_class_usage_.count(usage) > 0; + } + private: type::Struct* const type_; StructMemberList const members_; uint32_t const align_; uint32_t const size_; uint32_t const size_no_padding_; + std::unordered_set const storage_class_usage_; }; /// StructMember holds the semantic information for structure members. diff --git a/test/BUILD.gn b/test/BUILD.gn index a7bc6740dc..228541551d 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -175,6 +175,7 @@ source_set("tint_unittests_core_src") { "../src/resolver/resolver_test_helper.h", "../src/resolver/resolver_test.cc", "../src/resolver/struct_layout_test.cc", + "../src/resolver/struct_storage_class_use_test.cc", "../src/resolver/validation_test.cc", "../src/scope_stack_test.cc", "../src/semantic/sem_intrinsic_test.cc",