diff --git a/src/program_builder.h b/src/program_builder.h index 24a018640a..d4b8ff6a71 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -486,7 +486,10 @@ class ProgramBuilder { /// @param expr the expression /// @return expr - ast::Expression* Expr(ast::Expression* expr) { return expr; } + template + traits::EnableIfIsType* Expr(T* expr) { + return expr; + } /// @param name the identifier name /// @return an ast::IdentifierExpression with the given name @@ -948,7 +951,7 @@ class ProgramBuilder { /// @param idx the index argument for the array accessor expression /// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx` template - ast::Expression* MemberAccessor(OBJ&& obj, IDX&& idx) { + ast::MemberAccessorExpression* MemberAccessor(OBJ&& obj, IDX&& idx) { return create(Expr(std::forward(obj)), Expr(std::forward(idx))); } diff --git a/src/semantic/member_accessor_expression.h b/src/semantic/member_accessor_expression.h index 4a7d7f60ed..71268b4801 100644 --- a/src/semantic/member_accessor_expression.h +++ b/src/semantic/member_accessor_expression.h @@ -15,6 +15,8 @@ #ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_ #define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_ +#include + #include "src/semantic/expression.h" namespace tint { @@ -29,17 +31,24 @@ class MemberAccessorExpression /// @param declaration the AST node /// @param type the resolved type of the expression /// @param statement the statement that owns this expression - /// @param is_swizzle true if this member access is for a vector swizzle + /// @param swizzle if this member access is for a vector swizzle, the swizzle + /// indices MemberAccessorExpression(ast::Expression* declaration, type::Type* type, Statement* statement, - bool is_swizzle); + std::vector swizzle); + + /// Destructor + ~MemberAccessorExpression() override; /// @return true if this member access is for a vector swizzle - bool IsSwizzle() const { return is_swizzle_; } + bool IsSwizzle() const { return !swizzle_.empty(); } + + /// @return the swizzle indices, if this is a vector swizzle + const std::vector& Swizzle() const { return swizzle_; } private: - bool const is_swizzle_; + std::vector const swizzle_; }; } // namespace semantic diff --git a/src/semantic/sem_member_accessor_expression.cc b/src/semantic/sem_member_accessor_expression.cc index 470f05907b..fdbaeb384b 100644 --- a/src/semantic/sem_member_accessor_expression.cc +++ b/src/semantic/sem_member_accessor_expression.cc @@ -19,11 +19,14 @@ TINT_INSTANTIATE_CLASS_ID(tint::semantic::MemberAccessorExpression); namespace tint { namespace semantic { -MemberAccessorExpression::MemberAccessorExpression(ast::Expression* declaration, - type::Type* type, - Statement* statement, - bool is_swizzle) - : Base(declaration, type, statement), is_swizzle_(is_swizzle) {} +MemberAccessorExpression::MemberAccessorExpression( + ast::Expression* declaration, + type::Type* type, + Statement* statement, + std::vector swizzle) + : Base(declaration, type, statement), swizzle_(std::move(swizzle)) {} + +MemberAccessorExpression::~MemberAccessorExpression() = default; } // namespace semantic } // namespace tint diff --git a/src/source.h b/src/source.h index 9b6e74ce16..381d7371d0 100644 --- a/src/source.h +++ b/src/source.h @@ -84,6 +84,13 @@ class Source { /// @param e the range end location inline Range(const Location& b, const Location& e) : begin(b), end(e) {} + /// Return a column-shifted Range + /// @param n the number of characters to shift by + /// @returns a Range with a #begin and #end column shifted by `n` + inline Range operator+(size_t n) const { + return Range{{begin.line, begin.column + n}, {end.line, end.column + n}}; + } + /// The location of the first character in the range. Location begin; /// The location of one-past the last character in the range. @@ -127,6 +134,13 @@ class Source { return Source(Range{range.end}, file_path, file_content); } + /// Return a column-shifted Source + /// @param n the number of characters to shift by + /// @returns a Source with the range's columns shifted by `n` + inline Source operator+(size_t n) const { + return Source(range + n, file_path, file_content); + } + /// range is the span of text this source refers to in #file_path Range range; /// file is the optional file path this source refers to diff --git a/src/type_determiner.cc b/src/type_determiner.cc index fd334c255a..066cf0a8fd 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -752,7 +752,7 @@ bool TypeDeterminer::DetermineMemberAccessor( auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); type::Type* ret = nullptr; - bool is_swizzle = false; + std::vector swizzle; if (auto* ty = data_type->As()) { auto* strct = ty->impl(); @@ -777,9 +777,42 @@ bool TypeDeterminer::DetermineMemberAccessor( ret = builder_->create(ret, ptr->storage_class()); } } else if (auto* vec = data_type->As()) { - is_swizzle = true; + std::string str = builder_->Symbols().NameFor(expr->member()->symbol()); + auto size = str.size(); + swizzle.reserve(str.size()); + + for (auto c : str) { + switch (c) { + case 'x': + case 'r': + swizzle.emplace_back(0); + break; + case 'y': + case 'g': + swizzle.emplace_back(1); + break; + case 'z': + case 'b': + swizzle.emplace_back(2); + break; + case 'w': + case 'a': + swizzle.emplace_back(3); + break; + default: + diagnostics_.add_error( + "invalid vector swizzle character", + expr->member()->source().Begin() + swizzle.size()); + return false; + } + } + + if (size < 1 || size > 4) { + diagnostics_.add_error("invalid vector swizzle size", + expr->member()->source()); + return false; + } - auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size(); if (size == 1) { // A single element swizzle is just the type of the vector. ret = vec->type(); @@ -788,15 +821,15 @@ bool TypeDeterminer::DetermineMemberAccessor( ret = builder_->create(ret, ptr->storage_class()); } } else { - // The vector will have a number of components equal to the length of the - // swizzle. This assumes the validator will check that the swizzle + // The vector will have a number of components equal to the length of + // the swizzle. This assumes the validator will check that the swizzle // is correct. ret = builder_->create(vec->type(), static_cast(size)); } } else { diagnostics_.add_error( - "v-0007: invalid use of member accessor on a non-vector/non-struct " + + "invalid use of member accessor on a non-vector/non-struct " + data_type->type_name(), expr->source()); return false; @@ -804,7 +837,7 @@ bool TypeDeterminer::DetermineMemberAccessor( builder_->Sem().Add(expr, builder_->create( - expr, ret, current_statement_, is_swizzle)); + expr, ret, current_statement_, std::move(swizzle))); SetType(expr, ret); return true; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 801260d4f9..6f39628a8d 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -55,6 +55,7 @@ #include "src/semantic/call.h" #include "src/semantic/expression.h" #include "src/semantic/function.h" +#include "src/semantic/member_accessor_expression.h" #include "src/semantic/statement.h" #include "src/semantic/variable.h" #include "src/type/access_control_type.h" @@ -75,6 +76,7 @@ #include "src/type/u32_type.h" #include "src/type/vector_type.h" +using ::testing::ElementsAre; using ::testing::HasSubstr; namespace tint { @@ -1005,7 +1007,7 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct_Alias) { TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { Global("my_vec", ty.vec3(), ast::StorageClass::kNone); - auto* mem = MemberAccessor("my_vec", "xy"); + auto* mem = MemberAccessor("my_vec", "xzyw"); WrapInFunction(mem); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -1013,13 +1015,14 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { ASSERT_NE(TypeOf(mem), nullptr); ASSERT_TRUE(TypeOf(mem)->Is()); EXPECT_TRUE(TypeOf(mem)->As()->type()->Is()); - EXPECT_EQ(TypeOf(mem)->As()->size(), 2u); + EXPECT_EQ(TypeOf(mem)->As()->size(), 4u); + EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3)); } TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { Global("my_vec", ty.vec3(), ast::StorageClass::kNone); - auto* mem = MemberAccessor("my_vec", "x"); + auto* mem = MemberAccessor("my_vec", "b"); WrapInFunction(mem); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -1029,6 +1032,34 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { auto* ptr = TypeOf(mem)->As(); ASSERT_TRUE(ptr->type()->Is()); + EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2)); +} + +TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadChar) { + Global("my_vec", ty.vec3(), ast::StorageClass::kNone); + + auto* ident = create( + Source{{Source::Location{3, 3}, Source::Location{3, 7}}}, + Symbols().Register("xyqz")); + + auto* mem = MemberAccessor("my_vec", ident); + WrapInFunction(mem); + + EXPECT_FALSE(td()->Determine()); + EXPECT_EQ(td()->error(), "3:5 error: invalid vector swizzle character"); +} + +TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadLength) { + Global("my_vec", ty.vec3(), ast::StorageClass::kNone); + + auto* ident = create( + Source{{Source::Location{3, 3}, Source::Location{3, 8}}}, + Symbols().Register("zzzzz")); + auto* mem = MemberAccessor("my_vec", ident); + WrapInFunction(mem); + + EXPECT_FALSE(td()->Determine()); + EXPECT_EQ(td()->error(), "3:3 error: invalid vector swizzle size"); } TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) { diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 0a732f3a01..7fb193dde9 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -91,22 +91,6 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { stmts->last()->Is(); } -uint32_t convert_swizzle_to_index(const std::string& swizzle) { - if (swizzle == "r" || swizzle == "x") { - return 0; - } - if (swizzle == "g" || swizzle == "y") { - return 1; - } - if (swizzle == "b" || swizzle == "z") { - return 2; - } - if (swizzle == "a" || swizzle == "w") { - return 3; - } - return 0; -} - const char* image_format_to_rwtexture_type(type::ImageFormat image_format) { switch (image_format) { case type::ImageFormat::kRgba8Unorm: @@ -2084,11 +2068,13 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( out << str_member->offset(); } else if (res_type->Is()) { + auto swizzle = builder_.Sem().Get(mem)->Swizzle(); + // TODO(dsinclair): Swizzle stuff // // This must be a single element swizzle if we've got a vector at this // point. - if (builder_.Symbols().NameFor(mem->member()->symbol()).size() != 1) { + if (swizzle.size() != 1) { diagnostics_.add_error( "Encountered multi-element swizzle when should have only one " "level"); @@ -2098,10 +2084,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( // TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32) // so this is assuming 4. This will need to be fixed when we get f16 or // f64 types. - out << "(4 * " - << convert_swizzle_to_index( - builder_.Symbols().NameFor(mem->member()->symbol())) - << ")"; + out << "(4 * " << swizzle[0] << ")"; } else { diagnostics_.add_error("Invalid result type for member accessor: " + res_type->type_name());