Validate scalar constructor and implement conversion to vecN<bool> in spir-v backend

After implementing validation and fairly exhaustive tests, discovered
that conversion of scalar vector to bool vector did not work in the
spir-v backend. For module scope variables, we use and rely on the
FoldConstants transform to ensure no conversion needs to take place.
This is necessary because we cannot easily introduce temporary values
and refer to them when casting at module scope. Note that for the same
reason, module-level conversions are always constant foldable, so this
works. For function-level conversions, implemented support to emit a
comparison against a zero value, and store the result in the bool
vector.

Bug: tint:865
Change-Id: I0528045e803f176e03428bc7eac31ae06920bbd7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54744
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2021-06-18 15:32:21 +00:00
committed by Tint LUCI CQ
parent 0507273047
commit adbbd0ba66
19 changed files with 851 additions and 56 deletions

View File

@@ -1988,6 +1988,9 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
if (auto* mat_type = type->As<sem::Matrix>()) {
return ValidateMatrixConstructor(type_ctor, mat_type);
}
if (type->is_scalar()) {
return ValidateScalarConstructor(type_ctor, type);
}
if (auto* arr_type = type->As<sem::Array>()) {
return ValidateArrayConstructor(type_ctor, arr_type);
}
@@ -2154,6 +2157,45 @@ bool Resolver::ValidateMatrixConstructor(
return true;
}
bool Resolver::ValidateScalarConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Type* type) {
if (ctor->values().size() == 0) {
return true;
}
if (ctor->values().size() > 1) {
diagnostics_.add_error("expected zero or one value in constructor, got " +
std::to_string(ctor->values().size()),
ctor->source());
return false;
}
// Validate constructor
auto* value = ctor->values()[0];
auto* value_type = TypeOf(value)->UnwrapRef();
using Bool = sem::Bool;
using I32 = sem::I32;
using U32 = sem::U32;
using F32 = sem::F32;
const bool is_valid =
(type->Is<Bool>() && value_type->IsAnyOf<Bool, I32, U32, F32>()) ||
(type->Is<I32>() && value_type->IsAnyOf<I32, U32, F32>()) ||
(type->Is<U32>() && value_type->IsAnyOf<I32, U32, F32>()) ||
(type->Is<F32>() && value_type->IsAnyOf<I32, U32, F32>());
if (!is_valid) {
diagnostics_.add_error("cannot construct '" + TypeNameOf(ctor) +
"' with a value of type '" + TypeNameOf(value) +
"'",
ctor->source());
return false;
}
return true;
}
bool Resolver::Identifier(ast::IdentifierExpression* expr) {
auto symbol = expr->symbol();
VariableInfo* var;

View File

@@ -285,6 +285,8 @@ class Resolver {
const std::string& rhs_type_name);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type);
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Type* type);
bool ValidateArrayConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Array* arr_type);
bool ValidateTypeDecl(const ast::TypeDecl* named_type) const;

View File

@@ -384,6 +384,23 @@ struct DataType<alias<T, ID>> {
}
};
/// Struct of all creation pointer types
struct CreatePtrs {
/// ast node type create function
ast_type_func_ptr ast;
/// ast expression type create function
ast_expr_func_ptr expr;
/// sem type create function
sem_type_func_ptr sem;
};
/// Returns a CreatePtrs struct instance with all creation pointer types for
/// type `T`
template <typename T>
constexpr CreatePtrs CreatePtrsFor() {
return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
}
} // namespace builder
} // namespace resolver

View File

@@ -20,31 +20,24 @@ namespace resolver {
namespace {
// Helpers and typedefs
template <typename T>
using DataType = builder::DataType<T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T>
using alias = builder::alias<T>;
template <typename T>
using alias1 = builder::alias1<T>;
template <typename T>
using alias2 = builder::alias2<T>;
template <typename T>
using alias3 = builder::alias3<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
using builder::alias;
using builder::alias1;
using builder::alias2;
using builder::alias3;
using builder::CreatePtrs;
using builder::CreatePtrsFor;
using builder::DataType;
using builder::f32;
using builder::i32;
using builder::mat2x2;
using builder::mat2x3;
using builder::mat3x2;
using builder::mat3x3;
using builder::mat4x4;
using builder::u32;
using builder::vec2;
using builder::vec3;
using builder::vec4;
class ResolverTypeConstructorValidationTest : public resolver::TestHelper,
public testing::Test {};
@@ -235,6 +228,185 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
} // namespace InferTypeTest
namespace ConversionConstructorTest {
struct Params {
builder::ast_type_func_ptr lhs_type;
builder::ast_type_func_ptr rhs_type;
builder::ast_expr_func_ptr rhs_value_expr;
};
template <typename LhsType, typename RhsType>
constexpr Params ParamsFor() {
return Params{DataType<LhsType>::AST, DataType<RhsType>::AST,
DataType<RhsType>::Expr};
}
static constexpr Params valid_cases[] = {
// Direct init (non-conversions)
ParamsFor<bool, bool>(), //
ParamsFor<i32, i32>(), //
ParamsFor<u32, u32>(), //
ParamsFor<f32, f32>(), //
ParamsFor<vec3<bool>, vec3<bool>>(), //
ParamsFor<vec3<i32>, vec3<i32>>(), //
ParamsFor<vec3<u32>, vec3<u32>>(), //
ParamsFor<vec3<f32>, vec3<f32>>(), //
// Splat
ParamsFor<vec3<bool>, bool>(), //
ParamsFor<vec3<i32>, i32>(), //
ParamsFor<vec3<u32>, u32>(), //
ParamsFor<vec3<f32>, f32>(), //
// Conversion
ParamsFor<bool, u32>(), //
ParamsFor<bool, i32>(), //
ParamsFor<bool, f32>(), //
ParamsFor<i32, u32>(), //
ParamsFor<i32, f32>(), //
ParamsFor<u32, i32>(), //
ParamsFor<u32, f32>(), //
ParamsFor<f32, u32>(), //
ParamsFor<f32, i32>(), //
ParamsFor<vec3<bool>, vec3<u32>>(), //
ParamsFor<vec3<bool>, vec3<i32>>(), //
ParamsFor<vec3<bool>, vec3<f32>>(), //
ParamsFor<vec3<i32>, vec3<u32>>(), //
ParamsFor<vec3<i32>, vec3<f32>>(), //
ParamsFor<vec3<u32>, vec3<i32>>(), //
ParamsFor<vec3<u32>, vec3<f32>>(), //
ParamsFor<vec3<f32>, vec3<u32>>(), //
ParamsFor<vec3<f32>, vec3<i32>>(), //
};
using ConversionConstructorValidTest = ResolverTestWithParam<Params>;
TEST_P(ConversionConstructorValidTest, All) {
auto& params = GetParam();
// var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
auto* lhs_type1 = params.lhs_type(*this);
auto* lhs_type2 = params.lhs_type(*this);
auto* rhs_type = params.rhs_type(*this);
auto* rhs_value_expr = params.rhs_value_expr(*this, 0);
std::stringstream ss;
ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
<< FriendlyName(rhs_type) << "(<rhs value expr>))";
SCOPED_TRACE(ss.str());
auto* a = Var("a", lhs_type1, ast::StorageClass::kNone,
Construct(lhs_type2, Construct(rhs_type, rhs_value_expr)));
// Self-assign 'a' to force the expression to be resolved so we can test its
// type below
auto* a_ident = Expr("a");
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
ConversionConstructorValidTest,
testing::ValuesIn(valid_cases));
constexpr CreatePtrs all_types[] = {
CreatePtrsFor<bool>(), //
CreatePtrsFor<u32>(), //
CreatePtrsFor<i32>(), //
CreatePtrsFor<f32>(), //
CreatePtrsFor<vec3<bool>>(), //
CreatePtrsFor<vec3<i32>>(), //
CreatePtrsFor<vec3<u32>>(), //
CreatePtrsFor<vec3<f32>>(), //
CreatePtrsFor<mat3x3<i32>>(), //
CreatePtrsFor<mat3x3<u32>>(), //
CreatePtrsFor<mat3x3<f32>>(), //
CreatePtrsFor<mat2x3<i32>>(), //
CreatePtrsFor<mat2x3<u32>>(), //
CreatePtrsFor<mat2x3<f32>>(), //
CreatePtrsFor<mat3x2<i32>>(), //
CreatePtrsFor<mat3x2<u32>>(), //
CreatePtrsFor<mat3x2<f32>>() //
};
using ConversionConstructorInvalidTest =
ResolverTestWithParam<std::tuple<CreatePtrs, // lhs
CreatePtrs // rhs
>>;
TEST_P(ConversionConstructorInvalidTest, All) {
auto& params = GetParam();
auto& lhs_params = std::get<0>(params);
auto& rhs_params = std::get<1>(params);
// Skip test for valid cases
for (auto& v : valid_cases) {
if (v.lhs_type == lhs_params.ast && v.rhs_type == rhs_params.ast &&
v.rhs_value_expr == rhs_params.expr) {
return;
}
}
// Skip non-conversions
if (lhs_params.ast == rhs_params.ast) {
return;
}
// var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
auto* lhs_type1 = lhs_params.ast(*this);
auto* lhs_type2 = lhs_params.ast(*this);
auto* rhs_type = rhs_params.ast(*this);
auto* rhs_value_expr = rhs_params.expr(*this, 0);
std::stringstream ss;
ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
<< FriendlyName(rhs_type) << "(<rhs value expr>))";
SCOPED_TRACE(ss.str());
auto* a = Var("a", lhs_type1, ast::StorageClass::kNone,
Construct(lhs_type2, Construct(rhs_type, rhs_value_expr)));
// Self-assign 'a' to force the expression to be resolved so we can test its
// type below
auto* a_ident = Expr("a");
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_FALSE(r()->Resolve());
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
ConversionConstructorInvalidTest,
testing::Combine(testing::ValuesIn(all_types),
testing::ValuesIn(all_types)));
TEST_F(ResolverTypeConstructorValidationTest,
ConversionConstructorInvalid_TooManyInitializers) {
auto* a = Var("a", ty.f32(), ast::StorageClass::kNone,
Construct(Source{{12, 34}}, ty.f32(), Expr(1.0f), Expr(2.0f)));
WrapInFunction(a);
ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(),
"12:34 error: expected zero or one value in constructor, got 2");
}
TEST_F(ResolverTypeConstructorValidationTest,
ConversionConstructorInvalid_InvalidInitializer) {
auto* a = Var("a", ty.f32(), ast::StorageClass::kNone,
Construct(Source{{12, 34}}, ty.f32(), Expr(true)));
WrapInFunction(a);
ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(),
"12:34 error: cannot construct 'f32' with a value of type 'bool'");
}
} // namespace ConversionConstructorTest
} // namespace
} // namespace resolver
} // namespace tint

View File

@@ -27,6 +27,7 @@
#include "src/sem/struct.h"
#include "src/sem/variable.h"
#include "src/transform/external_texture_transform.h"
#include "src/transform/fold_constants.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
#include "src/transform/simplify.h"
@@ -43,6 +44,7 @@ Output Spirv::Run(const Program* in, const DataMap& data) {
Manager manager;
manager.Add<InlinePointerLets>(); // Required for arrayLength()
manager.Add<Simplify>(); // Required for arrayLength()
manager.Add<FoldConstants>();
manager.Add<ExternalTextureTransform>();
auto transformedInput = manager.Run(in, data);

View File

@@ -1317,7 +1317,8 @@ uint32_t Builder::GenerateTypeConstructorExpression(
}
if (can_cast_or_copy) {
return GenerateCastOrCopyOrPassthrough(result_type, values[0]);
return GenerateCastOrCopyOrPassthrough(result_type, values[0],
is_global_init);
}
auto type_id = GenerateTypeIfNeeded(result_type);
@@ -1361,7 +1362,8 @@ uint32_t Builder::GenerateTypeConstructorExpression(
// Both scalars, but not the same type so we need to generate a conversion
// of the value.
if (value_type->is_scalar() && result_type->is_scalar()) {
id = GenerateCastOrCopyOrPassthrough(result_type, values[0]);
id = GenerateCastOrCopyOrPassthrough(result_type, values[0],
is_global_init);
out << "_" << id;
ops.push_back(Operand::Int(id));
continue;
@@ -1458,7 +1460,27 @@ uint32_t Builder::GenerateTypeConstructorExpression(
}
uint32_t Builder::GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
ast::Expression* from_expr) {
ast::Expression* from_expr,
bool is_global_init) {
// This should not happen as we rely on constant folding to obviate
// casts/conversions for module-scope variables
if (is_global_init) {
TINT_ICE(builder_.Diagnostics())
<< "Module-level conversions are not supported. Conversions should "
"have already been constant-folded by the FoldConstants transform.";
return 0;
}
auto elem_type_of = [](const sem::Type* t) -> const sem::Type* {
if (t->is_scalar()) {
return t;
}
if (auto* v = t->As<sem::Vector>()) {
return v->type();
}
return nullptr;
};
auto result = result_op();
auto result_id = result.to_i();
@@ -1504,7 +1526,27 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
(from_type->is_unsigned_integer_vector() &&
to_type->is_integer_scalar_or_vector())) {
op = spv::Op::OpBitcast;
} else if ((from_type->is_numeric_scalar() && to_type->Is<sem::Bool>()) ||
(from_type->is_numeric_vector() && to_type->is_bool_vector())) {
// Convert scalar (vector) to bool (vector)
// Return the result of comparing from_expr with zero
uint32_t zero = GenerateConstantNullIfNeeded(from_type);
const auto* from_elem_type = elem_type_of(from_type);
op = from_elem_type->is_integer_scalar() ? spv::Op::OpINotEqual
: spv::Op::OpFUnordNotEqual;
if (!push_function_inst(
op, {Operand::Int(result_type_id), Operand::Int(result_id),
Operand::Int(val_id), Operand::Int(zero)})) {
return 0;
}
return result_id;
} else {
TINT_ICE(builder_.Diagnostics()) << "Invalid from_type";
}
if (op == spv::Op::OpNop) {
error_ = "unable to determine conversion type for cast, from: " +
from_type->type_name() + " to: " + to_type->type_name();

View File

@@ -382,9 +382,11 @@ class Builder {
/// of the right type.
/// @param to_type the type we're casting too
/// @param from_expr the expression to cast
/// @param is_global_init if this is a global initializer
/// @returns the expression ID on success or 0 otherwise
uint32_t GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
ast::Expression* from_expr);
ast::Expression* from_expr,
bool is_global_init);
/// Generates a loop statement
/// @param stmt the statement to generate
/// @returns true on successful generation

View File

@@ -637,50 +637,56 @@ TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec2_With_F32) {
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec2_With_Vec2) {
auto* cast = vec2<f32>(vec2<f32>(2.0f, 2.0f));
WrapInFunction(cast);
GlobalConst("a", ty.vec2<f32>(), cast);
spirv::Builder& b = Build();
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build());
b.push_function(Function{});
EXPECT_EQ(b.GenerateConstructorExpression(nullptr, cast, true), 5u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
%2 = OpTypeVector %3 2
%4 = OpConstant %3 2
%5 = OpConstantComposite %2 %4 %4
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 2
%3 = OpConstant %2 2
%4 = OpConstantComposite %1 %3 %3
%6 = OpTypeVoid
%5 = OpTypeFunction %6
)");
Validate(b);
}
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec3_With_Vec3) {
auto* cast = vec3<f32>(vec3<f32>(2.0f, 2.0f, 2.0f));
WrapInFunction(cast);
GlobalConst("a", ty.vec3<f32>(), cast);
spirv::Builder& b = Build();
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build());
b.push_function(Function{});
EXPECT_EQ(b.GenerateConstructorExpression(nullptr, cast, true), 5u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
%2 = OpTypeVector %3 3
%4 = OpConstant %3 2
%5 = OpConstantComposite %2 %4 %4 %4
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
%3 = OpConstant %2 2
%4 = OpConstantComposite %1 %3 %3 %3
%6 = OpTypeVoid
%5 = OpTypeFunction %6
)");
Validate(b);
}
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec4_With_Vec4) {
auto* cast = vec4<f32>(vec4<f32>(2.0f, 2.0f, 2.0f, 2.0f));
WrapInFunction(cast);
GlobalConst("a", ty.vec4<f32>(), cast);
spirv::Builder& b = Build();
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build());
b.push_function(Function{});
EXPECT_EQ(b.GenerateConstructorExpression(nullptr, cast, true), 5u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
%2 = OpTypeVector %3 4
%4 = OpConstant %3 2
%5 = OpConstantComposite %2 %4 %4 %4 %4
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 4
%3 = OpConstant %2 2
%4 = OpConstantComposite %1 %3 %3 %3 %3
%6 = OpTypeVoid
%5 = OpTypeFunction %6
)");
Validate(b);
}
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec3_With_F32) {

View File

@@ -65,6 +65,24 @@ struct ScalarConstant {
return c;
}
/// @param value the value of the constant
/// @returns a new ScalarConstant with the provided value and kind Kind::kI32
static inline ScalarConstant I32(int32_t value) {
ScalarConstant c;
c.value.i32 = value;
c.kind = Kind::kI32;
return c;
}
/// @param value the value of the constant
/// @returns a new ScalarConstant with the provided value and kind Kind::kI32
static inline ScalarConstant F32(float value) {
ScalarConstant c;
c.value.f32 = value;
c.kind = Kind::kF32;
return c;
}
/// Equality operator
/// @param rhs the ScalarConstant to compare against
/// @returns true if this ScalarConstant is equal to `rhs`