TypeDeterminer: Resolve swizzles

Have the TD resolve swizzles down to indices, erroring out if they're not valid.

Resolving these at TD time removes swizzle parsing in the HLSL writer, and is generally useful information.

If we don't sanitize in the TD, we can end up trying to construct a resulting vector of an invalid size (> 4) triggering an assert in the type::Vector constructor.

Fixed: chromium:1180634
Bug: tint:79
Change-Id: If1282c933d65eb02d26a8dc7e190f27801ef9dc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42221
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2021-02-24 14:15:02 +00:00 committed by Commit Bot service account
parent 8b1851dbdc
commit 6d612ad478
7 changed files with 118 additions and 42 deletions

View File

@ -486,7 +486,10 @@ class ProgramBuilder {
/// @param expr the expression /// @param expr the expression
/// @return expr /// @return expr
ast::Expression* Expr(ast::Expression* expr) { return expr; } template <typename T>
traits::EnableIfIsType<T, ast::Expression>* Expr(T* expr) {
return expr;
}
/// @param name the identifier name /// @param name the identifier name
/// @return an ast::IdentifierExpression with the given name /// @return an ast::IdentifierExpression with the given name
@ -948,7 +951,7 @@ class ProgramBuilder {
/// @param idx the index argument for the array accessor expression /// @param idx the index argument for the array accessor expression
/// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx` /// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx`
template <typename OBJ, typename IDX> template <typename OBJ, typename IDX>
ast::Expression* MemberAccessor(OBJ&& obj, IDX&& idx) { ast::MemberAccessorExpression* MemberAccessor(OBJ&& obj, IDX&& idx) {
return create<ast::MemberAccessorExpression>(Expr(std::forward<OBJ>(obj)), return create<ast::MemberAccessorExpression>(Expr(std::forward<OBJ>(obj)),
Expr(std::forward<IDX>(idx))); Expr(std::forward<IDX>(idx)));
} }

View File

@ -15,6 +15,8 @@
#ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_ #ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
#define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_ #define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
#include <vector>
#include "src/semantic/expression.h" #include "src/semantic/expression.h"
namespace tint { namespace tint {
@ -29,17 +31,24 @@ class MemberAccessorExpression
/// @param declaration the AST node /// @param declaration the AST node
/// @param type the resolved type of the expression /// @param type the resolved type of the expression
/// @param statement the statement that owns this expression /// @param statement the statement that owns this expression
/// @param is_swizzle true if this member access is for a vector swizzle /// @param swizzle if this member access is for a vector swizzle, the swizzle
/// indices
MemberAccessorExpression(ast::Expression* declaration, MemberAccessorExpression(ast::Expression* declaration,
type::Type* type, type::Type* type,
Statement* statement, Statement* statement,
bool is_swizzle); std::vector<uint32_t> swizzle);
/// Destructor
~MemberAccessorExpression() override;
/// @return true if this member access is for a vector swizzle /// @return true if this member access is for a vector swizzle
bool IsSwizzle() const { return is_swizzle_; } bool IsSwizzle() const { return !swizzle_.empty(); }
/// @return the swizzle indices, if this is a vector swizzle
const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
private: private:
bool const is_swizzle_; std::vector<uint32_t> const swizzle_;
}; };
} // namespace semantic } // namespace semantic

View File

@ -19,11 +19,14 @@ TINT_INSTANTIATE_CLASS_ID(tint::semantic::MemberAccessorExpression);
namespace tint { namespace tint {
namespace semantic { namespace semantic {
MemberAccessorExpression::MemberAccessorExpression(ast::Expression* declaration, MemberAccessorExpression::MemberAccessorExpression(
ast::Expression* declaration,
type::Type* type, type::Type* type,
Statement* statement, Statement* statement,
bool is_swizzle) std::vector<uint32_t> swizzle)
: Base(declaration, type, statement), is_swizzle_(is_swizzle) {} : Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
MemberAccessorExpression::~MemberAccessorExpression() = default;
} // namespace semantic } // namespace semantic
} // namespace tint } // namespace tint

View File

@ -84,6 +84,13 @@ class Source {
/// @param e the range end location /// @param e the range end location
inline Range(const Location& b, const Location& e) : begin(b), end(e) {} inline Range(const Location& b, const Location& e) : begin(b), end(e) {}
/// Return a column-shifted Range
/// @param n the number of characters to shift by
/// @returns a Range with a #begin and #end column shifted by `n`
inline Range operator+(size_t n) const {
return Range{{begin.line, begin.column + n}, {end.line, end.column + n}};
}
/// The location of the first character in the range. /// The location of the first character in the range.
Location begin; Location begin;
/// The location of one-past the last character in the range. /// The location of one-past the last character in the range.
@ -127,6 +134,13 @@ class Source {
return Source(Range{range.end}, file_path, file_content); return Source(Range{range.end}, file_path, file_content);
} }
/// Return a column-shifted Source
/// @param n the number of characters to shift by
/// @returns a Source with the range's columns shifted by `n`
inline Source operator+(size_t n) const {
return Source(range + n, file_path, file_content);
}
/// range is the span of text this source refers to in #file_path /// range is the span of text this source refers to in #file_path
Range range; Range range;
/// file is the optional file path this source refers to /// file is the optional file path this source refers to

View File

@ -752,7 +752,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
type::Type* ret = nullptr; type::Type* ret = nullptr;
bool is_swizzle = false; std::vector<uint32_t> swizzle;
if (auto* ty = data_type->As<type::Struct>()) { if (auto* ty = data_type->As<type::Struct>()) {
auto* strct = ty->impl(); auto* strct = ty->impl();
@ -777,9 +777,42 @@ bool TypeDeterminer::DetermineMemberAccessor(
ret = builder_->create<type::Pointer>(ret, ptr->storage_class()); ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
} }
} else if (auto* vec = data_type->As<type::Vector>()) { } else if (auto* vec = data_type->As<type::Vector>()) {
is_swizzle = true; std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
auto size = str.size();
swizzle.reserve(str.size());
for (auto c : str) {
switch (c) {
case 'x':
case 'r':
swizzle.emplace_back(0);
break;
case 'y':
case 'g':
swizzle.emplace_back(1);
break;
case 'z':
case 'b':
swizzle.emplace_back(2);
break;
case 'w':
case 'a':
swizzle.emplace_back(3);
break;
default:
diagnostics_.add_error(
"invalid vector swizzle character",
expr->member()->source().Begin() + swizzle.size());
return false;
}
}
if (size < 1 || size > 4) {
diagnostics_.add_error("invalid vector swizzle size",
expr->member()->source());
return false;
}
auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size();
if (size == 1) { if (size == 1) {
// A single element swizzle is just the type of the vector. // A single element swizzle is just the type of the vector.
ret = vec->type(); ret = vec->type();
@ -788,15 +821,15 @@ bool TypeDeterminer::DetermineMemberAccessor(
ret = builder_->create<type::Pointer>(ret, ptr->storage_class()); ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
} }
} else { } else {
// The vector will have a number of components equal to the length of the // The vector will have a number of components equal to the length of
// swizzle. This assumes the validator will check that the swizzle // the swizzle. This assumes the validator will check that the swizzle
// is correct. // is correct.
ret = builder_->create<type::Vector>(vec->type(), ret = builder_->create<type::Vector>(vec->type(),
static_cast<uint32_t>(size)); static_cast<uint32_t>(size));
} }
} else { } else {
diagnostics_.add_error( diagnostics_.add_error(
"v-0007: invalid use of member accessor on a non-vector/non-struct " + "invalid use of member accessor on a non-vector/non-struct " +
data_type->type_name(), data_type->type_name(),
expr->source()); expr->source());
return false; return false;
@ -804,7 +837,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
builder_->Sem().Add(expr, builder_->Sem().Add(expr,
builder_->create<semantic::MemberAccessorExpression>( builder_->create<semantic::MemberAccessorExpression>(
expr, ret, current_statement_, is_swizzle)); expr, ret, current_statement_, std::move(swizzle)));
SetType(expr, ret); SetType(expr, ret);
return true; return true;

View File

@ -55,6 +55,7 @@
#include "src/semantic/call.h" #include "src/semantic/call.h"
#include "src/semantic/expression.h" #include "src/semantic/expression.h"
#include "src/semantic/function.h" #include "src/semantic/function.h"
#include "src/semantic/member_accessor_expression.h"
#include "src/semantic/statement.h" #include "src/semantic/statement.h"
#include "src/semantic/variable.h" #include "src/semantic/variable.h"
#include "src/type/access_control_type.h" #include "src/type/access_control_type.h"
@ -75,6 +76,7 @@
#include "src/type/u32_type.h" #include "src/type/u32_type.h"
#include "src/type/vector_type.h" #include "src/type/vector_type.h"
using ::testing::ElementsAre;
using ::testing::HasSubstr; using ::testing::HasSubstr;
namespace tint { namespace tint {
@ -1005,7 +1007,7 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct_Alias) {
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone); Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* mem = MemberAccessor("my_vec", "xy"); auto* mem = MemberAccessor("my_vec", "xzyw");
WrapInFunction(mem); WrapInFunction(mem);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();
@ -1013,13 +1015,14 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
ASSERT_NE(TypeOf(mem), nullptr); ASSERT_NE(TypeOf(mem), nullptr);
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>()); ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>()); EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u); EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
} }
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone); Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* mem = MemberAccessor("my_vec", "x"); auto* mem = MemberAccessor("my_vec", "b");
WrapInFunction(mem); WrapInFunction(mem);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();
@ -1029,6 +1032,34 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
auto* ptr = TypeOf(mem)->As<type::Pointer>(); auto* ptr = TypeOf(mem)->As<type::Pointer>();
ASSERT_TRUE(ptr->type()->Is<type::F32>()); ASSERT_TRUE(ptr->type()->Is<type::F32>());
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadChar) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* ident = create<ast::IdentifierExpression>(
Source{{Source::Location{3, 3}, Source::Location{3, 7}}},
Symbols().Register("xyqz"));
auto* mem = MemberAccessor("my_vec", ident);
WrapInFunction(mem);
EXPECT_FALSE(td()->Determine());
EXPECT_EQ(td()->error(), "3:5 error: invalid vector swizzle character");
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadLength) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* ident = create<ast::IdentifierExpression>(
Source{{Source::Location{3, 3}, Source::Location{3, 8}}},
Symbols().Register("zzzzz"));
auto* mem = MemberAccessor("my_vec", ident);
WrapInFunction(mem);
EXPECT_FALSE(td()->Determine());
EXPECT_EQ(td()->error(), "3:3 error: invalid vector swizzle size");
} }
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) { TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {

View File

@ -91,22 +91,6 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
stmts->last()->Is<ast::FallthroughStatement>(); stmts->last()->Is<ast::FallthroughStatement>();
} }
uint32_t convert_swizzle_to_index(const std::string& swizzle) {
if (swizzle == "r" || swizzle == "x") {
return 0;
}
if (swizzle == "g" || swizzle == "y") {
return 1;
}
if (swizzle == "b" || swizzle == "z") {
return 2;
}
if (swizzle == "a" || swizzle == "w") {
return 3;
}
return 0;
}
const char* image_format_to_rwtexture_type(type::ImageFormat image_format) { const char* image_format_to_rwtexture_type(type::ImageFormat image_format) {
switch (image_format) { switch (image_format) {
case type::ImageFormat::kRgba8Unorm: case type::ImageFormat::kRgba8Unorm:
@ -2084,11 +2068,13 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
out << str_member->offset(); out << str_member->offset();
} else if (res_type->Is<type::Vector>()) { } else if (res_type->Is<type::Vector>()) {
auto swizzle = builder_.Sem().Get(mem)->Swizzle();
// TODO(dsinclair): Swizzle stuff // TODO(dsinclair): Swizzle stuff
// //
// This must be a single element swizzle if we've got a vector at this // This must be a single element swizzle if we've got a vector at this
// point. // point.
if (builder_.Symbols().NameFor(mem->member()->symbol()).size() != 1) { if (swizzle.size() != 1) {
diagnostics_.add_error( diagnostics_.add_error(
"Encountered multi-element swizzle when should have only one " "Encountered multi-element swizzle when should have only one "
"level"); "level");
@ -2098,10 +2084,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
// TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32) // TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
// so this is assuming 4. This will need to be fixed when we get f16 or // so this is assuming 4. This will need to be fixed when we get f16 or
// f64 types. // f64 types.
out << "(4 * " out << "(4 * " << swizzle[0] << ")";
<< convert_swizzle_to_index(
builder_.Symbols().NameFor(mem->member()->symbol()))
<< ")";
} else { } else {
diagnostics_.add_error("Invalid result type for member accessor: " + diagnostics_.add_error("Invalid result type for member accessor: " +
res_type->type_name()); res_type->type_name());