TypeDeterminer: Resolve swizzles
Have the TD resolve swizzles down to indices, erroring out if they're not valid. Resolving these at TD time removes swizzle parsing in the HLSL writer, and is generally useful information. If we don't sanitize in the TD, we can end up trying to construct a resulting vector of an invalid size (> 4) triggering an assert in the type::Vector constructor. Fixed: chromium:1180634 Bug: tint:79 Change-Id: If1282c933d65eb02d26a8dc7e190f27801ef9dc5 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42221 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
8b1851dbdc
commit
6d612ad478
|
@ -486,7 +486,10 @@ class ProgramBuilder {
|
||||||
|
|
||||||
/// @param expr the expression
|
/// @param expr the expression
|
||||||
/// @return expr
|
/// @return expr
|
||||||
ast::Expression* Expr(ast::Expression* expr) { return expr; }
|
template <typename T>
|
||||||
|
traits::EnableIfIsType<T, ast::Expression>* Expr(T* expr) {
|
||||||
|
return expr;
|
||||||
|
}
|
||||||
|
|
||||||
/// @param name the identifier name
|
/// @param name the identifier name
|
||||||
/// @return an ast::IdentifierExpression with the given name
|
/// @return an ast::IdentifierExpression with the given name
|
||||||
|
@ -948,7 +951,7 @@ class ProgramBuilder {
|
||||||
/// @param idx the index argument for the array accessor expression
|
/// @param idx the index argument for the array accessor expression
|
||||||
/// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx`
|
/// @returns a `ast::MemberAccessorExpression` that indexes `obj` with `idx`
|
||||||
template <typename OBJ, typename IDX>
|
template <typename OBJ, typename IDX>
|
||||||
ast::Expression* MemberAccessor(OBJ&& obj, IDX&& idx) {
|
ast::MemberAccessorExpression* MemberAccessor(OBJ&& obj, IDX&& idx) {
|
||||||
return create<ast::MemberAccessorExpression>(Expr(std::forward<OBJ>(obj)),
|
return create<ast::MemberAccessorExpression>(Expr(std::forward<OBJ>(obj)),
|
||||||
Expr(std::forward<IDX>(idx)));
|
Expr(std::forward<IDX>(idx)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
#ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
|
#ifndef SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
|
||||||
#define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
|
#define SRC_SEMANTIC_MEMBER_ACCESSOR_EXPRESSION_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "src/semantic/expression.h"
|
#include "src/semantic/expression.h"
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
@ -29,17 +31,24 @@ class MemberAccessorExpression
|
||||||
/// @param declaration the AST node
|
/// @param declaration the AST node
|
||||||
/// @param type the resolved type of the expression
|
/// @param type the resolved type of the expression
|
||||||
/// @param statement the statement that owns this expression
|
/// @param statement the statement that owns this expression
|
||||||
/// @param is_swizzle true if this member access is for a vector swizzle
|
/// @param swizzle if this member access is for a vector swizzle, the swizzle
|
||||||
|
/// indices
|
||||||
MemberAccessorExpression(ast::Expression* declaration,
|
MemberAccessorExpression(ast::Expression* declaration,
|
||||||
type::Type* type,
|
type::Type* type,
|
||||||
Statement* statement,
|
Statement* statement,
|
||||||
bool is_swizzle);
|
std::vector<uint32_t> swizzle);
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
~MemberAccessorExpression() override;
|
||||||
|
|
||||||
/// @return true if this member access is for a vector swizzle
|
/// @return true if this member access is for a vector swizzle
|
||||||
bool IsSwizzle() const { return is_swizzle_; }
|
bool IsSwizzle() const { return !swizzle_.empty(); }
|
||||||
|
|
||||||
|
/// @return the swizzle indices, if this is a vector swizzle
|
||||||
|
const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool const is_swizzle_;
|
std::vector<uint32_t> const swizzle_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace semantic
|
} // namespace semantic
|
||||||
|
|
|
@ -19,11 +19,14 @@ TINT_INSTANTIATE_CLASS_ID(tint::semantic::MemberAccessorExpression);
|
||||||
namespace tint {
|
namespace tint {
|
||||||
namespace semantic {
|
namespace semantic {
|
||||||
|
|
||||||
MemberAccessorExpression::MemberAccessorExpression(ast::Expression* declaration,
|
MemberAccessorExpression::MemberAccessorExpression(
|
||||||
|
ast::Expression* declaration,
|
||||||
type::Type* type,
|
type::Type* type,
|
||||||
Statement* statement,
|
Statement* statement,
|
||||||
bool is_swizzle)
|
std::vector<uint32_t> swizzle)
|
||||||
: Base(declaration, type, statement), is_swizzle_(is_swizzle) {}
|
: Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
|
||||||
|
|
||||||
|
MemberAccessorExpression::~MemberAccessorExpression() = default;
|
||||||
|
|
||||||
} // namespace semantic
|
} // namespace semantic
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
14
src/source.h
14
src/source.h
|
@ -84,6 +84,13 @@ class Source {
|
||||||
/// @param e the range end location
|
/// @param e the range end location
|
||||||
inline Range(const Location& b, const Location& e) : begin(b), end(e) {}
|
inline Range(const Location& b, const Location& e) : begin(b), end(e) {}
|
||||||
|
|
||||||
|
/// Return a column-shifted Range
|
||||||
|
/// @param n the number of characters to shift by
|
||||||
|
/// @returns a Range with a #begin and #end column shifted by `n`
|
||||||
|
inline Range operator+(size_t n) const {
|
||||||
|
return Range{{begin.line, begin.column + n}, {end.line, end.column + n}};
|
||||||
|
}
|
||||||
|
|
||||||
/// The location of the first character in the range.
|
/// The location of the first character in the range.
|
||||||
Location begin;
|
Location begin;
|
||||||
/// The location of one-past the last character in the range.
|
/// The location of one-past the last character in the range.
|
||||||
|
@ -127,6 +134,13 @@ class Source {
|
||||||
return Source(Range{range.end}, file_path, file_content);
|
return Source(Range{range.end}, file_path, file_content);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return a column-shifted Source
|
||||||
|
/// @param n the number of characters to shift by
|
||||||
|
/// @returns a Source with the range's columns shifted by `n`
|
||||||
|
inline Source operator+(size_t n) const {
|
||||||
|
return Source(range + n, file_path, file_content);
|
||||||
|
}
|
||||||
|
|
||||||
/// range is the span of text this source refers to in #file_path
|
/// range is the span of text this source refers to in #file_path
|
||||||
Range range;
|
Range range;
|
||||||
/// file is the optional file path this source refers to
|
/// file is the optional file path this source refers to
|
||||||
|
|
|
@ -752,7 +752,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
||||||
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
|
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
|
||||||
|
|
||||||
type::Type* ret = nullptr;
|
type::Type* ret = nullptr;
|
||||||
bool is_swizzle = false;
|
std::vector<uint32_t> swizzle;
|
||||||
|
|
||||||
if (auto* ty = data_type->As<type::Struct>()) {
|
if (auto* ty = data_type->As<type::Struct>()) {
|
||||||
auto* strct = ty->impl();
|
auto* strct = ty->impl();
|
||||||
|
@ -777,9 +777,42 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
||||||
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
|
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
|
||||||
}
|
}
|
||||||
} else if (auto* vec = data_type->As<type::Vector>()) {
|
} else if (auto* vec = data_type->As<type::Vector>()) {
|
||||||
is_swizzle = true;
|
std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
|
||||||
|
auto size = str.size();
|
||||||
|
swizzle.reserve(str.size());
|
||||||
|
|
||||||
|
for (auto c : str) {
|
||||||
|
switch (c) {
|
||||||
|
case 'x':
|
||||||
|
case 'r':
|
||||||
|
swizzle.emplace_back(0);
|
||||||
|
break;
|
||||||
|
case 'y':
|
||||||
|
case 'g':
|
||||||
|
swizzle.emplace_back(1);
|
||||||
|
break;
|
||||||
|
case 'z':
|
||||||
|
case 'b':
|
||||||
|
swizzle.emplace_back(2);
|
||||||
|
break;
|
||||||
|
case 'w':
|
||||||
|
case 'a':
|
||||||
|
swizzle.emplace_back(3);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
diagnostics_.add_error(
|
||||||
|
"invalid vector swizzle character",
|
||||||
|
expr->member()->source().Begin() + swizzle.size());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (size < 1 || size > 4) {
|
||||||
|
diagnostics_.add_error("invalid vector swizzle size",
|
||||||
|
expr->member()->source());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto size = builder_->Symbols().NameFor(expr->member()->symbol()).size();
|
|
||||||
if (size == 1) {
|
if (size == 1) {
|
||||||
// A single element swizzle is just the type of the vector.
|
// A single element swizzle is just the type of the vector.
|
||||||
ret = vec->type();
|
ret = vec->type();
|
||||||
|
@ -788,15 +821,15 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
||||||
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
|
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// The vector will have a number of components equal to the length of the
|
// The vector will have a number of components equal to the length of
|
||||||
// swizzle. This assumes the validator will check that the swizzle
|
// the swizzle. This assumes the validator will check that the swizzle
|
||||||
// is correct.
|
// is correct.
|
||||||
ret = builder_->create<type::Vector>(vec->type(),
|
ret = builder_->create<type::Vector>(vec->type(),
|
||||||
static_cast<uint32_t>(size));
|
static_cast<uint32_t>(size));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
diagnostics_.add_error(
|
diagnostics_.add_error(
|
||||||
"v-0007: invalid use of member accessor on a non-vector/non-struct " +
|
"invalid use of member accessor on a non-vector/non-struct " +
|
||||||
data_type->type_name(),
|
data_type->type_name(),
|
||||||
expr->source());
|
expr->source());
|
||||||
return false;
|
return false;
|
||||||
|
@ -804,7 +837,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
||||||
|
|
||||||
builder_->Sem().Add(expr,
|
builder_->Sem().Add(expr,
|
||||||
builder_->create<semantic::MemberAccessorExpression>(
|
builder_->create<semantic::MemberAccessorExpression>(
|
||||||
expr, ret, current_statement_, is_swizzle));
|
expr, ret, current_statement_, std::move(swizzle)));
|
||||||
SetType(expr, ret);
|
SetType(expr, ret);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -55,6 +55,7 @@
|
||||||
#include "src/semantic/call.h"
|
#include "src/semantic/call.h"
|
||||||
#include "src/semantic/expression.h"
|
#include "src/semantic/expression.h"
|
||||||
#include "src/semantic/function.h"
|
#include "src/semantic/function.h"
|
||||||
|
#include "src/semantic/member_accessor_expression.h"
|
||||||
#include "src/semantic/statement.h"
|
#include "src/semantic/statement.h"
|
||||||
#include "src/semantic/variable.h"
|
#include "src/semantic/variable.h"
|
||||||
#include "src/type/access_control_type.h"
|
#include "src/type/access_control_type.h"
|
||||||
|
@ -75,6 +76,7 @@
|
||||||
#include "src/type/u32_type.h"
|
#include "src/type/u32_type.h"
|
||||||
#include "src/type/vector_type.h"
|
#include "src/type/vector_type.h"
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
@ -1005,7 +1007,7 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct_Alias) {
|
||||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
||||||
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||||
|
|
||||||
auto* mem = MemberAccessor("my_vec", "xy");
|
auto* mem = MemberAccessor("my_vec", "xzyw");
|
||||||
WrapInFunction(mem);
|
WrapInFunction(mem);
|
||||||
|
|
||||||
EXPECT_TRUE(td()->Determine()) << td()->error();
|
EXPECT_TRUE(td()->Determine()) << td()->error();
|
||||||
|
@ -1013,13 +1015,14 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
||||||
ASSERT_NE(TypeOf(mem), nullptr);
|
ASSERT_NE(TypeOf(mem), nullptr);
|
||||||
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
|
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
|
||||||
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
|
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
|
||||||
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
|
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
|
||||||
|
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
|
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
|
||||||
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||||
|
|
||||||
auto* mem = MemberAccessor("my_vec", "x");
|
auto* mem = MemberAccessor("my_vec", "b");
|
||||||
WrapInFunction(mem);
|
WrapInFunction(mem);
|
||||||
|
|
||||||
EXPECT_TRUE(td()->Determine()) << td()->error();
|
EXPECT_TRUE(td()->Determine()) << td()->error();
|
||||||
|
@ -1029,6 +1032,34 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
|
||||||
|
|
||||||
auto* ptr = TypeOf(mem)->As<type::Pointer>();
|
auto* ptr = TypeOf(mem)->As<type::Pointer>();
|
||||||
ASSERT_TRUE(ptr->type()->Is<type::F32>());
|
ASSERT_TRUE(ptr->type()->Is<type::F32>());
|
||||||
|
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadChar) {
|
||||||
|
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||||
|
|
||||||
|
auto* ident = create<ast::IdentifierExpression>(
|
||||||
|
Source{{Source::Location{3, 3}, Source::Location{3, 7}}},
|
||||||
|
Symbols().Register("xyqz"));
|
||||||
|
|
||||||
|
auto* mem = MemberAccessor("my_vec", ident);
|
||||||
|
WrapInFunction(mem);
|
||||||
|
|
||||||
|
EXPECT_FALSE(td()->Determine());
|
||||||
|
EXPECT_EQ(td()->error(), "3:5 error: invalid vector swizzle character");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_BadLength) {
|
||||||
|
Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||||
|
|
||||||
|
auto* ident = create<ast::IdentifierExpression>(
|
||||||
|
Source{{Source::Location{3, 3}, Source::Location{3, 8}}},
|
||||||
|
Symbols().Register("zzzzz"));
|
||||||
|
auto* mem = MemberAccessor("my_vec", ident);
|
||||||
|
WrapInFunction(mem);
|
||||||
|
|
||||||
|
EXPECT_FALSE(td()->Determine());
|
||||||
|
EXPECT_EQ(td()->error(), "3:3 error: invalid vector swizzle size");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
|
TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
|
||||||
|
|
|
@ -91,22 +91,6 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
|
||||||
stmts->last()->Is<ast::FallthroughStatement>();
|
stmts->last()->Is<ast::FallthroughStatement>();
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t convert_swizzle_to_index(const std::string& swizzle) {
|
|
||||||
if (swizzle == "r" || swizzle == "x") {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
if (swizzle == "g" || swizzle == "y") {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (swizzle == "b" || swizzle == "z") {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
if (swizzle == "a" || swizzle == "w") {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* image_format_to_rwtexture_type(type::ImageFormat image_format) {
|
const char* image_format_to_rwtexture_type(type::ImageFormat image_format) {
|
||||||
switch (image_format) {
|
switch (image_format) {
|
||||||
case type::ImageFormat::kRgba8Unorm:
|
case type::ImageFormat::kRgba8Unorm:
|
||||||
|
@ -2084,11 +2068,13 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
|
||||||
out << str_member->offset();
|
out << str_member->offset();
|
||||||
|
|
||||||
} else if (res_type->Is<type::Vector>()) {
|
} else if (res_type->Is<type::Vector>()) {
|
||||||
|
auto swizzle = builder_.Sem().Get(mem)->Swizzle();
|
||||||
|
|
||||||
// TODO(dsinclair): Swizzle stuff
|
// TODO(dsinclair): Swizzle stuff
|
||||||
//
|
//
|
||||||
// This must be a single element swizzle if we've got a vector at this
|
// This must be a single element swizzle if we've got a vector at this
|
||||||
// point.
|
// point.
|
||||||
if (builder_.Symbols().NameFor(mem->member()->symbol()).size() != 1) {
|
if (swizzle.size() != 1) {
|
||||||
diagnostics_.add_error(
|
diagnostics_.add_error(
|
||||||
"Encountered multi-element swizzle when should have only one "
|
"Encountered multi-element swizzle when should have only one "
|
||||||
"level");
|
"level");
|
||||||
|
@ -2098,10 +2084,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
|
||||||
// TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
|
// TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
|
||||||
// so this is assuming 4. This will need to be fixed when we get f16 or
|
// so this is assuming 4. This will need to be fixed when we get f16 or
|
||||||
// f64 types.
|
// f64 types.
|
||||||
out << "(4 * "
|
out << "(4 * " << swizzle[0] << ")";
|
||||||
<< convert_swizzle_to_index(
|
|
||||||
builder_.Symbols().NameFor(mem->member()->symbol()))
|
|
||||||
<< ")";
|
|
||||||
} else {
|
} else {
|
||||||
diagnostics_.add_error("Invalid result type for member accessor: " +
|
diagnostics_.add_error("Invalid result type for member accessor: " +
|
||||||
res_type->type_name());
|
res_type->type_name());
|
||||||
|
|
Loading…
Reference in New Issue