diff --git a/src/ast/type/storage_texture_type.cc b/src/ast/type/storage_texture_type.cc index a61fc01e22..f3f3b9f689 100644 --- a/src/ast/type/storage_texture_type.cc +++ b/src/ast/type/storage_texture_type.cc @@ -164,6 +164,14 @@ StorageTextureType::StorageTextureType(TextureDimension dim, assert(IsValidStorageDimension(dim)); } +void StorageTextureType::set_type(Type* const type) { + type_ = type; +} + +Type* StorageTextureType::type() const { + return type_; +} + StorageTextureType::StorageTextureType(StorageTextureType&&) = default; StorageTextureType::~StorageTextureType() = default; diff --git a/src/ast/type/storage_texture_type.h b/src/ast/type/storage_texture_type.h index 2c27ffd347..ae6c4fdd1c 100644 --- a/src/ast/type/storage_texture_type.h +++ b/src/ast/type/storage_texture_type.h @@ -74,9 +74,10 @@ class StorageTextureType : public TextureType { /// @param dim the dimensionality of the texture /// @param access the access type for the texture /// @param format the image format of the texture - explicit StorageTextureType(TextureDimension dim, - StorageAccess access, - ImageFormat format); + StorageTextureType(TextureDimension dim, + StorageAccess access, + ImageFormat format); + /// Move constructor StorageTextureType(StorageTextureType&&); ~StorageTextureType() override; @@ -84,8 +85,11 @@ class StorageTextureType : public TextureType { /// @returns true if the type is a storage texture type bool IsStorage() const override; - /// @returns the subtype of the sampled texture - Type* type() const { return type_; } + /// @param type the subtype of the storage texture + void set_type(Type* const type); + + /// @returns the subtype of the storage texture set with set_type + Type* type() const; /// @returns the storage access StorageAccess access() const { return storage_access_; } diff --git a/src/ast/type/storage_texture_type_test.cc b/src/ast/type/storage_texture_type_test.cc index 6c834a0717..bc0f5032d3 100644 --- a/src/ast/type/storage_texture_type_test.cc +++ b/src/ast/type/storage_texture_type_test.cc @@ -14,6 +14,9 @@ #include "src/ast/type/storage_texture_type.h" +#include "src/ast/identifier_expression.h" +#include "src/type_determiner.h" + #include "gtest/gtest.h" namespace tint { @@ -72,6 +75,48 @@ TEST_F(StorageTextureTypeTest, TypeName) { EXPECT_EQ(s.type_name(), "__storage_texture_read_2d_array_rgba32float"); } +TEST_F(StorageTextureTypeTest, F32Type) { + Context ctx; + ast::type::Type* s = ctx.type_mgr().Get(std::make_unique( + TextureDimension::k2dArray, StorageAccess::kRead, + ImageFormat::kRgba32Float)); + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + ASSERT_TRUE(td.Determine()) << td.error(); + ASSERT_TRUE(s->IsTexture()); + ASSERT_TRUE(s->AsTexture()->IsStorage()); + EXPECT_TRUE(s->AsTexture()->AsStorage()->type()->IsF32()); +} + +TEST_F(StorageTextureTypeTest, U32Type) { + Context ctx; + ast::type::Type* s = ctx.type_mgr().Get(std::make_unique( + TextureDimension::k2dArray, StorageAccess::kRead, + ImageFormat::kRgba8Unorm)); + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + ASSERT_TRUE(td.Determine()) << td.error(); + ASSERT_TRUE(s->IsTexture()); + ASSERT_TRUE(s->AsTexture()->IsStorage()); + EXPECT_TRUE(s->AsTexture()->AsStorage()->type()->IsU32()); +} + +TEST_F(StorageTextureTypeTest, I32Type) { + Context ctx; + ast::type::Type* s = ctx.type_mgr().Get(std::make_unique( + TextureDimension::k2dArray, StorageAccess::kRead, + ImageFormat::kRgba32Sint)); + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + ASSERT_TRUE(td.Determine()) << td.error(); + ASSERT_TRUE(s->IsTexture()); + ASSERT_TRUE(s->AsTexture()->IsStorage()); + EXPECT_TRUE(s->AsTexture()->AsStorage()->type()->IsI32()); +} + } // namespace } // namespace type } // namespace ast diff --git a/src/type_determiner.cc b/src/type_determiner.cc index a559d8bdb7..9f2e90df56 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -41,9 +41,11 @@ #include "src/ast/type/array_type.h" #include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" +#include "src/ast/type/i32_type.h" #include "src/ast/type/matrix_type.h" #include "src/ast/type/pointer_type.h" #include "src/ast/type/struct_type.h" +#include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/unary_op_expression.h" @@ -177,6 +179,18 @@ void TypeDeterminer::set_referenced_from_function_if_needed( } bool TypeDeterminer::Determine() { + for (auto& iter : ctx_.type_mgr().types()) { + auto& type = iter.second; + if (!type->IsTexture() || !type->AsTexture()->IsStorage()) { + continue; + } + if (!DetermineStorageTextureSubtype(type->AsTexture()->AsStorage())) { + set_error(Source{}, "unable to determine storage texture subtype for: " + + type->type_name()); + return false; + } + } + for (const auto& var : mod_->global_variables()) { variable_stack_.set_global(var->name(), var.get()); @@ -824,6 +838,67 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) { return true; } +bool TypeDeterminer::DetermineStorageTextureSubtype( + ast::type::StorageTextureType* tex) { + if (tex->type() != nullptr) { + return true; + } + + switch (tex->image_format()) { + case ast::type::ImageFormat::kR8Unorm: + case ast::type::ImageFormat::kRg8Unorm: + case ast::type::ImageFormat::kRgba8Unorm: + case ast::type::ImageFormat::kRgba8UnormSrgb: + case ast::type::ImageFormat::kBgra8Unorm: + case ast::type::ImageFormat::kBgra8UnormSrgb: + case ast::type::ImageFormat::kRgb10A2Unorm: + case ast::type::ImageFormat::kR8Uint: + case ast::type::ImageFormat::kR16Uint: + case ast::type::ImageFormat::kRg8Uint: + case ast::type::ImageFormat::kR32Uint: + case ast::type::ImageFormat::kRg16Uint: + case ast::type::ImageFormat::kRgba8Uint: + case ast::type::ImageFormat::kRg32Uint: + case ast::type::ImageFormat::kRgba16Uint: + case ast::type::ImageFormat::kRgba32Uint: { + tex->set_type( + ctx_.type_mgr().Get(std::make_unique())); + return true; + } + + case ast::type::ImageFormat::kR8Snorm: + case ast::type::ImageFormat::kRg8Snorm: + case ast::type::ImageFormat::kRgba8Snorm: + case ast::type::ImageFormat::kR8Sint: + case ast::type::ImageFormat::kR16Sint: + case ast::type::ImageFormat::kRg8Sint: + case ast::type::ImageFormat::kR32Sint: + case ast::type::ImageFormat::kRg16Sint: + case ast::type::ImageFormat::kRgba8Sint: + case ast::type::ImageFormat::kRg32Sint: + case ast::type::ImageFormat::kRgba16Sint: + case ast::type::ImageFormat::kRgba32Sint: { + tex->set_type( + ctx_.type_mgr().Get(std::make_unique())); + return true; + } + + case ast::type::ImageFormat::kR16Float: + case ast::type::ImageFormat::kR32Float: + case ast::type::ImageFormat::kRg16Float: + case ast::type::ImageFormat::kRg11B10Float: + case ast::type::ImageFormat::kRg32Float: + case ast::type::ImageFormat::kRgba16Float: + case ast::type::ImageFormat::kRgba32Float: { + tex->set_type( + ctx_.type_mgr().Get(std::make_unique())); + return true; + } + } + + return false; +} + ast::type::Type* TypeDeterminer::GetImportData( const Source& source, const std::string& path, diff --git a/src/type_determiner.h b/src/type_determiner.h index cf8f8bd1af..290357d054 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -19,6 +19,7 @@ #include #include "src/ast/module.h" +#include "src/ast/type/storage_texture_type.h" #include "src/context.h" #include "src/scope_stack.h" @@ -118,6 +119,8 @@ class TypeDeterminer { bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr); bool DetermineUnaryOp(ast::UnaryOpExpression* expr); + bool DetermineStorageTextureSubtype(ast::type::StorageTextureType* tex); + Context& ctx_; ast::Module* mod_; std::string error_; diff --git a/src/type_manager.h b/src/type_manager.h index eb26a25d9d..2c82d76d40 100644 --- a/src/type_manager.h +++ b/src/type_manager.h @@ -37,10 +37,10 @@ class TypeManager { /// @return the pointer to the registered type ast::type::Type* Get(std::unique_ptr type); - /// Returns the type map, for testing purposes. + /// Returns the type map /// @returns the mapping from name string to type. const std::unordered_map>& - TypesForTesting() { + types() { return types_; } diff --git a/src/type_manager_test.cc b/src/type_manager_test.cc index 1628ffb7e8..217259a6a6 100644 --- a/src/type_manager_test.cc +++ b/src/type_manager_test.cc @@ -57,9 +57,9 @@ TEST_F(TypeManagerTest, ResetClearsPreviousData) { auto* t = tm.Get(std::make_unique()); ASSERT_NE(t, nullptr); - EXPECT_FALSE(tm.TypesForTesting().empty()); + EXPECT_FALSE(tm.types().empty()); tm.Reset(); - EXPECT_TRUE(tm.TypesForTesting().empty()); + EXPECT_TRUE(tm.types().empty()); auto* t2 = tm.Get(std::make_unique()); ASSERT_NE(t2, nullptr);