TypeDeterminer: Resolve swizzles
Have the TD resolve swizzles down to indices, erroring out if they're not valid. Resolving these at TD time removes swizzle parsing in the HLSL writer, and is generally useful information. If we don't sanitize in the TD, we can end up trying to construct a resulting vector of an invalid size (> 4) triggering an assert in the type::Vector constructor. Fixed: chromium:1180634 Bug: tint:79 Change-Id: If1282c933d65eb02d26a8dc7e190f27801ef9dc5 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42221 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
8b1851dbdc
commit
6d612ad478
|
@ -486,7 +486,10 @@ class ProgramBuilder {
|
|||
|
||||
/// @param expr the expression
|
||||
/// @return expr
|
||||
ast::Expression* Expr(ast::Expression* expr) { return expr; }
|
||||
template <typename T>
|
||||
traits::EnableIfIsType<T, ast::Expression>* Expr(T* expr) {
|
||||
return expr;
|
||||
}
|
||||
|
||||
/// @param name the identifier name
|
||||
/// @return an ast::IdentifierExpression with the given name
|
||||
|
@ -948,7 +951,7 @@ class ProgramBuilder {
|
|||
/// @param idx the index argument for the array accessor expression
|
||||
/// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx`
|
||||
template <typename OBJ, typename IDX>
|
||||
ast::Expression* MemberAccessor(OBJ&& obj, IDX&& idx) {
|
||||
ast::MemberAccessorExpression* MemberAccessor(OBJ&& obj, IDX&& idx) {
|
||||
return create<ast::MemberAccessorExpression>(Expr(std::forward<OBJ>(obj)),
|
||||
Expr(std::forward<IDX>(idx)));
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
#ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
|
||||
#define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "src/semantic/expression.h"
|
||||
|
||||
namespace tint {
|
||||
|
@ -29,17 +31,24 @@ class MemberAccessorExpression
|
|||
/// @param declaration the AST node
|
||||
/// @param type the resolved type of the expression
|
||||
/// @param statement the statement that owns this expression
|
||||
/// @param is_swizzle true if this member access is for a vector swizzle
|
||||
/// @param swizzle if this member access is for a vector swizzle, the swizzle
|
||||
/// indices
|
||||
MemberAccessorExpression(ast::Expression* declaration,
|
||||
type::Type* type,
|
||||
Statement* statement,
|
||||
bool is_swizzle);
|
||||
std::vector<uint32_t> swizzle);
|
||||
|
||||
/// Destructor
|
||||
~MemberAccessorExpression() override;
|
||||
|
||||
/// @return true if this member access is for a vector swizzle
|
||||
bool IsSwizzle() const { return is_swizzle_; }
|
||||
bool IsSwizzle() const { return !swizzle_.empty(); }
|
||||
|
||||
/// @return the swizzle indices, if this is a vector swizzle
|
||||
const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
|
||||
|
||||
private:
|
||||
bool const is_swizzle_;
|
||||
std::vector<uint32_t> const swizzle_;
|
||||
};
|
||||
|
||||
} // namespace semantic
|
||||
|
|
|
@ -19,11 +19,14 @@ TINT_INSTANTIATE_CLASS_ID(tint::semantic::MemberAccessorExpression);
|
|||
namespace tint {
|
||||
namespace semantic {
|
||||
|
||||
MemberAccessorExpression::MemberAccessorExpression(ast::Expression* declaration,
|
||||
type::Type* type,
|
||||
Statement* statement,
|
||||
bool is_swizzle)
|
||||
: Base(declaration, type, statement), is_swizzle_(is_swizzle) {}
|
||||
MemberAccessorExpression::MemberAccessorExpression(
|
||||
ast::Expression* declaration,
|
||||
type::Type* type,
|
||||
Statement* statement,
|
||||
std::vector<uint32_t> swizzle)
|
||||
: Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
|
||||
|
||||
MemberAccessorExpression::~MemberAccessorExpression() = default;
|
||||
|
||||
} // namespace semantic
|
||||
} // namespace tint
|
||||
|
|
14
src/source.h
14
src/source.h
|
@ -84,6 +84,13 @@ class Source {
|
|||
/// @param e the range end location
|
||||
inline Range(const Location& b, const Location& e) : begin(b), end(e) {}
|
||||
|
||||
/// Return a column-shifted Range
|
||||
/// @param n the number of characters to shift by
|
||||
/// @returns a Range with a #begin and #end column shifted by `n`
|
||||
inline Range operator+(size_t n) const {
|
||||
return Range{{begin.line, begin.column + n}, {end.line, end.column + n}};
|
||||
}
|
||||
|
||||
/// The location of the first character in the range.
|
||||
Location begin;
|
||||
/// The location of one-past the last character in the range.
|
||||
|
@ -127,6 +134,13 @@ class Source {
|
|||
return Source(Range{range.end}, file_path, file_content);
|
||||
}
|
||||
|
||||
/// Return a column-shifted Source
|
||||
/// @param n the number of characters to shift by
|
||||
/// @returns a Source with the range's columns shifted by `n`
|
||||
inline Source operator+(size_t n) const {
|
||||
return Source(range + n, file_path, file_content);
|
||||
}
|
||||
|
||||
/// range is the span of text this source refers to in #file_path
|
||||
Range range;
|
||||
/// file is the optional file path this source refers to
|
||||
|
|
|
@ -752,7 +752,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
|||
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
|
||||
|
||||
type::Type* ret = nullptr;
|
||||
bool is_swizzle = false;
|
||||
std::vector<uint32_t> swizzle;
|
||||
|
||||
if (auto* ty = data_type->As<type::Struct>()) {
|
||||
auto* strct = ty->impl();
|
||||
|
@ -777,9 +777,42 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
|||
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
|
||||
}
|
||||
} else if (auto* vec = data_type->As<type::Vector>()) {
|
||||
is_swizzle = true;
|
||||
std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
|
||||
auto size = str.size();
|
||||
swizzle.reserve(str.size());
|
||||
|
||||
for (auto c : str) {
|
||||
switch (c) {
|
||||
case 'x':
|
||||
case 'r':
|
||||
swizzle.emplace_back(0);
|
||||
break;
|
||||
case 'y':
|
||||
case 'g':
|
||||
swizzle.emplace_back(1);
|
||||
break;
|
||||
case 'z':
|
||||
case 'b':
|
||||
swizzle.emplace_back(2);
|
||||
break;
|
||||
case 'w':
|
||||
case 'a':
|
||||
swizzle.emplace_back(3);
|
||||
break;
|
||||
default:
|
||||
diagnostics_.add_error(
|
||||
"invalid vector swizzle character",
|
||||
expr->member()->source().Begin() + swizzle.size());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (size < 1 || size > 4) {
|
||||
diagnostics_.add_error("invalid vector swizzle size",
|
||||
expr->member()->source());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size();
|
||||
if (size == 1) {
|
||||
// A single element swizzle is just the type of the vector.
|
||||
ret = vec->type();
|
||||
|
@ -788,15 +821,15 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
|||
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
|
||||
}
|
||||
} else {
|
||||
// The vector will have a number of components equal to the length of the
|
||||
// swizzle. This assumes the validator will check that the swizzle
|
||||
// The vector will have a number of components equal to the length of
|
||||
// the swizzle. This assumes the validator will check that the swizzle
|
||||
// is correct.
|
||||
ret = builder_->create<type::Vector>(vec->type(),
|
||||
static_cast<uint32_t>(size));
|
||||
}
|
||||
} else {
|
||||
diagnostics_.add_error(
|
||||
"v-0007: invalid use of member accessor on a non-vector/non-struct " +
|
||||
"invalid use of member accessor on a non-vector/non-struct " +
|
||||
data_type->type_name(),
|
||||
expr->source());
|
||||
return false;
|
||||
|
@ -804,7 +837,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
|||
|
||||
builder_->Sem().Add(expr,
|
||||
builder_->create<semantic::MemberAccessorExpression>(
|
||||
expr, ret, current_statement_, is_swizzle));
|
||||
expr, ret, current_statement_, std::move(swizzle)));
|
||||
SetType(expr, ret);
|
||||
|
||||
return true;
|
||||
|
|
|
@ -55,6 +55,7 @@
|
|||
#include "src/semantic/call.h"
|
||||
#include "src/semantic/expression.h"
|
||||
#include "src/semantic/function.h"
|
||||
#include "src/semantic/member_accessor_expression.h"
|
||||
#include "src/semantic/statement.h"
|
||||
#include "src/semantic/variable.h"
|
||||
#include "src/type/access_control_type.h"
|
||||
|
@ -75,6 +76,7 @@
|
|||
#include "src/type/u32_type.h"
|
||||
#include "src/type/vector_type.h"
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
namespace tint {
|
||||
|
@ -1005,7 +1007,7 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct_Alias) {
|
|||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
||||
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* mem = MemberAccessor("my_vec", "xy");
|
||||
auto* mem = MemberAccessor("my_vec", "xzyw");
|
||||
WrapInFunction(mem);
|
||||
|
||||
EXPECT_TRUE(td()->Determine()) << td()->error();
|
||||
|
@ -1013,13 +1015,14 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
|||
ASSERT_NE(TypeOf(mem), nullptr);
|
||||
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
|
||||
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
|
||||
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
|
||||
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
|
||||
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* mem = MemberAccessor("my_vec", "x");
|
||||
auto* mem = MemberAccessor("my_vec", "b");
|
||||
WrapInFunction(mem);
|
||||
|
||||
EXPECT_TRUE(td()->Determine()) << td()->error();
|
||||
|
@ -1029,6 +1032,34 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
|
|||
|
||||
auto* ptr = TypeOf(mem)->As<type::Pointer>();
|
||||
ASSERT_TRUE(ptr->type()->Is<type::F32>());
|
||||
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadChar) {
|
||||
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* ident = create<ast::IdentifierExpression>(
|
||||
Source{{Source::Location{3, 3}, Source::Location{3, 7}}},
|
||||
Symbols().Register("xyqz"));
|
||||
|
||||
auto* mem = MemberAccessor("my_vec", ident);
|
||||
WrapInFunction(mem);
|
||||
|
||||
EXPECT_FALSE(td()->Determine());
|
||||
EXPECT_EQ(td()->error(), "3:5 error: invalid vector swizzle character");
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadLength) {
|
||||
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* ident = create<ast::IdentifierExpression>(
|
||||
Source{{Source::Location{3, 3}, Source::Location{3, 8}}},
|
||||
Symbols().Register("zzzzz"));
|
||||
auto* mem = MemberAccessor("my_vec", ident);
|
||||
WrapInFunction(mem);
|
||||
|
||||
EXPECT_FALSE(td()->Determine());
|
||||
EXPECT_EQ(td()->error(), "3:3 error: invalid vector swizzle size");
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
|
||||
|
|
|
@ -91,22 +91,6 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
|
|||
stmts->last()->Is<ast::FallthroughStatement>();
|
||||
}
|
||||
|
||||
uint32_t convert_swizzle_to_index(const std::string& swizzle) {
|
||||
if (swizzle == "r" || swizzle == "x") {
|
||||
return 0;
|
||||
}
|
||||
if (swizzle == "g" || swizzle == "y") {
|
||||
return 1;
|
||||
}
|
||||
if (swizzle == "b" || swizzle == "z") {
|
||||
return 2;
|
||||
}
|
||||
if (swizzle == "a" || swizzle == "w") {
|
||||
return 3;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char* image_format_to_rwtexture_type(type::ImageFormat image_format) {
|
||||
switch (image_format) {
|
||||
case type::ImageFormat::kRgba8Unorm:
|
||||
|
@ -2084,11 +2068,13 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
|
|||
out << str_member->offset();
|
||||
|
||||
} else if (res_type->Is<type::Vector>()) {
|
||||
auto swizzle = builder_.Sem().Get(mem)->Swizzle();
|
||||
|
||||
// TODO(dsinclair): Swizzle stuff
|
||||
//
|
||||
// This must be a single element swizzle if we've got a vector at this
|
||||
// point.
|
||||
if (builder_.Symbols().NameFor(mem->member()->symbol()).size() != 1) {
|
||||
if (swizzle.size() != 1) {
|
||||
diagnostics_.add_error(
|
||||
"Encountered multi-element swizzle when should have only one "
|
||||
"level");
|
||||
|
@ -2098,10 +2084,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
|
|||
// TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
|
||||
// so this is assuming 4. This will need to be fixed when we get f16 or
|
||||
// f64 types.
|
||||
out << "(4 * "
|
||||
<< convert_swizzle_to_index(
|
||||
builder_.Symbols().NameFor(mem->member()->symbol()))
|
||||
<< ")";
|
||||
out << "(4 * " << swizzle[0] << ")";
|
||||
} else {
|
||||
diagnostics_.add_error("Invalid result type for member accessor: " +
|
||||
res_type->type_name());
|
||||
|
|
Loading…
Reference in New Issue