[spirv-writer] Allow casting of vectors along with scalars.

The current `cast` conversion code only handles scalar types and fails
if provided with vectors. This CL updates the logic to accept scalars
along with the provided scalar cases.

Bug: tint:96
Change-Id: I60772e75286fc3ee7a9dfba6634db069062b22d0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23820
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-06-29 17:58:09 +00:00 committed by David Neto
parent 5b46d71ae7
commit 48bb366991
2 changed files with 241 additions and 6 deletions

View File

@ -1527,20 +1527,36 @@ uint32_t Builder::GenerateCastExpression(ast::CastExpression* cast) {
auto* from_type = cast->expr()->result_type()->UnwrapPtrIfNeeded();
spv::Op op = spv::Op::OpNop;
if (from_type->IsI32() && to_type->IsF32()) {
if ((from_type->IsI32() && to_type->IsF32()) ||
(from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
op = spv::Op::OpConvertSToF;
} else if (from_type->IsU32() && to_type->IsF32()) {
} else if ((from_type->IsU32() && to_type->IsF32()) ||
(from_type->is_unsigned_integer_vector() &&
to_type->is_float_vector())) {
op = spv::Op::OpConvertUToF;
} else if (from_type->IsF32() && to_type->IsI32()) {
} else if ((from_type->IsF32() && to_type->IsI32()) ||
(from_type->is_float_vector() &&
to_type->is_signed_integer_vector())) {
op = spv::Op::OpConvertFToS;
} else if (from_type->IsF32() && to_type->IsU32()) {
} else if ((from_type->IsF32() && to_type->IsU32()) ||
(from_type->is_float_vector() &&
to_type->is_unsigned_integer_vector())) {
op = spv::Op::OpConvertFToU;
} else if ((from_type->IsU32() && to_type->IsU32()) ||
(from_type->IsI32() && to_type->IsI32()) ||
(from_type->IsF32() && to_type->IsF32())) {
(from_type->IsF32() && to_type->IsF32()) ||
(from_type->is_unsigned_integer_vector() &&
to_type->is_unsigned_integer_vector()) ||
(from_type->is_signed_integer_vector() &&
to_type->is_signed_integer_vector()) ||
(from_type->is_float_vector() && to_type->is_float_vector())) {
op = spv::Op::OpCopyObject;
} else if ((from_type->IsI32() && to_type->IsU32()) ||
(from_type->IsU32() && to_type->IsI32())) {
(from_type->IsU32() && to_type->IsI32()) ||
(from_type->is_signed_integer_vector() &&
to_type->is_unsigned_integer_vector()) ||
(from_type->is_unsigned_integer_vector() &&
to_type->is_integer_scalar_or_vector())) {
op = spv::Op::OpBitcast;
}

View File

@ -22,6 +22,7 @@
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/uint_literal.h"
#include "src/context.h"
#include "src/type_determiner.h"
@ -329,6 +330,224 @@ TEST_F(BuilderTest, Cast_F32ToF32) {
)");
}
TEST_F(BuilderTest, Cast_Vectors_I32_to_F32) {
ast::type::I32Type i32;
ast::type::VectorType ivec3(&i32, 3);
ast::type::F32Type f32;
ast::type::VectorType fvec3(&f32, 3);
auto var =
std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &ivec3);
ast::CastExpression cast(&fvec3,
std::make_unique<ast::IdentifierExpression>("i"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
%8 = OpTypeFloat 32
%7 = OpTypeVector %8 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%9 = OpLoad %3 %1
%6 = OpConvertSToF %7 %9
)");
}
TEST_F(BuilderTest, Cast_Vectors_U32_to_F32) {
ast::type::U32Type u32;
ast::type::VectorType uvec3(&u32, 3);
ast::type::F32Type f32;
ast::type::VectorType fvec3(&f32, 3);
auto var =
std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &uvec3);
ast::CastExpression cast(&fvec3,
std::make_unique<ast::IdentifierExpression>("i"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
%8 = OpTypeFloat 32
%7 = OpTypeVector %8 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%9 = OpLoad %3 %1
%6 = OpConvertUToF %7 %9
)");
}
TEST_F(BuilderTest, Cast_Vectors_F32_to_I32) {
ast::type::I32Type i32;
ast::type::VectorType ivec3(&i32, 3);
ast::type::F32Type f32;
ast::type::VectorType fvec3(&f32, 3);
auto var =
std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &fvec3);
ast::CastExpression cast(&ivec3,
std::make_unique<ast::IdentifierExpression>("i"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
%8 = OpTypeInt 32 1
%7 = OpTypeVector %8 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%9 = OpLoad %3 %1
%6 = OpConvertFToS %7 %9
)");
}
TEST_F(BuilderTest, Cast_Vectors_F32_to_U32) {
ast::type::U32Type u32;
ast::type::VectorType uvec3(&u32, 3);
ast::type::F32Type f32;
ast::type::VectorType fvec3(&f32, 3);
auto var =
std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &fvec3);
ast::CastExpression cast(&uvec3,
std::make_unique<ast::IdentifierExpression>("i"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
%8 = OpTypeInt 32 0
%7 = OpTypeVector %8 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%9 = OpLoad %3 %1
%6 = OpConvertFToU %7 %9
)");
}
TEST_F(BuilderTest, Cast_Vectors_U32_to_U32) {
ast::type::U32Type u32;
ast::type::VectorType uvec3(&u32, 3);
auto var =
std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &uvec3);
ast::CastExpression cast(&uvec3,
std::make_unique<ast::IdentifierExpression>("i"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%7 = OpLoad %3 %1
%6 = OpCopyObject %3 %7
)");
}
TEST_F(BuilderTest, Cast_Vectors_I32_to_U32) {
ast::type::U32Type u32;
ast::type::VectorType uvec3(&u32, 3);
ast::type::I32Type i32;
ast::type::VectorType ivec3(&i32, 3);
auto var =
std::make_unique<ast::Variable>("i", ast::StorageClass::kPrivate, &ivec3);
ast::CastExpression cast(&uvec3,
std::make_unique<ast::IdentifierExpression>("i"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
%8 = OpTypeInt 32 0
%7 = OpTypeVector %8 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%9 = OpLoad %3 %1
%6 = OpBitcast %7 %9
)");
}
} // namespace
} // namespace spirv
} // namespace writer