TypeDeterminer: Resolve swizzles

Have the TD resolve swizzles down to indices, erroring out if they're not valid.

Resolving these at TD time removes swizzle parsing in the HLSL writer, and is generally useful information.

If we don't sanitize in the TD, we can end up trying to construct a resulting vector of an invalid size (> 4) triggering an assert in the type::Vector constructor.

Fixed: chromium:1180634
Bug: tint:79
Change-Id: If1282c933d65eb02d26a8dc7e190f27801ef9dc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42221
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2021-02-24 14:15:02 +00:00 committed by Commit Bot service account
parent 8b1851dbdc
commit 6d612ad478
7 changed files with 118 additions and 42 deletions

View File

@ -486,7 +486,10 @@ class ProgramBuilder {
/// @param expr the expression
/// @return expr
ast::Expression* Expr(ast::Expression* expr) { return expr; }
template <typename T>
traits::EnableIfIsType<T, ast::Expression>* Expr(T* expr) {
return expr;
}
/// @param name the identifier name
/// @return an ast::IdentifierExpression with the given name
@ -948,7 +951,7 @@ class ProgramBuilder {
/// @param idx the index argument for the array accessor expression
/// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx`
template <typename OBJ, typename IDX>
ast::Expression* MemberAccessor(OBJ&& obj, IDX&& idx) {
ast::MemberAccessorExpression* MemberAccessor(OBJ&& obj, IDX&& idx) {
return create<ast::MemberAccessorExpression>(Expr(std::forward<OBJ>(obj)),
Expr(std::forward<IDX>(idx)));
}

View File

@ -15,6 +15,8 @@
#ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
#define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
#include <vector>
#include "src/semantic/expression.h"
namespace tint {
@ -29,17 +31,24 @@ class MemberAccessorExpression
/// @param declaration the AST node
/// @param type the resolved type of the expression
/// @param statement the statement that owns this expression
/// @param is_swizzle true if this member access is for a vector swizzle
/// @param swizzle if this member access is for a vector swizzle, the swizzle
/// indices
MemberAccessorExpression(ast::Expression* declaration,
type::Type* type,
Statement* statement,
bool is_swizzle);
std::vector<uint32_t> swizzle);
/// Destructor
~MemberAccessorExpression() override;
/// @return true if this member access is for a vector swizzle
bool IsSwizzle() const { return is_swizzle_; }
bool IsSwizzle() const { return !swizzle_.empty(); }
/// @return the swizzle indices, if this is a vector swizzle
const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
private:
bool const is_swizzle_;
std::vector<uint32_t> const swizzle_;
};
} // namespace semantic

View File

@ -19,11 +19,14 @@ TINT_INSTANTIATE_CLASS_ID(tint::semantic::MemberAccessorExpression);
namespace tint {
namespace semantic {
MemberAccessorExpression::MemberAccessorExpression(ast::Expression* declaration,
MemberAccessorExpression::MemberAccessorExpression(
ast::Expression* declaration,
type::Type* type,
Statement* statement,
bool is_swizzle)
: Base(declaration, type, statement), is_swizzle_(is_swizzle) {}
std::vector<uint32_t> swizzle)
: Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
MemberAccessorExpression::~MemberAccessorExpression() = default;
} // namespace semantic
} // namespace tint

View File

@ -84,6 +84,13 @@ class Source {
/// @param e the range end location
inline Range(const Location& b, const Location& e) : begin(b), end(e) {}
/// Return a column-shifted Range
/// @param n the number of characters to shift by
/// @returns a Range with a #begin and #end column shifted by `n`
inline Range operator+(size_t n) const {
return Range{{begin.line, begin.column + n}, {end.line, end.column + n}};
}
/// The location of the first character in the range.
Location begin;
/// The location of one-past the last character in the range.
@ -127,6 +134,13 @@ class Source {
return Source(Range{range.end}, file_path, file_content);
}
/// Return a column-shifted Source
/// @param n the number of characters to shift by
/// @returns a Source with the range's columns shifted by `n`
inline Source operator+(size_t n) const {
return Source(range + n, file_path, file_content);
}
/// range is the span of text this source refers to in #file_path
Range range;
/// file is the optional file path this source refers to

View File

@ -752,7 +752,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
type::Type* ret = nullptr;
bool is_swizzle = false;
std::vector<uint32_t> swizzle;
if (auto* ty = data_type->As<type::Struct>()) {
auto* strct = ty->impl();
@ -777,9 +777,42 @@ bool TypeDeterminer::DetermineMemberAccessor(
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
}
} else if (auto* vec = data_type->As<type::Vector>()) {
is_swizzle = true;
std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
auto size = str.size();
swizzle.reserve(str.size());
for (auto c : str) {
switch (c) {
case 'x':
case 'r':
swizzle.emplace_back(0);
break;
case 'y':
case 'g':
swizzle.emplace_back(1);
break;
case 'z':
case 'b':
swizzle.emplace_back(2);
break;
case 'w':
case 'a':
swizzle.emplace_back(3);
break;
default:
diagnostics_.add_error(
"invalid vector swizzle character",
expr->member()->source().Begin() + swizzle.size());
return false;
}
}
if (size < 1 || size > 4) {
diagnostics_.add_error("invalid vector swizzle size",
expr->member()->source());
return false;
}
auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size();
if (size == 1) {
// A single element swizzle is just the type of the vector.
ret = vec->type();
@ -788,15 +821,15 @@ bool TypeDeterminer::DetermineMemberAccessor(
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
}
} else {
// The vector will have a number of components equal to the length of the
// swizzle. This assumes the validator will check that the swizzle
// The vector will have a number of components equal to the length of
// the swizzle. This assumes the validator will check that the swizzle
// is correct.
ret = builder_->create<type::Vector>(vec->type(),
static_cast<uint32_t>(size));
}
} else {
diagnostics_.add_error(
"v-0007: invalid use of member accessor on a non-vector/non-struct " +
"invalid use of member accessor on a non-vector/non-struct " +
data_type->type_name(),
expr->source());
return false;
@ -804,7 +837,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
builder_->Sem().Add(expr,
builder_->create<semantic::MemberAccessorExpression>(
expr, ret, current_statement_, is_swizzle));
expr, ret, current_statement_, std::move(swizzle)));
SetType(expr, ret);
return true;

View File

@ -55,6 +55,7 @@
#include "src/semantic/call.h"
#include "src/semantic/expression.h"
#include "src/semantic/function.h"
#include "src/semantic/member_accessor_expression.h"
#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
#include "src/type/access_control_type.h"
@ -75,6 +76,7 @@
#include "src/type/u32_type.h"
#include "src/type/vector_type.h"
using ::testing::ElementsAre;
using ::testing::HasSubstr;
namespace tint {
@ -1005,7 +1007,7 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct_Alias) {
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* mem = MemberAccessor("my_vec", "xy");
auto* mem = MemberAccessor("my_vec", "xzyw");
WrapInFunction(mem);
EXPECT_TRUE(td()->Determine()) << td()->error();
@ -1013,13 +1015,14 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
ASSERT_NE(TypeOf(mem), nullptr);
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* mem = MemberAccessor("my_vec", "x");
auto* mem = MemberAccessor("my_vec", "b");
WrapInFunction(mem);
EXPECT_TRUE(td()->Determine()) << td()->error();
@ -1029,6 +1032,34 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
auto* ptr = TypeOf(mem)->As<type::Pointer>();
ASSERT_TRUE(ptr->type()->Is<type::F32>());
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadChar) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* ident = create<ast::IdentifierExpression>(
Source{{Source::Location{3, 3}, Source::Location{3, 7}}},
Symbols().Register("xyqz"));
auto* mem = MemberAccessor("my_vec", ident);
WrapInFunction(mem);
EXPECT_FALSE(td()->Determine());
EXPECT_EQ(td()->error(), "3:5 error: invalid vector swizzle character");
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadLength) {
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* ident = create<ast::IdentifierExpression>(
Source{{Source::Location{3, 3}, Source::Location{3, 8}}},
Symbols().Register("zzzzz"));
auto* mem = MemberAccessor("my_vec", ident);
WrapInFunction(mem);
EXPECT_FALSE(td()->Determine());
EXPECT_EQ(td()->error(), "3:3 error: invalid vector swizzle size");
}
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {

View File

@ -91,22 +91,6 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
stmts->last()->Is<ast::FallthroughStatement>();
}
uint32_t convert_swizzle_to_index(const std::string& swizzle) {
if (swizzle == "r" || swizzle == "x") {
return 0;
}
if (swizzle == "g" || swizzle == "y") {
return 1;
}
if (swizzle == "b" || swizzle == "z") {
return 2;
}
if (swizzle == "a" || swizzle == "w") {
return 3;
}
return 0;
}
const char* image_format_to_rwtexture_type(type::ImageFormat image_format) {
switch (image_format) {
case type::ImageFormat::kRgba8Unorm:
@ -2084,11 +2068,13 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
out << str_member->offset();
} else if (res_type->Is<type::Vector>()) {
auto swizzle = builder_.Sem().Get(mem)->Swizzle();
// TODO(dsinclair): Swizzle stuff
//
// This must be a single element swizzle if we've got a vector at this
// point.
if (builder_.Symbols().NameFor(mem->member()->symbol()).size() != 1) {
if (swizzle.size() != 1) {
diagnostics_.add_error(
"Encountered multi-element swizzle when should have only one "
"level");
@ -2098,10 +2084,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
// TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
// so this is assuming 4. This will need to be fixed when we get f16 or
// f64 types.
out << "(4 * "
<< convert_swizzle_to_index(
builder_.Symbols().NameFor(mem->member()->symbol()))
<< ")";
out << "(4 * " << swizzle[0] << ")";
} else {
diagnostics_.add_error("Invalid result type for member accessor: " +
res_type->type_name());