diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index e148a49824..75379023bf 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -573,6 +573,7 @@ libtint_source_set("libtint_core_all_src") { "type/sampler.h", "type/short_name.h", "type/storage_texture.h", + "type/struct.h", "type/texture.h", "type/type.h", "type/type_manager.h", @@ -738,6 +739,8 @@ libtint_source_set("libtint_type_src") { "type/short_name.h", "type/storage_texture.cc", "type/storage_texture.h", + "type/struct.cc", + "type/struct.h", "type/texture.cc", "type/texture.h", "type/type.cc", @@ -1228,6 +1231,7 @@ if (tint_build_unittests) { "type/sampler_test.cc", "type/short_name_test.cc", "type/storage_texture_test.cc", + "type/struct_test.cc", "type/texture_test.cc", "type/type_manager_test.cc", "type/type_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index bc279cd864..375c253134 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -335,6 +335,7 @@ list(APPEND TINT_LIB_SRCS sem/sampler_texture_pair.h sem/statement.cc sem/struct.cc + sem/struct.h sem/switch_statement.cc sem/switch_statement.h sem/type_initializer.cc @@ -493,6 +494,8 @@ list(APPEND TINT_LIB_SRCS type/sampler.h type/storage_texture.cc type/storage_texture.h + type/struct.cc + type/struct.h type/texture.cc type/texture.h type/type.cc @@ -946,6 +949,7 @@ if(TINT_BUILD_TESTS) type/sampled_texture_test.cc type/sampler_test.cc type/storage_texture_test.cc + type/struct_test.cc type/texture_test.cc type/type_test.cc type/type_manager_test.cc diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 224c800fb0..03dca1724e 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -497,7 +497,7 @@ class ProgramBuilder { /// @returns the de-aliased array count pointer template traits::EnableIf || - traits::IsTypeOrDerived, + traits::IsTypeOrDerived, T>* create(ARGS&&... args) { AssertNotMoved(); diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 28cb6197da..29eb86e01e 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -34,6 +34,7 @@ #include "src/tint/type/f32.h" #include "src/tint/type/i32.h" #include "src/tint/type/matrix.h" +#include "src/tint/type/struct.h" #include "src/tint/type/u32.h" #include "src/tint/type/vector.h" #include "src/tint/utils/bitcast.h" @@ -415,7 +416,7 @@ struct Composite : ImplConstant { utils::Vector conv_els; conv_els.Reserve(elements.Length()); std::function target_el_ty; - if (auto* str = target_ty->As()) { + if (auto* str = target_ty->As()) { if (str->Members().Length() != elements.Length()) { TINT_ICE(Resolver, builder.Diagnostics()) << "const-eval conversion of structure has mismatched element counts"; @@ -493,7 +494,7 @@ const ImplConstant* ZeroValue(ProgramBuilder& builder, const type::Type* type) { } return nullptr; }, - [&](const sem::Struct* s) -> const ImplConstant* { + [&](const type::StructBase* s) -> const ImplConstant* { utils::Hashmap zero_by_type; utils::Vector zeros; zeros.Reserve(s->Members().Length()); @@ -1448,7 +1449,7 @@ ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr, } ConstEval::Result ConstEval::MemberAccess(const sem::Expression* obj_expr, - const sem::StructMember* member) { + const type::StructMemberBase* member) { auto obj_val = obj_expr->ConstantValue(); if (!obj_val) { return nullptr; diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 81e4575ba0..d7d370d7e8 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -33,8 +33,10 @@ class LiteralExpression; namespace tint::sem { class Constant; class Expression; -class StructMember; } // namespace tint::sem +namespace tint::type { +class StructMemberBase; +} // namespace tint::type namespace tint::resolver { @@ -92,7 +94,7 @@ class ConstEval { /// @param obj the object being accessed /// @param member the member /// @return the result of the member access, or null if the value cannot be calculated - Result MemberAccess(const sem::Expression* obj, const sem::StructMember* member); + Result MemberAccess(const sem::Expression* obj, const type::StructMemberBase* member); /// @param ty the result type /// @param vector the vector being swizzled diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 3c9d54279a..303378611b 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -996,13 +996,13 @@ sem::Function* Resolver::Function(const ast::Function* decl) { if (auto* str = p_ty->As()) { switch (decl->PipelineStage()) { case ast::PipelineStage::kVertex: - str->AddUsage(sem::PipelineStageUsage::kVertexInput); + str->AddUsage(type::PipelineStageUsage::kVertexInput); break; case ast::PipelineStage::kFragment: - str->AddUsage(sem::PipelineStageUsage::kFragmentInput); + str->AddUsage(type::PipelineStageUsage::kFragmentInput); break; case ast::PipelineStage::kCompute: - str->AddUsage(sem::PipelineStageUsage::kComputeInput); + str->AddUsage(type::PipelineStageUsage::kComputeInput); break; case ast::PipelineStage::kNone: break; @@ -1048,13 +1048,13 @@ sem::Function* Resolver::Function(const ast::Function* decl) { switch (decl->PipelineStage()) { case ast::PipelineStage::kVertex: - str->AddUsage(sem::PipelineStageUsage::kVertexOutput); + str->AddUsage(type::PipelineStageUsage::kVertexOutput); break; case ast::PipelineStage::kFragment: - str->AddUsage(sem::PipelineStageUsage::kFragmentOutput); + str->AddUsage(type::PipelineStageUsage::kFragmentOutput); break; case ast::PipelineStage::kCompute: - str->AddUsage(sem::PipelineStageUsage::kComputeOutput); + str->AddUsage(type::PipelineStageUsage::kComputeOutput); break; case ast::PipelineStage::kNone: break; diff --git a/src/tint/resolver/struct_pipeline_stage_use_test.cc b/src/tint/resolver/struct_pipeline_stage_use_test.cc index c4455b7096..cd77c07c7b 100644 --- a/src/tint/resolver/struct_pipeline_stage_use_test.cc +++ b/src/tint/resolver/struct_pipeline_stage_use_test.cc @@ -76,7 +76,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderParam) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kVertexInput)); + UnorderedElementsAre(type::PipelineStageUsage::kVertexInput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderReturnType) { @@ -92,7 +92,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderReturnType) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kVertexOutput)); + UnorderedElementsAre(type::PipelineStageUsage::kVertexOutput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) { @@ -106,7 +106,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kFragmentInput)); + UnorderedElementsAre(type::PipelineStageUsage::kFragmentInput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderReturnType) { @@ -120,7 +120,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderReturnType) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kFragmentOutput)); + UnorderedElementsAre(type::PipelineStageUsage::kFragmentOutput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedAsComputeShaderParam) { @@ -136,7 +136,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsComputeShaderParam) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kComputeInput)); + UnorderedElementsAre(type::PipelineStageUsage::kComputeInput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedMultipleStages) { @@ -155,8 +155,8 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedMultipleStages) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kVertexOutput, - sem::PipelineStageUsage::kFragmentInput)); + UnorderedElementsAre(type::PipelineStageUsage::kVertexOutput, + type::PipelineStageUsage::kFragmentInput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) { @@ -171,7 +171,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kFragmentInput)); + UnorderedElementsAre(type::PipelineStageUsage::kFragmentInput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamLocationSet) { @@ -201,7 +201,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeViaAlias) { auto* sem = TypeOf(s)->As(); ASSERT_NE(sem, nullptr); EXPECT_THAT(sem->PipelineStageUses(), - UnorderedElementsAre(sem::PipelineStageUsage::kFragmentOutput)); + UnorderedElementsAre(type::PipelineStageUsage::kFragmentOutput)); } TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeLocationSet) { diff --git a/src/tint/sem/member_accessor_expression.h b/src/tint/sem/member_accessor_expression.h index ba9a552846..a56c61f744 100644 --- a/src/tint/sem/member_accessor_expression.h +++ b/src/tint/sem/member_accessor_expression.h @@ -23,7 +23,6 @@ namespace tint::ast { class MemberAccessorExpression; } // namespace tint::ast namespace tint::sem { -class Struct; class StructMember; } // namespace tint::sem diff --git a/src/tint/sem/struct.cc b/src/tint/sem/struct.cc index 0db346ae0a..413231ab71 100644 --- a/src/tint/sem/struct.cc +++ b/src/tint/sem/struct.cc @@ -14,44 +14,12 @@ #include "src/tint/sem/struct.h" -#include -#include -#include -#include - #include "src/tint/ast/struct_member.h" -#include "src/tint/symbol_table.h" -#include "src/tint/utils/hash.h" -TINT_INSTANTIATE_TYPEINFO(tint::sem::StructBase); TINT_INSTANTIATE_TYPEINFO(tint::sem::Struct); -TINT_INSTANTIATE_TYPEINFO(tint::sem::StructMemberBase); TINT_INSTANTIATE_TYPEINFO(tint::sem::StructMember); namespace tint::sem { -namespace { - -type::TypeFlags FlagsFrom(utils::VectorRef members) { - type::TypeFlags flags{ - type::TypeFlag::kConstructable, - type::TypeFlag::kCreationFixedFootprint, - type::TypeFlag::kFixedFootprint, - }; - for (auto* member : members) { - if (!member->Type()->IsConstructible()) { - flags.Remove(type::TypeFlag::kConstructable); - } - if (!member->Type()->HasFixedFootprint()) { - flags.Remove(type::TypeFlag::kFixedFootprint); - } - if (!member->Type()->HasCreationFixedFootprint()) { - flags.Remove(type::TypeFlag::kCreationFixedFootprint); - } - } - return flags; -} - -} // namespace Struct::Struct(const ast::Struct* declaration, tint::Source source, @@ -64,122 +32,6 @@ Struct::Struct(const ast::Struct* declaration, Struct::~Struct() = default; -StructBase::StructBase(tint::Source source, - Symbol name, - utils::VectorRef members, - uint32_t align, - uint32_t size, - uint32_t size_no_padding) - : Base(FlagsFrom(members)), - source_(source), - name_(name), - members_(std::move(members)), - align_(align), - size_(size), - size_no_padding_(size_no_padding) {} - -StructBase::~StructBase() = default; - -size_t StructBase::Hash() const { - return utils::Hash(TypeInfo::Of().full_hashcode, name_); -} - -bool StructBase::Equals(const type::Type& other) const { - if (auto* o = other.As()) { - return o->name_ == name_; - } - return false; -} - -const StructMemberBase* StructBase::FindMember(Symbol name) const { - for (auto* member : members_) { - if (member->Name() == name) { - return member; - } - } - return nullptr; -} - -uint32_t StructBase::Align() const { - return align_; -} - -uint32_t StructBase::Size() const { - return size_; -} - -std::string StructBase::FriendlyName(const SymbolTable& symbols) const { - return symbols.NameFor(name_); -} - -std::string StructBase::Layout(const tint::SymbolTable& symbols) const { - std::stringstream ss; - - auto member_name_of = [&](const sem::StructMemberBase* sm) { - return symbols.NameFor(sm->Name()); - }; - - if (Members().IsEmpty()) { - return {}; - } - const auto* const last_member = Members().Back(); - const uint32_t last_member_struct_padding_offset = last_member->Offset() + last_member->Size(); - - // Compute max widths to align output - const auto offset_w = static_cast(::log10(last_member_struct_padding_offset)) + 1; - const auto size_w = static_cast(::log10(Size())) + 1; - const auto align_w = static_cast(::log10(Align())) + 1; - - auto print_struct_begin_line = [&](size_t align, size_t size, std::string struct_name) { - ss << "/* " << std::setw(offset_w) << " " - << "align(" << std::setw(align_w) << align << ") size(" << std::setw(size_w) << size - << ") */ struct " << struct_name << " {\n"; - }; - - auto print_struct_end_line = [&]() { - ss << "/* " << std::setw(offset_w + size_w + align_w) << " " - << "*/ };"; - }; - - auto print_member_line = [&](size_t offset, size_t align, size_t size, std::string s) { - ss << "/* offset(" << std::setw(offset_w) << offset << ") align(" << std::setw(align_w) - << align << ") size(" << std::setw(size_w) << size << ") */ " << s << ";\n"; - }; - - print_struct_begin_line(Align(), Size(), UnwrapRef()->FriendlyName(symbols)); - - for (size_t i = 0; i < Members().Length(); ++i) { - auto* const m = Members()[i]; - - // Output field alignment padding, if any - auto* const prev_member = (i == 0) ? nullptr : Members()[i - 1]; - if (prev_member) { - uint32_t padding = m->Offset() - (prev_member->Offset() + prev_member->Size()); - if (padding > 0) { - size_t padding_offset = m->Offset() - padding; - print_member_line(padding_offset, 1, padding, - "// -- implicit field alignment padding --"); - } - } - - // Output member - std::string member_name = member_name_of(m); - print_member_line(m->Offset(), m->Align(), m->Size(), - member_name + " : " + m->Type()->UnwrapRef()->FriendlyName(symbols)); - } - - // Output struct size padding, if any - uint32_t struct_padding = Size() - last_member_struct_padding_offset; - if (struct_padding > 0) { - print_member_line(last_member_struct_padding_offset, 1, struct_padding, - "// -- implicit struct size padding --"); - } - - print_struct_end_line(); - - return ss.str(); -} - StructMember::StructMember(const ast::StructMember* declaration, tint::Source source, Symbol name, @@ -193,23 +45,4 @@ StructMember::StructMember(const ast::StructMember* declaration, StructMember::~StructMember() = default; -StructMemberBase::StructMemberBase(tint::Source source, - Symbol name, - const type::Type* type, - uint32_t index, - uint32_t offset, - uint32_t align, - uint32_t size, - std::optional location) - : source_(source), - name_(name), - type_(type), - index_(index), - offset_(offset), - align_(align), - size_(size), - location_(location) {} - -StructMemberBase::~StructMemberBase() = default; - } // namespace tint::sem diff --git a/src/tint/sem/struct.h b/src/tint/sem/struct.h index ca2bc4c5e1..ac535ece30 100644 --- a/src/tint/sem/struct.h +++ b/src/tint/sem/struct.h @@ -15,16 +15,12 @@ #ifndef SRC_TINT_SEM_STRUCT_H_ #define SRC_TINT_SEM_STRUCT_H_ -#include - #include -#include -#include #include "src/tint/ast/address_space.h" #include "src/tint/ast/struct.h" -#include "src/tint/sem/node.h" #include "src/tint/symbol.h" +#include "src/tint/type/struct.h" #include "src/tint/type/type.h" #include "src/tint/utils/vector.h" @@ -34,145 +30,15 @@ class StructMember; } // namespace tint::ast namespace tint::sem { class StructMember; -class StructMemberBase; } // namespace tint::sem +namespace tint::type { +class StructMemberBase; +} // namespace tint::type namespace tint::sem { -/// Metadata to capture how a structure is used in a shader module. -enum class PipelineStageUsage { - kVertexInput, - kVertexOutput, - kFragmentInput, - kFragmentOutput, - kComputeInput, - kComputeOutput, -}; - -/// StructBase holds the semantic information for structures. -class StructBase : public Castable { - public: - /// Constructor - /// @param source the source of the structure - /// @param name the name of the structure - /// @param members the structure members - /// @param align the byte alignment of the structure - /// @param size the byte size of the structure - /// @param size_no_padding size of the members without the end of structure - /// alignment padding - StructBase(tint::Source source, - Symbol name, - utils::VectorRef members, - uint32_t align, - uint32_t size, - uint32_t size_no_padding); - - /// Destructor - ~StructBase() override; - - /// @returns a hash of the type. - size_t Hash() const override; - - /// @param other the other type to compare against - /// @returns true if the this type is equal to the given type - bool Equals(const Type& other) const override; - - /// @returns the source of the structure - tint::Source Source() const { return source_; } - - /// @returns the name of the structure - Symbol Name() const { return name_; } - - /// @returns the members of the structure - utils::VectorRef Members() const { return members_; } - - /// @param name the member name to look for - /// @returns the member with the given name, or nullptr if it was not found. - const StructMemberBase* FindMember(Symbol name) const; - - /// @returns the byte alignment of the structure - /// @note this may differ from the alignment of a structure member of this - /// structure type, if the member is annotated with the `@align(n)` - /// attribute. - uint32_t Align() const override; - - /// @returns the byte size of the structure - /// @note this may differ from the size of a structure member of this - /// structure type, if the member is annotated with the `@size(n)` - /// attribute. - uint32_t Size() const override; - - /// @returns the byte size of the members without the end of structure - /// alignment padding - uint32_t SizeNoPadding() const { return size_no_padding_; } - - /// Adds the AddressSpace usage to the structure. - /// @param usage the storage usage - void AddUsage(ast::AddressSpace usage) { address_space_usage_.emplace(usage); } - - /// @returns the set of address space uses of this structure - const std::unordered_set& AddressSpaceUsage() const { - return address_space_usage_; - } - - /// @param usage the ast::AddressSpace usage type to query - /// @returns true iff this structure has been used as the given address space - bool UsedAs(ast::AddressSpace usage) const { return address_space_usage_.count(usage) > 0; } - - /// @returns true iff this structure has been used by address space that's - /// host-shareable. - bool IsHostShareable() const { - for (auto sc : address_space_usage_) { - if (ast::IsHostShareable(sc)) { - return true; - } - } - return false; - } - - /// Adds the pipeline stage usage to the structure. - /// @param usage the storage usage - void AddUsage(PipelineStageUsage usage) { pipeline_stage_uses_.emplace(usage); } - - /// @returns the set of entry point uses of this structure - const std::unordered_set& PipelineStageUses() const { - return pipeline_stage_uses_; - } - - /// @param symbols the program's symbol table - /// @returns the name for this type that closely resembles how it would be - /// declared in WGSL. - std::string FriendlyName(const SymbolTable& symbols) const override; - - /// @param symbols the program's symbol table - /// @returns a multiline string that describes the layout of this struct, - /// including size and alignment information. - std::string Layout(const tint::SymbolTable& symbols) const; - - /// @param concrete the conversion-rank ordered concrete versions of this abstract structure. - void SetConcreteTypes(utils::VectorRef concrete) { - concrete_types_ = concrete; - } - - /// @returns the conversion-rank ordered concrete versions of this abstract structure, or an - /// empty vector if this structure is not abstract. - /// @note only structures returned by builtins may be abstract (e.g. modf, frexp) - utils::VectorRef ConcreteTypes() const { return concrete_types_; } - - private: - const tint::Source source_; - const Symbol name_; - const utils::Vector members_; - const uint32_t align_; - const uint32_t size_; - const uint32_t size_no_padding_; - std::unordered_set address_space_usage_; - std::unordered_set pipeline_stage_uses_; - utils::Vector concrete_types_; -}; - /// Struct holds the semantic information for structures. -class Struct final : public Castable { +class Struct final : public Castable { public: /// Constructor /// @param declaration the AST structure declaration @@ -206,75 +72,8 @@ class Struct final : public Castable { ast::Struct const* const declaration_; }; -/// StructMemberBase holds the semantic information for structure members. -class StructMemberBase : public Castable { - public: - /// Constructor - /// @param source the source of the struct member - /// @param name the name of the structure member - /// @param type the type of the member - /// @param index the index of the member in the structure - /// @param offset the byte offset from the base of the structure - /// @param align the byte alignment of the member - /// @param size the byte size of the member - /// @param location the location attribute, if present - StructMemberBase(tint::Source source, - Symbol name, - const type::Type* type, - uint32_t index, - uint32_t offset, - uint32_t align, - uint32_t size, - std::optional location); - - /// Destructor - ~StructMemberBase() override; - - /// @returns the source the struct member - const tint::Source& Source() const { return source_; } - - /// @returns the name of the structure member - Symbol Name() const { return name_; } - - /// Sets the owning structure to `s` - /// @param s the new structure owner - void SetStruct(const sem::StructBase* s) { struct_ = s; } - - /// @returns the structure that owns this member - const sem::StructBase* Struct() const { return struct_; } - - /// @returns the type of the member - const type::Type* Type() const { return type_; } - - /// @returns the member index - uint32_t Index() const { return index_; } - - /// @returns byte offset from base of structure - uint32_t Offset() const { return offset_; } - - /// @returns the alignment of the member in bytes - uint32_t Align() const { return align_; } - - /// @returns byte size - uint32_t Size() const { return size_; } - - /// @returns the location, if set - std::optional Location() const { return location_; } - - private: - const tint::Source source_; - const Symbol name_; - const sem::StructBase* struct_; - const type::Type* type_; - const uint32_t index_; - const uint32_t offset_; - const uint32_t align_; - const uint32_t size_; - const std::optional location_; -}; - /// StructMember holds the semantic information for structure members. -class StructMember final : public Castable { +class StructMember final : public Castable { public: /// Constructor /// @param declaration the AST declaration node diff --git a/src/tint/sem/struct_test.cc b/src/tint/sem/struct_test.cc index b2f16cf651..8d16c2daa7 100644 --- a/src/tint/sem/struct_test.cc +++ b/src/tint/sem/struct_test.cc @@ -20,9 +20,9 @@ namespace tint::sem { namespace { using namespace tint::number_suffixes; // NOLINT -using StructTest = TestHelper; +using SemStructTest = TestHelper; -TEST_F(StructTest, Creation) { +TEST_F(SemStructTest, Creation) { auto name = Sym("S"); auto* impl = create(name, utils::Empty, utils::Empty); auto* ptr = impl; @@ -34,7 +34,7 @@ TEST_F(StructTest, Creation) { EXPECT_EQ(s->SizeNoPadding(), 16u); } -TEST_F(StructTest, Hash) { +TEST_F(SemStructTest, Hash) { auto* a_impl = create(Sym("a"), utils::Empty, utils::Empty); auto* a = create(a_impl, a_impl->source, a_impl->name, utils::Empty, 4u /* align */, 4u /* size */, 4u /* size_no_padding */); @@ -45,7 +45,7 @@ TEST_F(StructTest, Hash) { EXPECT_NE(a->Hash(), b->Hash()); } -TEST_F(StructTest, Equals) { +TEST_F(SemStructTest, Equals) { auto* a_impl = create(Sym("a"), utils::Empty, utils::Empty); auto* a = create(a_impl, a_impl->source, a_impl->name, utils::Empty, 4u /* align */, 4u /* size */, 4u /* size_no_padding */); @@ -58,7 +58,7 @@ TEST_F(StructTest, Equals) { EXPECT_FALSE(a->Equals(type::Void{})); } -TEST_F(StructTest, FriendlyName) { +TEST_F(SemStructTest, FriendlyName) { auto name = Sym("my_struct"); auto* impl = create(name, utils::Empty, utils::Empty); auto* s = create(impl, impl->source, impl->name, utils::Empty, 4u /* align */, @@ -66,162 +66,5 @@ TEST_F(StructTest, FriendlyName) { EXPECT_EQ(s->FriendlyName(Symbols()), "my_struct"); } -TEST_F(StructTest, Layout) { - auto* inner_st = // - Structure("Inner", utils::Vector{ - Member("a", ty.i32()), - Member("b", ty.u32()), - Member("c", ty.f32()), - Member("d", ty.vec3()), - Member("e", ty.mat4x2()), - }); - - auto* outer_st = Structure("Outer", utils::Vector{ - Member("inner", ty.type_name("Inner")), - Member("a", ty.i32()), - }); - - auto p = Build(); - ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); - - auto* sem_inner_st = p.Sem().Get(inner_st); - auto* sem_outer_st = p.Sem().Get(outer_st); - - EXPECT_EQ(sem_inner_st->Layout(p.Symbols()), - R"(/* align(16) size(64) */ struct Inner { -/* offset( 0) align( 4) size( 4) */ a : i32; -/* offset( 4) align( 4) size( 4) */ b : u32; -/* offset( 8) align( 4) size( 4) */ c : f32; -/* offset(12) align( 1) size( 4) */ // -- implicit field alignment padding --; -/* offset(16) align(16) size(12) */ d : vec3; -/* offset(28) align( 1) size( 4) */ // -- implicit field alignment padding --; -/* offset(32) align( 8) size(32) */ e : mat4x2; -/* */ };)"); - - EXPECT_EQ(sem_outer_st->Layout(p.Symbols()), - R"(/* align(16) size(80) */ struct Outer { -/* offset( 0) align(16) size(64) */ inner : Inner; -/* offset(64) align( 4) size( 4) */ a : i32; -/* offset(68) align( 1) size(12) */ // -- implicit struct size padding --; -/* */ };)"); -} - -TEST_F(StructTest, Location) { - auto* st = Structure("st", utils::Vector{ - Member("a", ty.i32(), utils::Vector{Location(1_u)}), - Member("b", ty.u32()), - }); - - auto p = Build(); - ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); - - auto* sem = p.Sem().Get(st); - ASSERT_EQ(2u, sem->Members().Length()); - - EXPECT_TRUE(sem->Members()[0]->Location().has_value()); - EXPECT_EQ(sem->Members()[0]->Location().value(), 1u); - - EXPECT_FALSE(sem->Members()[1]->Location().has_value()); -} - -TEST_F(StructTest, IsConstructable) { - auto* inner = // - Structure("Inner", utils::Vector{ - Member("a", ty.i32()), - Member("b", ty.u32()), - Member("c", ty.f32()), - Member("d", ty.vec3()), - Member("e", ty.mat4x2()), - }); - - auto* outer = Structure("Outer", utils::Vector{ - Member("inner", ty.type_name("Inner")), - Member("a", ty.i32()), - }); - - auto* outer_runtime_sized_array = - Structure("OuterRuntimeSizedArray", utils::Vector{ - Member("inner", ty.type_name("Inner")), - Member("a", ty.i32()), - Member("runtime_sized_array", ty.array()), - }); - auto p = Build(); - ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); - - auto* sem_inner = p.Sem().Get(inner); - auto* sem_outer = p.Sem().Get(outer); - auto* sem_outer_runtime_sized_array = p.Sem().Get(outer_runtime_sized_array); - - EXPECT_TRUE(sem_inner->IsConstructible()); - EXPECT_TRUE(sem_outer->IsConstructible()); - EXPECT_FALSE(sem_outer_runtime_sized_array->IsConstructible()); -} - -TEST_F(StructTest, HasCreationFixedFootprint) { - auto* inner = // - Structure("Inner", utils::Vector{ - Member("a", ty.i32()), - Member("b", ty.u32()), - Member("c", ty.f32()), - Member("d", ty.vec3()), - Member("e", ty.mat4x2()), - Member("f", ty.array()), - }); - - auto* outer = Structure("Outer", utils::Vector{ - Member("inner", ty.type_name("Inner")), - }); - - auto* outer_with_runtime_sized_array = - Structure("OuterRuntimeSizedArray", utils::Vector{ - Member("inner", ty.type_name("Inner")), - Member("runtime_sized_array", ty.array()), - }); - - auto p = Build(); - ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); - - auto* sem_inner = p.Sem().Get(inner); - auto* sem_outer = p.Sem().Get(outer); - auto* sem_outer_with_runtime_sized_array = p.Sem().Get(outer_with_runtime_sized_array); - - EXPECT_TRUE(sem_inner->HasCreationFixedFootprint()); - EXPECT_TRUE(sem_outer->HasCreationFixedFootprint()); - EXPECT_FALSE(sem_outer_with_runtime_sized_array->HasCreationFixedFootprint()); -} - -TEST_F(StructTest, HasFixedFootprint) { - auto* inner = // - Structure("Inner", utils::Vector{ - Member("a", ty.i32()), - Member("b", ty.u32()), - Member("c", ty.f32()), - Member("d", ty.vec3()), - Member("e", ty.mat4x2()), - Member("f", ty.array()), - }); - - auto* outer = Structure("Outer", utils::Vector{ - Member("inner", ty.type_name("Inner")), - }); - - auto* outer_with_runtime_sized_array = - Structure("OuterRuntimeSizedArray", utils::Vector{ - Member("inner", ty.type_name("Inner")), - Member("runtime_sized_array", ty.array()), - }); - - auto p = Build(); - ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); - - auto* sem_inner = p.Sem().Get(inner); - auto* sem_outer = p.Sem().Get(outer); - auto* sem_outer_with_runtime_sized_array = p.Sem().Get(outer_with_runtime_sized_array); - - EXPECT_TRUE(sem_inner->HasFixedFootprint()); - EXPECT_TRUE(sem_outer->HasFixedFootprint()); - EXPECT_FALSE(sem_outer_with_runtime_sized_array->HasFixedFootprint()); -} - } // namespace } // namespace tint::sem diff --git a/src/tint/type/struct.cc b/src/tint/type/struct.cc new file mode 100644 index 0000000000..bf52c373af --- /dev/null +++ b/src/tint/type/struct.cc @@ -0,0 +1,188 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/type/struct.h" + +#include +#include +#include +#include + +#include "src/tint/symbol_table.h" +#include "src/tint/utils/hash.h" + +TINT_INSTANTIATE_TYPEINFO(tint::type::StructBase); +TINT_INSTANTIATE_TYPEINFO(tint::type::StructMemberBase); + +namespace tint::type { +namespace { + +type::TypeFlags FlagsFrom(utils::VectorRef members) { + type::TypeFlags flags{ + type::TypeFlag::kConstructable, + type::TypeFlag::kCreationFixedFootprint, + type::TypeFlag::kFixedFootprint, + }; + for (auto* member : members) { + if (!member->Type()->IsConstructible()) { + flags.Remove(type::TypeFlag::kConstructable); + } + if (!member->Type()->HasFixedFootprint()) { + flags.Remove(type::TypeFlag::kFixedFootprint); + } + if (!member->Type()->HasCreationFixedFootprint()) { + flags.Remove(type::TypeFlag::kCreationFixedFootprint); + } + } + return flags; +} + +} // namespace + +StructBase::StructBase(tint::Source source, + Symbol name, + utils::VectorRef members, + uint32_t align, + uint32_t size, + uint32_t size_no_padding) + : Base(FlagsFrom(members)), + source_(source), + name_(name), + members_(std::move(members)), + align_(align), + size_(size), + size_no_padding_(size_no_padding) {} + +StructBase::~StructBase() = default; + +size_t StructBase::Hash() const { + return utils::Hash(TypeInfo::Of().full_hashcode, name_); +} + +bool StructBase::Equals(const type::Type& other) const { + if (auto* o = other.As()) { + return o->name_ == name_; + } + return false; +} + +const StructMemberBase* StructBase::FindMember(Symbol name) const { + for (auto* member : members_) { + if (member->Name() == name) { + return member; + } + } + return nullptr; +} + +uint32_t StructBase::Align() const { + return align_; +} + +uint32_t StructBase::Size() const { + return size_; +} + +std::string StructBase::FriendlyName(const SymbolTable& symbols) const { + return symbols.NameFor(name_); +} + +std::string StructBase::Layout(const tint::SymbolTable& symbols) const { + std::stringstream ss; + + auto member_name_of = [&](const type::StructMemberBase* sm) { + return symbols.NameFor(sm->Name()); + }; + + if (Members().IsEmpty()) { + return {}; + } + const auto* const last_member = Members().Back(); + const uint32_t last_member_struct_padding_offset = last_member->Offset() + last_member->Size(); + + // Compute max widths to align output + const auto offset_w = static_cast(::log10(last_member_struct_padding_offset)) + 1; + const auto size_w = static_cast(::log10(Size())) + 1; + const auto align_w = static_cast(::log10(Align())) + 1; + + auto print_struct_begin_line = [&](size_t align, size_t size, std::string struct_name) { + ss << "/* " << std::setw(offset_w) << " " + << "align(" << std::setw(align_w) << align << ") size(" << std::setw(size_w) << size + << ") */ struct " << struct_name << " {\n"; + }; + + auto print_struct_end_line = [&]() { + ss << "/* " << std::setw(offset_w + size_w + align_w) << " " + << "*/ };"; + }; + + auto print_member_line = [&](size_t offset, size_t align, size_t size, std::string s) { + ss << "/* offset(" << std::setw(offset_w) << offset << ") align(" << std::setw(align_w) + << align << ") size(" << std::setw(size_w) << size << ") */ " << s << ";\n"; + }; + + print_struct_begin_line(Align(), Size(), UnwrapRef()->FriendlyName(symbols)); + + for (size_t i = 0; i < Members().Length(); ++i) { + auto* const m = Members()[i]; + + // Output field alignment padding, if any + auto* const prev_member = (i == 0) ? nullptr : Members()[i - 1]; + if (prev_member) { + uint32_t padding = m->Offset() - (prev_member->Offset() + prev_member->Size()); + if (padding > 0) { + size_t padding_offset = m->Offset() - padding; + print_member_line(padding_offset, 1, padding, + "// -- implicit field alignment padding --"); + } + } + + // Output member + std::string member_name = member_name_of(m); + print_member_line(m->Offset(), m->Align(), m->Size(), + member_name + " : " + m->Type()->UnwrapRef()->FriendlyName(symbols)); + } + + // Output struct size padding, if any + uint32_t struct_padding = Size() - last_member_struct_padding_offset; + if (struct_padding > 0) { + print_member_line(last_member_struct_padding_offset, 1, struct_padding, + "// -- implicit struct size padding --"); + } + + print_struct_end_line(); + + return ss.str(); +} + +StructMemberBase::StructMemberBase(tint::Source source, + Symbol name, + const type::Type* type, + uint32_t index, + uint32_t offset, + uint32_t align, + uint32_t size, + std::optional location) + : source_(source), + name_(name), + type_(type), + index_(index), + offset_(offset), + align_(align), + size_(size), + location_(location) {} + +StructMemberBase::~StructMemberBase() = default; + +} // namespace tint::type diff --git a/src/tint/type/struct.h b/src/tint/type/struct.h new file mode 100644 index 0000000000..e6e2bbd1f7 --- /dev/null +++ b/src/tint/type/struct.h @@ -0,0 +1,238 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_TYPE_STRUCT_H_ +#define SRC_TINT_TYPE_STRUCT_H_ + +#include + +#include +#include +#include + +#include "src/tint/ast/address_space.h" +#include "src/tint/symbol.h" +#include "src/tint/type/node.h" +#include "src/tint/type/type.h" +#include "src/tint/utils/vector.h" + +// Forward declarations +namespace tint::type { +class StructMemberBase; +} // namespace tint::type + +namespace tint::type { + +/// Metadata to capture how a structure is used in a shader module. +enum class PipelineStageUsage { + kVertexInput, + kVertexOutput, + kFragmentInput, + kFragmentOutput, + kComputeInput, + kComputeOutput, +}; + +/// StructBase holds the Type information for structures. +class StructBase : public Castable { + public: + /// Constructor + /// @param source the source of the structure + /// @param name the name of the structure + /// @param members the structure members + /// @param align the byte alignment of the structure + /// @param size the byte size of the structure + /// @param size_no_padding size of the members without the end of structure + /// alignment padding + StructBase(tint::Source source, + Symbol name, + utils::VectorRef members, + uint32_t align, + uint32_t size, + uint32_t size_no_padding); + + /// Destructor + ~StructBase() override; + + /// @returns a hash of the type. + size_t Hash() const override; + + /// @param other the other type to compare against + /// @returns true if the this type is equal to the given type + bool Equals(const Type& other) const override; + + /// @returns the source of the structure + tint::Source Source() const { return source_; } + + /// @returns the name of the structure + Symbol Name() const { return name_; } + + /// @returns the members of the structure + utils::VectorRef Members() const { return members_; } + + /// @param name the member name to look for + /// @returns the member with the given name, or nullptr if it was not found. + const StructMemberBase* FindMember(Symbol name) const; + + /// @returns the byte alignment of the structure + /// @note this may differ from the alignment of a structure member of this + /// structure type, if the member is annotated with the `@align(n)` + /// attribute. + uint32_t Align() const override; + + /// @returns the byte size of the structure + /// @note this may differ from the size of a structure member of this + /// structure type, if the member is annotated with the `@size(n)` + /// attribute. + uint32_t Size() const override; + + /// @returns the byte size of the members without the end of structure + /// alignment padding + uint32_t SizeNoPadding() const { return size_no_padding_; } + + /// Adds the AddressSpace usage to the structure. + /// @param usage the storage usage + void AddUsage(ast::AddressSpace usage) { address_space_usage_.emplace(usage); } + + /// @returns the set of address space uses of this structure + const std::unordered_set& AddressSpaceUsage() const { + return address_space_usage_; + } + + /// @param usage the ast::AddressSpace usage type to query + /// @returns true iff this structure has been used as the given address space + bool UsedAs(ast::AddressSpace usage) const { return address_space_usage_.count(usage) > 0; } + + /// @returns true iff this structure has been used by address space that's + /// host-shareable. + bool IsHostShareable() const { + for (auto sc : address_space_usage_) { + if (ast::IsHostShareable(sc)) { + return true; + } + } + return false; + } + + /// Adds the pipeline stage usage to the structure. + /// @param usage the storage usage + void AddUsage(PipelineStageUsage usage) { pipeline_stage_uses_.emplace(usage); } + + /// @returns the set of entry point uses of this structure + const std::unordered_set& PipelineStageUses() const { + return pipeline_stage_uses_; + } + + /// @param symbols the program's symbol table + /// @returns the name for this type that closely resembles how it would be + /// declared in WGSL. + std::string FriendlyName(const SymbolTable& symbols) const override; + + /// @param symbols the program's symbol table + /// @returns a multiline string that describes the layout of this struct, + /// including size and alignment information. + std::string Layout(const tint::SymbolTable& symbols) const; + + /// @param concrete the conversion-rank ordered concrete versions of this abstract structure. + void SetConcreteTypes(utils::VectorRef concrete) { + concrete_types_ = concrete; + } + + /// @returns the conversion-rank ordered concrete versions of this abstract structure, or an + /// empty vector if this structure is not abstract. + /// @note only structures returned by builtins may be abstract (e.g. modf, frexp) + utils::VectorRef ConcreteTypes() const { return concrete_types_; } + + private: + const tint::Source source_; + const Symbol name_; + const utils::Vector members_; + const uint32_t align_; + const uint32_t size_; + const uint32_t size_no_padding_; + std::unordered_set address_space_usage_; + std::unordered_set pipeline_stage_uses_; + utils::Vector concrete_types_; +}; + +/// StructMemberBase holds the type information for structure members. +class StructMemberBase : public Castable { + public: + /// Constructor + /// @param source the source of the struct member + /// @param name the name of the structure member + /// @param type the type of the member + /// @param index the index of the member in the structure + /// @param offset the byte offset from the base of the structure + /// @param align the byte alignment of the member + /// @param size the byte size of the member + /// @param location the location attribute, if present + StructMemberBase(tint::Source source, + Symbol name, + const type::Type* type, + uint32_t index, + uint32_t offset, + uint32_t align, + uint32_t size, + std::optional location); + + /// Destructor + ~StructMemberBase() override; + + /// @returns the source the struct member + const tint::Source& Source() const { return source_; } + + /// @returns the name of the structure member + Symbol Name() const { return name_; } + + /// Sets the owning structure to `s` + /// @param s the new structure owner + void SetStruct(const type::StructBase* s) { struct_ = s; } + + /// @returns the structure that owns this member + const type::StructBase* Struct() const { return struct_; } + + /// @returns the type of the member + const type::Type* Type() const { return type_; } + + /// @returns the member index + uint32_t Index() const { return index_; } + + /// @returns byte offset from base of structure + uint32_t Offset() const { return offset_; } + + /// @returns the alignment of the member in bytes + uint32_t Align() const { return align_; } + + /// @returns byte size + uint32_t Size() const { return size_; } + + /// @returns the location, if set + std::optional Location() const { return location_; } + + private: + const tint::Source source_; + const Symbol name_; + const type::StructBase* struct_; + const type::Type* type_; + const uint32_t index_; + const uint32_t offset_; + const uint32_t align_; + const uint32_t size_; + const std::optional location_; +}; + +} // namespace tint::type + +#endif // SRC_TINT_TYPE_STRUCT_H_ diff --git a/src/tint/type/struct_test.cc b/src/tint/type/struct_test.cc new file mode 100644 index 0000000000..e8992819c4 --- /dev/null +++ b/src/tint/type/struct_test.cc @@ -0,0 +1,219 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/type/struct.h" +#include "src/tint/type/test_helper.h" +#include "src/tint/type/texture.h" + +namespace tint::type { +namespace { + +using namespace tint::number_suffixes; // NOLINT +using TypeStructTest = TestHelper; + +TEST_F(TypeStructTest, Creation) { + auto name = Sym("S"); + auto* s = create(Source{}, name, utils::Empty, 4u /* align */, 8u /* size */, + 16u /* size_no_padding */); + EXPECT_EQ(s->Align(), 4u); + EXPECT_EQ(s->Size(), 8u); + EXPECT_EQ(s->SizeNoPadding(), 16u); +} + +TEST_F(TypeStructTest, Hash) { + auto* a = create(Source{}, Sym("a"), utils::Empty, 4u /* align */, + 4u /* size */, 4u /* size_no_padding */); + auto* b = create(Source{}, Sym("b"), utils::Empty, 4u /* align */, + 4u /* size */, 4u /* size_no_padding */); + + EXPECT_NE(a->Hash(), b->Hash()); +} + +TEST_F(TypeStructTest, Equals) { + auto* a = create(Source{}, Sym("a"), utils::Empty, 4u /* align */, + 4u /* size */, 4u /* size_no_padding */); + auto* b = create(Source{}, Sym("b"), utils::Empty, 4u /* align */, + 4u /* size */, 4u /* size_no_padding */); + + EXPECT_TRUE(a->Equals(*a)); + EXPECT_FALSE(a->Equals(*b)); + EXPECT_FALSE(a->Equals(type::Void{})); +} + +TEST_F(TypeStructTest, FriendlyName) { + auto name = Sym("my_struct"); + auto* s = create(Source{}, name, utils::Empty, 4u /* align */, 4u /* size */, + 4u /* size_no_padding */); + EXPECT_EQ(s->FriendlyName(Symbols()), "my_struct"); +} + +TEST_F(TypeStructTest, Layout) { + auto* inner_st = // + Structure("Inner", utils::Vector{ + Member("a", ty.i32()), + Member("b", ty.u32()), + Member("c", ty.f32()), + Member("d", ty.vec3()), + Member("e", ty.mat4x2()), + }); + + auto* outer_st = Structure("Outer", utils::Vector{ + Member("inner", ty.type_name("Inner")), + Member("a", ty.i32()), + }); + + auto p = Build(); + ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); + + auto* sem_inner_st = p.Sem().Get(inner_st); + auto* sem_outer_st = p.Sem().Get(outer_st); + + EXPECT_EQ(sem_inner_st->Layout(p.Symbols()), + R"(/* align(16) size(64) */ struct Inner { +/* offset( 0) align( 4) size( 4) */ a : i32; +/* offset( 4) align( 4) size( 4) */ b : u32; +/* offset( 8) align( 4) size( 4) */ c : f32; +/* offset(12) align( 1) size( 4) */ // -- implicit field alignment padding --; +/* offset(16) align(16) size(12) */ d : vec3; +/* offset(28) align( 1) size( 4) */ // -- implicit field alignment padding --; +/* offset(32) align( 8) size(32) */ e : mat4x2; +/* */ };)"); + + EXPECT_EQ(sem_outer_st->Layout(p.Symbols()), + R"(/* align(16) size(80) */ struct Outer { +/* offset( 0) align(16) size(64) */ inner : Inner; +/* offset(64) align( 4) size( 4) */ a : i32; +/* offset(68) align( 1) size(12) */ // -- implicit struct size padding --; +/* */ };)"); +} + +TEST_F(TypeStructTest, Location) { + auto* st = Structure("st", utils::Vector{ + Member("a", ty.i32(), utils::Vector{Location(1_u)}), + Member("b", ty.u32()), + }); + + auto p = Build(); + ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); + + auto* sem = p.Sem().Get(st); + ASSERT_EQ(2u, sem->Members().Length()); + + EXPECT_TRUE(sem->Members()[0]->Location().has_value()); + EXPECT_EQ(sem->Members()[0]->Location().value(), 1u); + + EXPECT_FALSE(sem->Members()[1]->Location().has_value()); +} + +TEST_F(TypeStructTest, IsConstructable) { + auto* inner = // + Structure("Inner", utils::Vector{ + Member("a", ty.i32()), + Member("b", ty.u32()), + Member("c", ty.f32()), + Member("d", ty.vec3()), + Member("e", ty.mat4x2()), + }); + + auto* outer = Structure("Outer", utils::Vector{ + Member("inner", ty.type_name("Inner")), + Member("a", ty.i32()), + }); + + auto* outer_runtime_sized_array = + Structure("OuterRuntimeSizedArray", utils::Vector{ + Member("inner", ty.type_name("Inner")), + Member("a", ty.i32()), + Member("runtime_sized_array", ty.array()), + }); + auto p = Build(); + ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); + + auto* sem_inner = p.Sem().Get(inner); + auto* sem_outer = p.Sem().Get(outer); + auto* sem_outer_runtime_sized_array = p.Sem().Get(outer_runtime_sized_array); + + EXPECT_TRUE(sem_inner->IsConstructible()); + EXPECT_TRUE(sem_outer->IsConstructible()); + EXPECT_FALSE(sem_outer_runtime_sized_array->IsConstructible()); +} + +TEST_F(TypeStructTest, HasCreationFixedFootprint) { + auto* inner = // + Structure("Inner", utils::Vector{ + Member("a", ty.i32()), + Member("b", ty.u32()), + Member("c", ty.f32()), + Member("d", ty.vec3()), + Member("e", ty.mat4x2()), + Member("f", ty.array()), + }); + + auto* outer = Structure("Outer", utils::Vector{ + Member("inner", ty.type_name("Inner")), + }); + + auto* outer_with_runtime_sized_array = + Structure("OuterRuntimeSizedArray", utils::Vector{ + Member("inner", ty.type_name("Inner")), + Member("runtime_sized_array", ty.array()), + }); + + auto p = Build(); + ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); + + auto* sem_inner = p.Sem().Get(inner); + auto* sem_outer = p.Sem().Get(outer); + auto* sem_outer_with_runtime_sized_array = p.Sem().Get(outer_with_runtime_sized_array); + + EXPECT_TRUE(sem_inner->HasCreationFixedFootprint()); + EXPECT_TRUE(sem_outer->HasCreationFixedFootprint()); + EXPECT_FALSE(sem_outer_with_runtime_sized_array->HasCreationFixedFootprint()); +} + +TEST_F(TypeStructTest, HasFixedFootprint) { + auto* inner = // + Structure("Inner", utils::Vector{ + Member("a", ty.i32()), + Member("b", ty.u32()), + Member("c", ty.f32()), + Member("d", ty.vec3()), + Member("e", ty.mat4x2()), + Member("f", ty.array()), + }); + + auto* outer = Structure("Outer", utils::Vector{ + Member("inner", ty.type_name("Inner")), + }); + + auto* outer_with_runtime_sized_array = + Structure("OuterRuntimeSizedArray", utils::Vector{ + Member("inner", ty.type_name("Inner")), + Member("runtime_sized_array", ty.array()), + }); + + auto p = Build(); + ASSERT_TRUE(p.IsValid()) << p.Diagnostics().str(); + + auto* sem_inner = p.Sem().Get(inner); + auto* sem_outer = p.Sem().Get(outer); + auto* sem_outer_with_runtime_sized_array = p.Sem().Get(outer_with_runtime_sized_array); + + EXPECT_TRUE(sem_inner->HasFixedFootprint()); + EXPECT_TRUE(sem_outer->HasFixedFootprint()); + EXPECT_FALSE(sem_outer_with_runtime_sized_array->HasFixedFootprint()); +} + +} // namespace +} // namespace tint::type diff --git a/src/tint/type/type.cc b/src/tint/type/type.cc index 3250033b00..0fb90e399c 100644 --- a/src/tint/type/type.cc +++ b/src/tint/type/type.cc @@ -182,7 +182,7 @@ bool Type::HoldsAbstract() const { [&](const type::Vector* v) { return v->type()->HoldsAbstract(); }, [&](const type::Matrix* m) { return m->type()->HoldsAbstract(); }, [&](const sem::Array* a) { return a->ElemType()->HoldsAbstract(); }, - [&](const sem::Struct* s) { + [&](const type::StructBase* s) { for (auto* m : s->Members()) { if (m->Type()->HoldsAbstract()) { return true; @@ -240,7 +240,7 @@ uint32_t Type::ConversionRank(const Type* from, const Type* to) { } return kNoConversion; }, - [&](const sem::Struct* from_str) { + [&](const type::StructBase* from_str) { auto concrete_tys = from_str->ConcreteTypes(); for (size_t i = 0; i < concrete_tys.Length(); i++) { if (concrete_tys[i] == to) { diff --git a/src/tint/type/type_manager.h b/src/tint/type/type_manager.h index c6319e5b9c..49b420a0d8 100644 --- a/src/tint/type/type_manager.h +++ b/src/tint/type/type_manager.h @@ -91,7 +91,7 @@ class TypeManager final { /// pointer is returned. template || - traits::IsTypeOrDerived>, + traits::IsTypeOrDerived>, typename... ARGS> TYPE* GetNode(ARGS&&... args) { return nodes_.Get(std::forward(args)...); @@ -119,8 +119,8 @@ struct hash { size_t operator()(const tint::type::Node& type) const { if (const auto* ac = type.As()) { return ac->Hash(); - } else if (type.Is()) { - return tint::TypeInfo::Of().full_hashcode; + } else if (type.Is()) { + return tint::TypeInfo::Of().full_hashcode; } TINT_ASSERT(Type, false && "Unreachable"); return 0; @@ -139,7 +139,7 @@ struct equal_to { return ac->Equals(*bc); } return false; - } else if (a.Is()) { + } else if (a.Is()) { return &a == &b; } TINT_ASSERT(Type, false && "Unreachable"); diff --git a/src/tint/type/type_test.cc b/src/tint/type/type_test.cc index cbb2e2a380..c061e40961 100644 --- a/src/tint/type/type_test.cc +++ b/src/tint/type/type_test.cc @@ -45,60 +45,54 @@ struct TypeTest : public TestHelper { const type::Matrix* mat4x3_af = create(vec3_af, 4u); const type::Reference* ref_u32 = create(u32, ast::AddressSpace::kPrivate, ast::Access::kReadWrite); - const sem::Struct* str_f32 = create(nullptr, - Source{}, - Sym("str_f32"), - utils::Vector{ - create( - /* declaration */ nullptr, - /* source */ Source{}, - /* name */ Sym("x"), - /* type */ f32, - /* index */ 0u, - /* offset */ 0u, - /* align */ 4u, - /* size */ 4u, - /* location */ std::nullopt), - }, - /* align*/ 4u, - /* size*/ 4u, - /* size_no_padding*/ 4u); - const sem::Struct* str_f16 = create(nullptr, - Source{}, - Sym("str_f16"), - utils::Vector{ - create( - /* declaration */ nullptr, - /* source */ Source{}, - /* name */ Sym("x"), - /* type */ f16, - /* index */ 0u, - /* offset */ 0u, - /* align */ 4u, - /* size */ 4u, - /* location */ std::nullopt), - }, - /* align*/ 4u, - /* size*/ 4u, - /* size_no_padding*/ 4u); - sem::Struct* str_af = create(nullptr, - Source{}, - Sym("str_af"), - utils::Vector{ - create( - /* declaration */ nullptr, - /* source */ Source{}, - /* name */ Sym("x"), - /* type */ af, - /* index */ 0u, - /* offset */ 0u, - /* align */ 4u, - /* size */ 4u, - /* location */ std::nullopt), - }, - /* align*/ 4u, - /* size*/ 4u, - /* size_no_padding*/ 4u); + const type::StructBase* str_f32 = create(Source{}, + Sym("str_f32"), + utils::Vector{ + create( + /* source */ Source{}, + /* name */ Sym("x"), + /* type */ f32, + /* index */ 0u, + /* offset */ 0u, + /* align */ 4u, + /* size */ 4u, + /* location */ std::nullopt), + }, + /* align*/ 4u, + /* size*/ 4u, + /* size_no_padding*/ 4u); + const type::StructBase* str_f16 = create(Source{}, + Sym("str_f16"), + utils::Vector{ + create( + /* source */ Source{}, + /* name */ Sym("x"), + /* type */ f16, + /* index */ 0u, + /* offset */ 0u, + /* align */ 4u, + /* size */ 4u, + /* location */ std::nullopt), + }, + /* align*/ 4u, + /* size*/ 4u, + /* size_no_padding*/ 4u); + type::StructBase* str_af = create(Source{}, + Sym("str_af"), + utils::Vector{ + create( + /* source */ Source{}, + /* name */ Sym("x"), + /* type */ af, + /* index */ 0u, + /* offset */ 0u, + /* align */ 4u, + /* size */ 4u, + /* location */ std::nullopt), + }, + /* align*/ 4u, + /* size*/ 4u, + /* size_no_padding*/ 4u); const sem::Array* arr_i32 = create( /* element */ i32, /* count */ create(5u), diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index f783ad7b70..bccd7bf1e5 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -4155,16 +4155,16 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) { } auto loc = mem->Location().value(); - if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) { + if (pipeline_stage_uses.count(type::PipelineStageUsage::kVertexInput)) { post += " : TEXCOORD" + std::to_string(loc); } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kVertexOutput)) { + type::PipelineStageUsage::kVertexOutput)) { post += " : TEXCOORD" + std::to_string(loc); } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kFragmentInput)) { + type::PipelineStageUsage::kFragmentInput)) { post += " : TEXCOORD" + std::to_string(loc); } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kFragmentOutput)) { + type::PipelineStageUsage::kFragmentOutput)) { post += " : SV_Target" + std::to_string(loc); } else { TINT_ICE(Writer, diagnostics_) << "invalid use of location attribute"; diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 3a0c6705af..9f5b9d9abd 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -2841,16 +2841,16 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) { } uint32_t loc = mem->Location().value(); - if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) { + if (pipeline_stage_uses.count(type::PipelineStageUsage::kVertexInput)) { out << " [[attribute(" + std::to_string(loc) + ")]]"; } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kVertexOutput)) { + type::PipelineStageUsage::kVertexOutput)) { out << " [[user(locn" + std::to_string(loc) + ")]]"; } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kFragmentInput)) { + type::PipelineStageUsage::kFragmentInput)) { out << " [[user(locn" + std::to_string(loc) + ")]]"; } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kFragmentOutput)) { + type::PipelineStageUsage::kFragmentOutput)) { out << " [[color(" + std::to_string(loc) + ")]]"; } else { TINT_ICE(Writer, diagnostics_) << "invalid use of location decoration";