dawn-cmake/src/writer/spirv/builder_binary_expression_test.cc
David Neto 4c32dd9735 spirv-writer: Fix phi for short-circuiting operators
The Phi in the merge block was taking the value of the RHS
from the wrong basic block ID. Instead of taking it from
the first block of the expression for the RHS, take it from
the last block of the expression for the RHS.

Bug: tint:355
Change-Id: I1b79a1b107459fd420e39963ad7ab2e89bc4494f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33640
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
2020-11-23 17:17:35 +00:00

1096 lines
37 KiB
C++

// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "gtest/gtest.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bool_literal.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/uint_literal.h"
#include "src/context.h"
#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
#include "src/writer/spirv/test_helper.h"
namespace tint {
namespace writer {
namespace spirv {
namespace {
using BuilderTest = TestHelper;
struct BinaryData {
ast::BinaryOp op;
std::string name;
};
inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
out << data.op;
return out;
}
using BinaryArithSignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithSignedIntegerTest, Scalar) {
auto param = GetParam();
ast::type::I32Type i32;
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 3));
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 4));
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
%2 = OpConstant %1 3
%3 = OpConstant %1 4
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %1 %2 %3\n");
}
TEST_P(BinaryArithSignedIntegerTest, Vector) {
auto param = GetParam();
ast::type::I32Type i32;
ast::type::VectorType vec3(&i32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
});
auto* rhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
});
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %1 %4 %4\n");
}
TEST_P(BinaryArithSignedIntegerTest, Scalar_Loads) {
auto param = GetParam();
ast::type::I32Type i32;
ast::Variable var("param", ast::StorageClass::kFunction, &i32);
auto* lhs = create<ast::IdentifierExpression>("param");
auto* rhs = create<ast::IdentifierExpression>("param");
ast::BinaryExpression expr(param.op, lhs, rhs);
td.RegisterVariableForTesting(&var);
EXPECT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 7u) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
%2 = OpTypePointer Function %3
%4 = OpConstantNull %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function %4
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%5 = OpLoad %3 %1
%6 = OpLoad %3 %1
%7 = )" + param.name +
R"( %3 %5 %6
)");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryArithSignedIntegerTest,
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
BinaryData{ast::BinaryOp::kShiftRight,
"OpShiftRightArithmetic"},
BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
using BinaryArithUnsignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithUnsignedIntegerTest, Scalar) {
auto param = GetParam();
ast::type::U32Type u32;
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 3));
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 4));
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
%2 = OpConstant %1 3
%3 = OpConstant %1 4
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %1 %2 %3\n");
}
TEST_P(BinaryArithUnsignedIntegerTest, Vector) {
auto param = GetParam();
ast::type::U32Type u32;
ast::type::VectorType vec3(&u32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
});
auto* rhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
});
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %1 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryArithUnsignedIntegerTest,
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
BinaryData{ast::BinaryOp::kDivide, "OpUDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpUMod"},
BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
BinaryData{ast::BinaryOp::kShiftRight,
"OpShiftRightLogical"},
BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
using BinaryArithFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithFloatTest, Scalar) {
auto param = GetParam();
ast::type::F32Type f32;
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 3.2f));
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 4.5f));
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 3.20000005
%3 = OpConstant %1 4.5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %1 %2 %3\n");
}
TEST_P(BinaryArithFloatTest, Vector) {
auto param = GetParam();
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
});
auto* rhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
});
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %1 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryArithFloatTest,
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpFMod"},
BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
using BinaryCompareUnsignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) {
auto param = GetParam();
ast::type::U32Type u32;
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 3));
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 4));
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
%2 = OpConstant %1 3
%3 = OpConstant %1 4
%5 = OpTypeBool
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %5 %2 %3\n");
}
TEST_P(BinaryCompareUnsignedIntegerTest, Vector) {
auto param = GetParam();
ast::type::U32Type u32;
ast::type::VectorType vec3(&u32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
});
auto* rhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(&u32, 1)),
});
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
%7 = OpTypeBool
%6 = OpTypeVector %7 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %6 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryCompareUnsignedIntegerTest,
testing::Values(
BinaryData{ast::BinaryOp::kEqual, "OpIEqual"},
BinaryData{ast::BinaryOp::kGreaterThan, "OpUGreaterThan"},
BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpUGreaterThanEqual"},
BinaryData{ast::BinaryOp::kLessThan, "OpULessThan"},
BinaryData{ast::BinaryOp::kLessThanEqual, "OpULessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
using BinaryCompareSignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareSignedIntegerTest, Scalar) {
auto param = GetParam();
ast::type::I32Type i32;
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 3));
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 4));
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
%2 = OpConstant %1 3
%3 = OpConstant %1 4
%5 = OpTypeBool
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %5 %2 %3\n");
}
TEST_P(BinaryCompareSignedIntegerTest, Vector) {
auto param = GetParam();
ast::type::I32Type i32;
ast::type::VectorType vec3(&i32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
});
auto* rhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
});
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
%7 = OpTypeBool
%6 = OpTypeVector %7 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %6 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryCompareSignedIntegerTest,
testing::Values(
BinaryData{ast::BinaryOp::kEqual, "OpIEqual"},
BinaryData{ast::BinaryOp::kGreaterThan, "OpSGreaterThan"},
BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpSGreaterThanEqual"},
BinaryData{ast::BinaryOp::kLessThan, "OpSLessThan"},
BinaryData{ast::BinaryOp::kLessThanEqual, "OpSLessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
using BinaryCompareFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareFloatTest, Scalar) {
auto param = GetParam();
ast::type::F32Type f32;
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 3.2f));
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 4.5f));
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 3.20000005
%3 = OpConstant %1 4.5
%5 = OpTypeBool
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %5 %2 %3\n");
}
TEST_P(BinaryCompareFloatTest, Vector) {
auto param = GetParam();
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
});
auto* rhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
});
ast::BinaryExpression expr(param.op, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
%7 = OpTypeBool
%6 = OpTypeVector %7 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %6 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryCompareFloatTest,
testing::Values(
BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"},
BinaryData{ast::BinaryOp::kGreaterThan, "OpFOrdGreaterThan"},
BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual"},
BinaryData{ast::BinaryOp::kLessThan, "OpFOrdLessThan"},
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
});
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = OpVectorTimesScalar %1 %4 %3\n");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
ast::ExpressionList vals;
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 1
%3 = OpTypeVector %1 3
%4 = OpConstantComposite %3 %2 %2 %2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = OpVectorTimesScalar %3 %4 %2\n");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
ast::type::F32Type f32;
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::IdentifierExpression>("mat");
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%7 = OpConstant %5 1
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%8 = OpMatrixTimesScalar %3 %6 %7
)");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
ast::type::F32Type f32;
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
auto* rhs = create<ast::IdentifierExpression>("mat");
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%6 = OpConstant %5 1
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%7 = OpLoad %3 %1
%8 = OpMatrixTimesScalar %3 %7 %6
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::IdentifierExpression>("mat");
ast::ExpressionList vals;
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%7 = OpConstant %5 1
%8 = OpConstantComposite %4 %7 %7 %7
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%9 = OpMatrixTimesVector %4 %6 %8
)");
}
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
ast::ExpressionList vals;
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
auto* lhs = create<ast::TypeConstructorExpression>(&vec3, vals);
auto* rhs = create<ast::IdentifierExpression>("mat");
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%6 = OpConstant %5 1
%7 = OpConstantComposite %4 %6 %6 %6
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%8 = OpLoad %3 %1
%9 = OpVectorTimesMatrix %4 %7 %8
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::IdentifierExpression>("mat");
auto* rhs = create<ast::IdentifierExpression>("mat");
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%7 = OpLoad %3 %1
%8 = OpMatrixTimesMatrix %3 %6 %7
)");
}
TEST_F(BuilderTest, Binary_LogicalAnd) {
ast::type::I32Type i32;
auto* lhs =
create<ast::BinaryExpression>(ast::BinaryOp::kEqual,
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 2)));
auto* rhs =
create<ast::BinaryExpression>(ast::BinaryOp::kEqual,
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 3)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 4)));
ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%3 = OpConstant %2 1
%4 = OpConstant %2 2
%6 = OpTypeBool
%9 = OpConstant %2 3
%10 = OpConstant %2 4
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
%5 = OpIEqual %6 %3 %4
OpSelectionMerge %7 None
OpBranchConditional %5 %8 %7
%8 = OpLabel
%11 = OpIEqual %6 %9 %10
OpBranch %7
%7 = OpLabel
%12 = OpPhi %6 %5 %1 %11 %8
)");
}
TEST_F(BuilderTest, Binary_LogicalAnd_WithLoads) {
ast::type::BoolType bool_type;
auto* a_var =
create<ast::Variable>("a", ast::StorageClass::kFunction, &bool_type);
a_var->set_constructor(create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_type, true)));
auto* b_var =
create<ast::Variable>("b", ast::StorageClass::kFunction, &bool_type);
b_var->set_constructor(create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_type, false)));
auto* lhs = create<ast::IdentifierExpression>("a");
auto* rhs = create<ast::IdentifierExpression>("b");
td.RegisterVariableForTesting(a_var);
td.RegisterVariableForTesting(b_var);
ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
ASSERT_TRUE(b.GenerateGlobalVariable(a_var)) << b.error();
ASSERT_TRUE(b.GenerateGlobalVariable(b_var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
%3 = OpConstantTrue %2
%5 = OpTypePointer Function %2
%4 = OpVariable %5 Function %3
%6 = OpConstantFalse %2
%7 = OpVariable %5 Function %6
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
%8 = OpLoad %2 %4
OpSelectionMerge %9 None
OpBranchConditional %8 %10 %9
%10 = OpLabel
%11 = OpLoad %2 %7
OpBranch %9
%9 = OpLabel
%12 = OpPhi %2 %8 %1 %11 %10
)");
}
TEST_F(BuilderTest, Binary_logicalOr_Nested_LogicalAnd) {
ast::type::BoolType bool_ty;
// Test an expression like
// a || (b && c)
// From: crbug.com/tint/355
auto* logical_and_expr = create<ast::BinaryExpression>(
ast::BinaryOp::kLogicalAnd,
create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_ty, true)),
create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_ty, false)));
ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr,
create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_ty, true)),
logical_and_expr);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 10u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
%3 = OpConstantTrue %2
%8 = OpConstantFalse %2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
OpSelectionMerge %4 None
OpBranchConditional %3 %4 %5
%5 = OpLabel
OpSelectionMerge %6 None
OpBranchConditional %3 %7 %6
%7 = OpLabel
OpBranch %6
%6 = OpLabel
%9 = OpPhi %2 %3 %5 %8 %7
OpBranch %4
%4 = OpLabel
%10 = OpPhi %2 %3 %1 %9 %6
)");
}
TEST_F(BuilderTest, Binary_logicalAnd_Nested_LogicalOr) {
ast::type::BoolType bool_ty;
// Test an expression like
// a && (b || c)
// From: crbug.com/tint/355
auto* logical_or_expr = create<ast::BinaryExpression>(
ast::BinaryOp::kLogicalOr,
create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_ty, true)),
create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_ty, false)));
ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd,
create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_ty, true)),
logical_or_expr);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 10u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
%3 = OpConstantTrue %2
%8 = OpConstantFalse %2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
OpSelectionMerge %4 None
OpBranchConditional %3 %5 %4
%5 = OpLabel
OpSelectionMerge %6 None
OpBranchConditional %3 %6 %7
%7 = OpLabel
OpBranch %6
%6 = OpLabel
%9 = OpPhi %2 %3 %5 %8 %7
OpBranch %4
%4 = OpLabel
%10 = OpPhi %2 %3 %1 %9 %6
)");
}
TEST_F(BuilderTest, Binary_LogicalOr) {
ast::type::I32Type i32;
auto* lhs =
create<ast::BinaryExpression>(ast::BinaryOp::kEqual,
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 1)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 2)));
auto* rhs =
create<ast::BinaryExpression>(ast::BinaryOp::kEqual,
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 3)),
create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(&i32, 4)));
ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%3 = OpConstant %2 1
%4 = OpConstant %2 2
%6 = OpTypeBool
%9 = OpConstant %2 3
%10 = OpConstant %2 4
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
%5 = OpIEqual %6 %3 %4
OpSelectionMerge %7 None
OpBranchConditional %5 %7 %8
%8 = OpLabel
%11 = OpIEqual %6 %9 %10
OpBranch %7
%7 = OpLabel
%12 = OpPhi %6 %5 %1 %11 %8
)");
}
TEST_F(BuilderTest, Binary_LogicalOr_WithLoads) {
ast::type::BoolType bool_type;
auto* a_var =
create<ast::Variable>("a", ast::StorageClass::kFunction, &bool_type);
a_var->set_constructor(create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_type, true)));
auto* b_var =
create<ast::Variable>("b", ast::StorageClass::kFunction, &bool_type);
b_var->set_constructor(create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(&bool_type, false)));
auto* lhs = create<ast::IdentifierExpression>("a");
auto* rhs = create<ast::IdentifierExpression>("b");
td.RegisterVariableForTesting(a_var);
td.RegisterVariableForTesting(b_var);
ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
ASSERT_TRUE(b.GenerateGlobalVariable(a_var)) << b.error();
ASSERT_TRUE(b.GenerateGlobalVariable(b_var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
%3 = OpConstantTrue %2
%5 = OpTypePointer Function %2
%4 = OpVariable %5 Function %3
%6 = OpConstantFalse %2
%7 = OpVariable %5 Function %6
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
%8 = OpLoad %2 %4
OpSelectionMerge %9 None
OpBranchConditional %8 %9 %10
%10 = OpLabel
%11 = OpLoad %2 %7
OpBranch %9
%9 = OpLabel
%12 = OpPhi %2 %8 %1 %11 %10
)");
}
} // namespace
} // namespace spirv
} // namespace writer
} // namespace tint