Validate scalar constructor and implement conversion to vecN<bool> in spir-v backend

After implementing validation and fairly exhaustive tests, discovered
that conversion of scalar vector to bool vector did not work in the
spir-v backend. For module scope variables, we use and rely on the
FoldConstants transform to ensure no conversion needs to take place.
This is necessary because we cannot easily introduce temporary values
and refer to them when casting at module scope. Note that for the same
reason, module-level conversions are always constant foldable, so this
works. For function-level conversions, implemented support to emit a
comparison against a zero value, and store the result in the bool
vector.

Bug: tint:865
Change-Id: I0528045e803f176e03428bc7eac31ae06920bbd7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54744
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2021-06-18 15:32:21 +00:00
committed by Tint LUCI CQ
parent 0507273047
commit adbbd0ba66
19 changed files with 851 additions and 56 deletions

View File

@@ -1988,6 +1988,9 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
if (auto* mat_type = type->As<sem::Matrix>()) {
return ValidateMatrixConstructor(type_ctor, mat_type);
}
if (type->is_scalar()) {
return ValidateScalarConstructor(type_ctor, type);
}
if (auto* arr_type = type->As<sem::Array>()) {
return ValidateArrayConstructor(type_ctor, arr_type);
}
@@ -2154,6 +2157,45 @@ bool Resolver::ValidateMatrixConstructor(
return true;
}
bool Resolver::ValidateScalarConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Type* type) {
if (ctor->values().size() == 0) {
return true;
}
if (ctor->values().size() > 1) {
diagnostics_.add_error("expected zero or one value in constructor, got " +
std::to_string(ctor->values().size()),
ctor->source());
return false;
}
// Validate constructor
auto* value = ctor->values()[0];
auto* value_type = TypeOf(value)->UnwrapRef();
using Bool = sem::Bool;
using I32 = sem::I32;
using U32 = sem::U32;
using F32 = sem::F32;
const bool is_valid =
(type->Is<Bool>() && value_type->IsAnyOf<Bool, I32, U32, F32>()) ||
(type->Is<I32>() && value_type->IsAnyOf<I32, U32, F32>()) ||
(type->Is<U32>() && value_type->IsAnyOf<I32, U32, F32>()) ||
(type->Is<F32>() && value_type->IsAnyOf<I32, U32, F32>());
if (!is_valid) {
diagnostics_.add_error("cannot construct '" + TypeNameOf(ctor) +
"' with a value of type '" + TypeNameOf(value) +
"'",
ctor->source());
return false;
}
return true;
}
bool Resolver::Identifier(ast::IdentifierExpression* expr) {
auto symbol = expr->symbol();
VariableInfo* var;

View File

@@ -285,6 +285,8 @@ class Resolver {
const std::string& rhs_type_name);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type);
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Type* type);
bool ValidateArrayConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Array* arr_type);
bool ValidateTypeDecl(const ast::TypeDecl* named_type) const;

View File

@@ -384,6 +384,23 @@ struct DataType<alias<T, ID>> {
}
};
/// Struct of all creation pointer types
struct CreatePtrs {
/// ast node type create function
ast_type_func_ptr ast;
/// ast expression type create function
ast_expr_func_ptr expr;
/// sem type create function
sem_type_func_ptr sem;
};
/// Returns a CreatePtrs struct instance with all creation pointer types for
/// type `T`
template <typename T>
constexpr CreatePtrs CreatePtrsFor() {
return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
}
} // namespace builder
} // namespace resolver

View File

@@ -20,31 +20,24 @@ namespace resolver {
namespace {
// Helpers and typedefs
template <typename T>
using DataType = builder::DataType<T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T>
using alias = builder::alias<T>;
template <typename T>
using alias1 = builder::alias1<T>;
template <typename T>
using alias2 = builder::alias2<T>;
template <typename T>
using alias3 = builder::alias3<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
using builder::alias;
using builder::alias1;
using builder::alias2;
using builder::alias3;
using builder::CreatePtrs;
using builder::CreatePtrsFor;
using builder::DataType;
using builder::f32;
using builder::i32;
using builder::mat2x2;
using builder::mat2x3;
using builder::mat3x2;
using builder::mat3x3;
using builder::mat4x4;
using builder::u32;
using builder::vec2;
using builder::vec3;
using builder::vec4;
class ResolverTypeConstructorValidationTest : public resolver::TestHelper,
public testing::Test {};
@@ -235,6 +228,185 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
} // namespace InferTypeTest
namespace ConversionConstructorTest {
struct Params {
builder::ast_type_func_ptr lhs_type;
builder::ast_type_func_ptr rhs_type;
builder::ast_expr_func_ptr rhs_value_expr;
};
template <typename LhsType, typename RhsType>
constexpr Params ParamsFor() {
return Params{DataType<LhsType>::AST, DataType<RhsType>::AST,
DataType<RhsType>::Expr};
}
static constexpr Params valid_cases[] = {
// Direct init (non-conversions)
ParamsFor<bool, bool>(), //
ParamsFor<i32, i32>(), //
ParamsFor<u32, u32>(), //
ParamsFor<f32, f32>(), //
ParamsFor<vec3<bool>, vec3<bool>>(), //
ParamsFor<vec3<i32>, vec3<i32>>(), //
ParamsFor<vec3<u32>, vec3<u32>>(), //
ParamsFor<vec3<f32>, vec3<f32>>(), //
// Splat
ParamsFor<vec3<bool>, bool>(), //
ParamsFor<vec3<i32>, i32>(), //
ParamsFor<vec3<u32>, u32>(), //
ParamsFor<vec3<f32>, f32>(), //
// Conversion
ParamsFor<bool, u32>(), //
ParamsFor<bool, i32>(), //
ParamsFor<bool, f32>(), //
ParamsFor<i32, u32>(), //
ParamsFor<i32, f32>(), //
ParamsFor<u32, i32>(), //
ParamsFor<u32, f32>(), //
ParamsFor<f32, u32>(), //
ParamsFor<f32, i32>(), //
ParamsFor<vec3<bool>, vec3<u32>>(), //
ParamsFor<vec3<bool>, vec3<i32>>(), //
ParamsFor<vec3<bool>, vec3<f32>>(), //
ParamsFor<vec3<i32>, vec3<u32>>(), //
ParamsFor<vec3<i32>, vec3<f32>>(), //
ParamsFor<vec3<u32>, vec3<i32>>(), //
ParamsFor<vec3<u32>, vec3<f32>>(), //
ParamsFor<vec3<f32>, vec3<u32>>(), //
ParamsFor<vec3<f32>, vec3<i32>>(), //
};
using ConversionConstructorValidTest = ResolverTestWithParam<Params>;
TEST_P(ConversionConstructorValidTest, All) {
auto& params = GetParam();
// var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
auto* lhs_type1 = params.lhs_type(*this);
auto* lhs_type2 = params.lhs_type(*this);
auto* rhs_type = params.rhs_type(*this);
auto* rhs_value_expr = params.rhs_value_expr(*this, 0);
std::stringstream ss;
ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
<< FriendlyName(rhs_type) << "(<rhs value expr>))";
SCOPED_TRACE(ss.str());
auto* a = Var("a", lhs_type1, ast::StorageClass::kNone,
Construct(lhs_type2, Construct(rhs_type, rhs_value_expr)));
// Self-assign 'a' to force the expression to be resolved so we can test its
// type below
auto* a_ident = Expr("a");
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
ConversionConstructorValidTest,
testing::ValuesIn(valid_cases));
constexpr CreatePtrs all_types[] = {
CreatePtrsFor<bool>(), //
CreatePtrsFor<u32>(), //
CreatePtrsFor<i32>(), //
CreatePtrsFor<f32>(), //
CreatePtrsFor<vec3<bool>>(), //
CreatePtrsFor<vec3<i32>>(), //
CreatePtrsFor<vec3<u32>>(), //
CreatePtrsFor<vec3<f32>>(), //
CreatePtrsFor<mat3x3<i32>>(), //
CreatePtrsFor<mat3x3<u32>>(), //
CreatePtrsFor<mat3x3<f32>>(), //
CreatePtrsFor<mat2x3<i32>>(), //
CreatePtrsFor<mat2x3<u32>>(), //
CreatePtrsFor<mat2x3<f32>>(), //
CreatePtrsFor<mat3x2<i32>>(), //
CreatePtrsFor<mat3x2<u32>>(), //
CreatePtrsFor<mat3x2<f32>>() //
};
using ConversionConstructorInvalidTest =
ResolverTestWithParam<std::tuple<CreatePtrs, // lhs
CreatePtrs // rhs
>>;
TEST_P(ConversionConstructorInvalidTest, All) {
auto& params = GetParam();
auto& lhs_params = std::get<0>(params);
auto& rhs_params = std::get<1>(params);
// Skip test for valid cases
for (auto& v : valid_cases) {
if (v.lhs_type == lhs_params.ast && v.rhs_type == rhs_params.ast &&
v.rhs_value_expr == rhs_params.expr) {
return;
}
}
// Skip non-conversions
if (lhs_params.ast == rhs_params.ast) {
return;
}
// var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
auto* lhs_type1 = lhs_params.ast(*this);
auto* lhs_type2 = lhs_params.ast(*this);
auto* rhs_type = rhs_params.ast(*this);
auto* rhs_value_expr = rhs_params.expr(*this, 0);
std::stringstream ss;
ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
<< FriendlyName(rhs_type) << "(<rhs value expr>))";
SCOPED_TRACE(ss.str());
auto* a = Var("a", lhs_type1, ast::StorageClass::kNone,
Construct(lhs_type2, Construct(rhs_type, rhs_value_expr)));
// Self-assign 'a' to force the expression to be resolved so we can test its
// type below
auto* a_ident = Expr("a");
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_FALSE(r()->Resolve());
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
ConversionConstructorInvalidTest,
testing::Combine(testing::ValuesIn(all_types),
testing::ValuesIn(all_types)));
TEST_F(ResolverTypeConstructorValidationTest,
ConversionConstructorInvalid_TooManyInitializers) {
auto* a = Var("a", ty.f32(), ast::StorageClass::kNone,
Construct(Source{{12, 34}}, ty.f32(), Expr(1.0f), Expr(2.0f)));
WrapInFunction(a);
ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(),
"12:34 error: expected zero or one value in constructor, got 2");
}
TEST_F(ResolverTypeConstructorValidationTest,
ConversionConstructorInvalid_InvalidInitializer) {
auto* a = Var("a", ty.f32(), ast::StorageClass::kNone,
Construct(Source{{12, 34}}, ty.f32(), Expr(true)));
WrapInFunction(a);
ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(),
"12:34 error: cannot construct 'f32' with a value of type 'bool'");
}
} // namespace ConversionConstructorTest
} // namespace
} // namespace resolver
} // namespace tint