tint: impement short-circuiting for const eval of logical and/or

Bug: tint:1581
Change-Id: I44852bfeb0e55771009a89ed199ea60ca51e8477
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113431
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano 2022-12-09 21:28:51 +00:00 committed by Dawn LUCI CQ
parent 24c8440eb6
commit 28779af91c
13 changed files with 1130 additions and 45 deletions

View File

@ -1977,6 +1977,16 @@ class ProgramBuilder {
return create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, Expr(std::forward<EXPR>(expr)));
}
/// @param source the source information
/// @param expr the expression to perform a unary not on
/// @return an ast::UnaryOpExpression that is the unary not of the input
/// expression
template <typename EXPR>
const ast::UnaryOpExpression* Not(const Source& source, EXPR&& expr) {
return create<ast::UnaryOpExpression>(source, ast::UnaryOp::kNot,
Expr(std::forward<EXPR>(expr)));
}
/// @param expr the expression to perform a unary complement on
/// @return an ast::UnaryOpExpression that is the unary complement of the
/// input expression
@ -2121,6 +2131,17 @@ class ProgramBuilder {
Expr(std::forward<RHS>(rhs)));
}
/// @param source the source information
/// @param lhs the left hand argument to the division operation
/// @param rhs the right hand argument to the division operation
/// @returns a `ast::BinaryExpression` dividing `lhs` by `rhs`
template <typename LHS, typename RHS>
const ast::BinaryExpression* Div(const Source& source, LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(source, ast::BinaryOp::kDivide,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the modulo operation
/// @param rhs the right hand argument to the modulo operation
/// @returns a `ast::BinaryExpression` applying modulo of `lhs` by `rhs`
@ -2177,6 +2198,17 @@ class ProgramBuilder {
ast::BinaryOp::kLogicalAnd, Expr(std::forward<LHS>(lhs)), Expr(std::forward<RHS>(rhs)));
}
/// @param source the source information
/// @param lhs the left hand argument to the logical and operation
/// @param rhs the right hand argument to the logical and operation
/// @returns a `ast::BinaryExpression` of `lhs` && `rhs`
template <typename LHS, typename RHS>
const ast::BinaryExpression* LogicalAnd(const Source& source, LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(source, ast::BinaryOp::kLogicalAnd,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the logical or operation
/// @param rhs the right hand argument to the logical or operation
/// @returns a `ast::BinaryExpression` of `lhs` || `rhs`
@ -2186,6 +2218,17 @@ class ProgramBuilder {
ast::BinaryOp::kLogicalOr, Expr(std::forward<LHS>(lhs)), Expr(std::forward<RHS>(rhs)));
}
/// @param source the source information
/// @param lhs the left hand argument to the logical or operation
/// @param rhs the right hand argument to the logical or operation
/// @returns a `ast::BinaryExpression` of `lhs` || `rhs`
template <typename LHS, typename RHS>
const ast::BinaryExpression* LogicalOr(const Source& source, LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(source, ast::BinaryOp::kLogicalOr,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the greater than operation
/// @param rhs the right hand argument to the greater than operation
/// @returns a `ast::BinaryExpression` of `lhs` > `rhs`
@ -2234,6 +2277,17 @@ class ProgramBuilder {
Expr(std::forward<RHS>(rhs)));
}
/// @param source the source information
/// @param lhs the left hand argument to the equal expression
/// @param rhs the right hand argument to the equal expression
/// @returns a `ast::BinaryExpression` comparing `lhs` equal to `rhs`
template <typename LHS, typename RHS>
const ast::BinaryExpression* Equal(const Source& source, LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(source, ast::BinaryOp::kEqual,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the not-equal expression
/// @param rhs the right hand argument to the not-equal expression
/// @returns a `ast::BinaryExpression` comparing `lhs` equal to `rhs` for

View File

@ -1814,13 +1814,17 @@ ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty,
ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
// Note: Due to short-circuiting, this function is only called if lhs is true, so we could
// technically only return the value of the rhs.
return CreateElement(builder, source, ty, args[0]->As<bool>() && args[1]->As<bool>());
}
ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
return CreateElement(builder, source, ty, args[0]->As<bool>() || args[1]->As<bool>());
// Note: Due to short-circuiting, this function is only called if lhs is false, so we could
// technically only return the value of the rhs.
return CreateElement(builder, source, ty, args[1]->As<bool>());
}
ConstEval::Result ConstEval::OpAnd(const type::Type* ty,

View File

@ -14,6 +14,7 @@
#include "src/tint/resolver/const_eval_test.h"
#include "src/tint/reader/wgsl/parser.h"
#include "src/tint/utils/result.h"
using namespace tint::number_suffixes; // NOLINT
@ -1366,5 +1367,917 @@ INSTANTIATE_TEST_SUITE_P(ShiftRight,
ShiftRightCases<i32>(), //
ShiftRightCases<u32>()))));
namespace LogicalShortCircuit {
/// Validates that `binary` is a short-circuiting logical and expression
static void ValidateAnd(const sem::Info& sem, const ast::BinaryExpression* binary) {
auto* lhs = binary->lhs;
auto* rhs = binary->rhs;
auto* lhs_sem = sem.Get(lhs);
ASSERT_TRUE(lhs_sem->ConstantValue());
EXPECT_EQ(lhs_sem->ConstantValue()->As<bool>(), false);
EXPECT_EQ(lhs_sem->Stage(), sem::EvaluationStage::kConstant);
auto* rhs_sem = sem.Get(rhs);
EXPECT_EQ(rhs_sem->ConstantValue(), nullptr);
EXPECT_EQ(rhs_sem->Stage(), sem::EvaluationStage::kNotEvaluated);
auto* binary_sem = sem.Get(binary);
ASSERT_TRUE(binary_sem->ConstantValue());
EXPECT_EQ(binary_sem->ConstantValue()->As<bool>(), false);
EXPECT_EQ(binary_sem->Stage(), sem::EvaluationStage::kConstant);
}
/// Validates that `binary` is a short-circuiting logical or expression
static void ValidateOr(const sem::Info& sem, const ast::BinaryExpression* binary) {
auto* lhs = binary->lhs;
auto* rhs = binary->rhs;
auto* lhs_sem = sem.Get(lhs);
ASSERT_TRUE(lhs_sem->ConstantValue());
EXPECT_EQ(lhs_sem->ConstantValue()->As<bool>(), true);
EXPECT_EQ(lhs_sem->Stage(), sem::EvaluationStage::kConstant);
auto* rhs_sem = sem.Get(rhs);
EXPECT_EQ(rhs_sem->ConstantValue(), nullptr);
EXPECT_EQ(rhs_sem->Stage(), sem::EvaluationStage::kNotEvaluated);
auto* binary_sem = sem.Get(binary);
ASSERT_TRUE(binary_sem->ConstantValue());
EXPECT_EQ(binary_sem->ConstantValue()->As<bool>(), true);
EXPECT_EQ(binary_sem->Stage(), sem::EvaluationStage::kConstant);
}
// Naming convention for tests below:
//
// [Non]ShortCircuit_[And|Or]_[Error|Invalid]_<Op>
//
// Where:
// ShortCircuit: the rhs will not be const-evaluated
// NonShortCircuitL the rhs will be const-evaluated
//
// And/Or: type of binary expression
//
// Error: a non-const evaluation error (e.g. parser or validation error)
// Invalid: a const-evaluation error
//
// <Op> the type of operation on the rhs that may or may not be short-circuited.
////////////////////////////////////////////////
// Short-Circuit Unary
////////////////////////////////////////////////
// NOTE: Cannot demonstrate short-circuiting an invalid unary op as const eval of unary does not
// fail.
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Unary) {
// const one = 1;
// const result = (one == 0) && (!0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Not(Source{{12, 34}}, 0_a);
GlobalConst("result", LogicalAnd(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator ! (abstract-int)
2 candidate operators:
operator ! (bool) -> bool
operator ! (vecN<bool>) -> vecN<bool>
)");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Unary) {
// const one = 1;
// const result = (one == 1) || (!0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Not(Source{{12, 34}}, 0_a);
GlobalConst("result", LogicalOr(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator ! (abstract-int)
2 candidate operators:
operator ! (bool) -> bool
operator ! (vecN<bool>) -> vecN<bool>
)");
}
////////////////////////////////////////////////
// Short-Circuit Binary
////////////////////////////////////////////////
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Binary) {
// const one = 1;
// const result = (one == 0) && ((2 / 0) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Div(2_a, 0_a), 0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateAnd(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Binary) {
// const one = 1;
// const result = (one == 1) && ((2 / 0) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Div(Source{{12, 34}}, 2_a, 0_a), 0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: '2 / 0' cannot be represented as 'abstract-int'");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Binary) {
// const one = 1;
// const result = (one == 0) && (2 / 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Div(2_a, 0_a);
auto* binary = LogicalAnd(Source{{12, 34}}, lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator && (bool, abstract-int)
1 candidate operator:
operator && (bool, bool) -> bool
)");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Binary) {
// const one = 1;
// const result = (one == 1) || ((2 / 0) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Div(2_a, 0_a), 0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateOr(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Binary) {
// const one = 1;
// const result = (one == 0) || ((2 / 0) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Div(Source{{12, 34}}, 2_a, 0_a), 0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: '2 / 0' cannot be represented as 'abstract-int'");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Binary) {
// const one = 1;
// const result = (one == 1) || (2 / 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Div(2_a, 0_a);
auto* binary = LogicalOr(Source{{12, 34}}, lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator || (bool, abstract-int)
1 candidate operator:
operator || (bool, bool) -> bool
)");
}
////////////////////////////////////////////////
// Short-Circuit Materialize
////////////////////////////////////////////////
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Materialize) {
// const one = 1;
// const result = (one == 0) && (1.7976931348623157e+308 == 0.0f);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Expr(1.7976931348623157e+308_a), 0_f);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateAnd(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Materialize) {
// const one = 1;
// const result = (one == 1) && (1.7976931348623157e+308 == 0.0f);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Expr(Source{{12, 34}}, 1.7976931348623157e+308_a), 0_f);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: value 1.7976931348623157081e+308 cannot be represented as 'f32'");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Materialize) {
// const one = 1;
// const result = (one == 0) && (1.7976931348623157e+308 == 0i);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Source{{12, 34}}, 1.7976931348623157e+308_a, 0_i);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator == (abstract-float, i32)
2 candidate operators:
operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Materialize) {
// const one = 1;
// const result = (one == 1) || (1.7976931348623157e+308 == 0.0f);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(1.7976931348623157e+308_a, 0_f);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateOr(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Materialize) {
// const one = 1;
// const result = (one == 0) || (1.7976931348623157e+308 == 0.0f);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Expr(Source{{12, 34}}, 1.7976931348623157e+308_a), 0_f);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: value 1.7976931348623157081e+308 cannot be represented as 'f32'");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Materialize) {
// const one = 1;
// const result = (one == 1) || (1.7976931348623157e+308 == 0i);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Source{{12, 34}}, Expr(1.7976931348623157e+308_a), 0_i);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator == (abstract-float, i32)
2 candidate operators:
operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
////////////////////////////////////////////////
// Short-Circuit Index
////////////////////////////////////////////////
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Index) {
// const one = 1;
// const a = array(1i, 2i, 3i);
// const i = 4;
// const result = (one == 0) && (a[i] == 0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
GlobalConst("i", Expr(4_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(IndexAccessor("a", "i"), 0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateAnd(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Index) {
// const one = 1;
// const a = array(1i, 2i, 3i);
// const i = 3;
// const result = (one == 1) && (a[i] == 0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
GlobalConst("i", Expr(3_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(IndexAccessor("a", Expr(Source{{12, 34}}, "i")), 0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: index 3 out of bounds [0..2]");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Index) {
// const one = 1;
// const a = array(1i, 2i, 3i);
// const i = 3;
// const result = (one == 0) && (a[i] == 0.0f);
GlobalConst("one", Expr(1_a));
GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
GlobalConst("i", Expr(3_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Source{{12, 34}}, IndexAccessor("a", "i"), 0.0_f);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator == (i32, f32)
2 candidate operators:
operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Index) {
// const one = 1;
// const a = array(1i, 2i, 3i);
// const i = 4;
// const result = (one == 1) || (a[i] == 0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
GlobalConst("i", Expr(4_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(IndexAccessor("a", "i"), 0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateOr(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Index) {
// const one = 1;
// const a = array(1i, 2i, 3i);
// const i = 3;
// const result = (one == 0) || (a[i] == 0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
GlobalConst("i", Expr(3_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(IndexAccessor("a", Expr(Source{{12, 34}}, "i")), 0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: index 3 out of bounds [0..2]");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Index) {
// const one = 1;
// const a = array(1i, 2i, 3i);
// const i = 3;
// const result = (one == 1) || (a[i] == 0.0f);
GlobalConst("one", Expr(1_a));
GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
GlobalConst("i", Expr(3_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Source{{12, 34}}, IndexAccessor("a", "i"), 0.0_f);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator == (i32, f32)
2 candidate operators:
operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
////////////////////////////////////////////////
// Short-Circuit Bitcast
////////////////////////////////////////////////
// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_And_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 0) && (bitcast<f32>(a) == 0.0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", Expr(0x7F800000_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Bitcast<f32>("a"), 0.0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateAnd(Sem(), binary);
}
// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
TEST_F(ResolverConstEvalTest, DISABLED_NonShortCircuit_And_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 1) && (bitcast<f32>(a) == 0.0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", Expr(0x7F800000_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Bitcast(Source{{12, 34}}, ty.f32(), "a"), 0.0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: value not representable as f32 message here");
}
// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_And_Error_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 0) && (bitcast<f32>(a) == 0i);
GlobalConst("one", Expr(1_a));
GlobalConst("a", Expr(0x7F800000_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Source{{12, 34}}, Bitcast(ty.f32(), "a"), 0_i);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload message here)");
}
// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_Or_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 1) || (bitcast<f32>(a) == 0.0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", Expr(0x7F800000_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Bitcast<f32>("a"), 0.0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateOr(Sem(), binary);
}
// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
TEST_F(ResolverConstEvalTest, DISABLED_NonShortCircuit_Or_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 0) || (bitcast<f32>(a) == 0.0);
GlobalConst("one", Expr(1_a));
GlobalConst("a", Expr(0x7F800000_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Bitcast(Source{{12, 34}}, ty.f32(), "a"), 0.0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: value not representable as f32 message here");
}
// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_Or_Error_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 1) || (bitcast<f32>(a) == 0i);
GlobalConst("one", Expr(1_a));
GlobalConst("a", Expr(0x7F800000_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Source{{12, 34}}, Bitcast(ty.f32(), "a"), 0_i);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload message here)");
}
////////////////////////////////////////////////
// Short-Circuit Type Init/Convert
////////////////////////////////////////////////
// NOTE: Cannot demonstrate short-circuiting an invalid init/convert as const eval of init/convert
// always succeeds.
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Init) {
// const one = 1;
// const result = (one == 0) && (vec2<f32>(1.0, true).x == 0.0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(MemberAccessor(vec2<f32>(Source{{12, 34}}, 1.0_a, Expr(true)), "x"), 0.0_a);
GlobalConst("result", LogicalAnd(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching initializer for vec2<f32>(abstract-float, bool)
4 candidate initializers:
vec2(x: T, y: T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
vec2(T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
vec2(vec2<T>) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
vec2<T>() -> vec2<T> where: T is f32, f16, i32, u32 or bool
5 candidate conversions:
vec2<T>(vec2<U>) -> vec2<f32> where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool
vec2<T>(vec2<U>) -> vec2<f16> where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool
vec2<T>(vec2<U>) -> vec2<i32> where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool
vec2<T>(vec2<U>) -> vec2<u32> where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool
vec2<T>(vec2<U>) -> vec2<bool> where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32
)");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Init) {
// const one = 1;
// const result = (one == 1) || (vec2<f32>(1.0, true).x == 0.0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(MemberAccessor(vec2<f32>(Source{{12, 34}}, 1.0_a, Expr(true)), "x"), 0.0_a);
GlobalConst("result", LogicalOr(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching initializer for vec2<f32>(abstract-float, bool)
4 candidate initializers:
vec2(x: T, y: T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
vec2(T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
vec2(vec2<T>) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
vec2<T>() -> vec2<T> where: T is f32, f16, i32, u32 or bool
5 candidate conversions:
vec2<T>(vec2<U>) -> vec2<f32> where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool
vec2<T>(vec2<U>) -> vec2<f16> where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool
vec2<T>(vec2<U>) -> vec2<i32> where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool
vec2<T>(vec2<U>) -> vec2<u32> where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool
vec2<T>(vec2<U>) -> vec2<bool> where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32
)");
}
////////////////////////////////////////////////
// Short-Circuit Array/Struct Init
////////////////////////////////////////////////
// NOTE: Cannot demonstrate short-circuiting an invalid array/struct init as const eval of
// array/struct init always succeeds.
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_StructInit) {
// struct S {
// a : i32,
// b : f32,
// }
// const one = 1;
// const result = (one == 0) && Foo(1, true).a == 0;
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(
MemberAccessor(Construct(ty.type_name("S"), Expr(1_a), Expr(Source{{12, 34}}, true)), "a"),
0_a);
GlobalConst("result", LogicalAnd(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: type in struct initializer does not match struct member type: "
"expected 'f32', found 'bool'");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_StructInit) {
// struct S {
// a : i32,
// b : f32,
// }
// const one = 1;
// const result = (one == 1) || Foo(1, true).a == 0;
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(
MemberAccessor(Construct(ty.type_name("S"), Expr(1_a), Expr(Source{{12, 34}}, true)), "a"),
0_a);
GlobalConst("result", LogicalOr(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: type in struct initializer does not match struct member type: "
"expected 'f32', found 'bool'");
}
////////////////////////////////////////////////
// Short-Circuit Builtin Call
////////////////////////////////////////////////
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_BuiltinCall) {
// const one = 1;
// return (one == 0) && (extractBits(1, 0, 99) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Call("extractBits", 1_a, 0_a, 99_a), 0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateAnd(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_BuiltinCall) {
// const one = 1;
// return (one == 1) && (extractBits(1, 0, 99) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Call(Source{{12, 34}}, "extractBits", 1_a, 0_a, 99_a), 0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_BuiltinCall) {
// const one = 1;
// return (one == 0) && (extractBits(1, 0, 99) == 0.0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Source{{12, 34}}, Call("extractBits", 1_a, 0_a, 99_a), 0.0_a);
auto* binary = LogicalAnd(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator == (i32, abstract-float)
2 candidate operators:
operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_BuiltinCall) {
// const one = 1;
// return (one == 1) || (extractBits(1, 0, 99) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Call("extractBits", 1_a, 0_a, 99_a), 0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateOr(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_BuiltinCall) {
// const one = 1;
// return (one == 0) || (extractBits(1, 0, 99) == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(Call(Source{{12, 34}}, "extractBits", 1_a, 0_a, 99_a), 0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_BuiltinCall) {
// const one = 1;
// return (one == 1) || (extractBits(1, 0, 99) == 0.0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(Source{{12, 34}}, Call("extractBits", 1_a, 0_a, 99_a), 0.0_a);
auto* binary = LogicalOr(lhs, rhs);
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: no matching overload for operator == (i32, abstract-float)
2 candidate operators:
operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
////////////////////////////////////////////////
// Short-Circuit Literal
////////////////////////////////////////////////
// NOTE: Cannot demonstrate short-circuiting an invalid literal as const eval of a literal does not
// fail.
#if TINT_BUILD_WGSL_READER
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Literal) {
// NOTE: This fails parsing rather than resolving, which is why we can't use the ProgramBuilder
// for this test.
auto src = R"(
const one = 1;
const result = (one == 0) && (1111111111111111111111111111111i == 0);
)";
auto file = std::make_unique<Source::File>("test", src);
auto program = reader::wgsl::Parse(file.get());
EXPECT_FALSE(program.IsValid());
diag::Formatter::Style style;
style.print_newline_at_end = false;
auto error = diag::Formatter(style).format(program.Diagnostics());
EXPECT_EQ(error, R"(test:3:31 error: value cannot be represented as 'i32'
const result = (one == 0) && (1111111111111111111111111111111i == 0);
^
)");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Literal) {
// NOTE: This fails parsing rather than resolving, which is why we can't use the ProgramBuilder
// for this test.
auto src = R"(
const one = 1;
const result = (one == 1) || (1111111111111111111111111111111i == 0);
)";
auto file = std::make_unique<Source::File>("test", src);
auto program = reader::wgsl::Parse(file.get());
EXPECT_FALSE(program.IsValid());
diag::Formatter::Style style;
style.print_newline_at_end = false;
auto error = diag::Formatter(style).format(program.Diagnostics());
EXPECT_EQ(error, R"(test:3:31 error: value cannot be represented as 'i32'
const result = (one == 1) || (1111111111111111111111111111111i == 0);
^
)");
}
#endif // TINT_BUILD_WGSL_READER
////////////////////////////////////////////////
// Short-Circuit Member Access
////////////////////////////////////////////////
// NOTE: Cannot demonstrate short-circuiting an invalid member access as const eval of member access
// always succeeds.
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_MemberAccess) {
// struct S {
// a : i32,
// b : f32,
// }
// const s = S(1, 2.0);
// const one = 1;
// const result = (one == 0) && (s.c == 0);
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
GlobalConst("s", Construct(ty.type_name("S"), Expr(1_a), Expr(2.0_a)));
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", Expr("c")), 0_a);
GlobalConst("result", LogicalAnd(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: struct member c not found");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_MemberAccess) {
// struct S {
// a : i32,
// b : f32,
// }
// const s = S(1, 2.0);
// const one = 1;
// const result = (one == 1) || (s.c == 0);
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
GlobalConst("s", Construct(ty.type_name("S"), Expr(1_a), Expr(2.0_a)));
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", Expr("c")), 0_a);
GlobalConst("result", LogicalOr(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: struct member c not found");
}
////////////////////////////////////////////////
// Short-Circuit Swizzle
////////////////////////////////////////////////
// NOTE: Cannot demonstrate short-circuiting an invalid swizzle as const eval of swizzle always
// succeeds.
TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Swizzle) {
// const one = 1;
// const result = (one == 0) && (vec2(1, 2).z == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(MemberAccessor(vec2<AInt>(1_a, 2_a), Expr(Source{{12, 34}}, "z")), 0_a);
GlobalConst("result", LogicalAnd(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: invalid vector swizzle member");
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Swizzle) {
// const one = 1;
// const result = (one == 1) || (vec2(1, 2).z == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(MemberAccessor(vec2<AInt>(1_a, 2_a), Expr(Source{{12, 34}}, "z")), 0_a);
GlobalConst("result", LogicalOr(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: invalid vector swizzle member");
}
////////////////////////////////////////////////
// Short-Circuit Nested
////////////////////////////////////////////////
#if TINT_BUILD_WGSL_READER
using ResolverConstEvalTestShortCircuit = ResolverTestWithParam<std::tuple<const char*, bool>>;
TEST_P(ResolverConstEvalTestShortCircuit, Test) {
const char* expr = std::get<0>(GetParam());
bool should_pass = std::get<1>(GetParam());
auto src = std::string(R"(
const one = 1;
const result = )");
src = src + expr + ";";
auto file = std::make_unique<Source::File>("test", src);
auto program = reader::wgsl::Parse(file.get());
if (should_pass) {
diag::Formatter::Style style;
style.print_newline_at_end = false;
auto error = diag::Formatter(style).format(program.Diagnostics());
EXPECT_TRUE(program.IsValid()) << error;
} else {
EXPECT_FALSE(program.IsValid());
}
}
INSTANTIATE_TEST_SUITE_P(Nested,
ResolverConstEvalTestShortCircuit,
testing::ValuesIn(std::vector<std::tuple<const char*, bool>>{
// AND nested rhs
{"(one == 0) && ((one == 0) && ((2/0)==0))", true},
{"(one == 1) && ((one == 0) && ((2/0)==0))", true},
{"(one == 0) && ((one == 1) && ((2/0)==0))", true},
{"(one == 1) && ((one == 1) && ((2/0)==0))", false},
// AND nested lhs
{"((one == 0) && ((2/0)==0)) && (one == 0)", true},
{"((one == 0) && ((2/0)==0)) && (one == 1)", true},
{"((one == 1) && ((2/0)==0)) && (one == 0)", false},
{"((one == 1) && ((2/0)==0)) && (one == 1)", false},
// OR nested rhs
{"(one == 1) || ((one == 1) || ((2/0)==0))", true},
{"(one == 0) || ((one == 1) || ((2/0)==0))", true},
{"(one == 1) || ((one == 0) || ((2/0)==0))", true},
{"(one == 0) || ((one == 0) || ((2/0)==0))", false},
// OR nested lhs
{"((one == 1) || ((2/0)==0)) || (one == 1)", true},
{"((one == 1) || ((2/0)==0)) || (one == 0)", true},
{"((one == 0) || ((2/0)==0)) || (one == 1)", false},
{"((one == 0) || ((2/0)==0)) || (one == 0)", false},
// AND nested both sides
{"((one == 0) && ((2/0)==0)) && ((one == 0) && ((2/0)==0))", true},
{"((one == 0) && ((2/0)==0)) && ((one == 1) && ((2/0)==0))", true},
{"((one == 1) && ((2/0)==0)) && ((one == 0) && ((2/0)==0))", false},
{"((one == 1) && ((2/0)==0)) && ((one == 1) && ((2/0)==0))", false},
// OR nested both sides
{"((one == 1) || ((2/0)==0)) && ((one == 1) || ((2/0)==0))", true},
{"((one == 1) || ((2/0)==0)) && ((one == 0) || ((2/0)==0))", false},
{"((one == 0) || ((2/0)==0)) && ((one == 1) || ((2/0)==0))", false},
{"((one == 0) || ((2/0)==0)) && ((one == 0) || ((2/0)==0))", false},
// AND chained
{"(one == 0) && (one == 0) && ((2 / 0) == 0)", true},
{"(one == 1) && (one == 0) && ((2 / 0) == 0)", true},
{"(one == 0) && (one == 1) && ((2 / 0) == 0)", true},
{"(one == 1) && (one == 1) && ((2 / 0) == 0)", false},
// OR chained
{"(one == 1) || (one == 1) || ((2 / 0) == 0)", true},
{"(one == 0) || (one == 1) || ((2 / 0) == 0)", true},
{"(one == 1) || (one == 0) || ((2 / 0) == 0)", true},
{"(one == 0) || (one == 0) || ((2 / 0) == 0)", false},
}));
#endif // TINT_BUILD_WGSL_READER
} // namespace LogicalShortCircuit
} // namespace
} // namespace tint::resolver

View File

@ -293,5 +293,53 @@ TEST_F(ResolverEvaluationStageTest, MemberAccessor_Runtime) {
EXPECT_EQ(Sem().Get(expr)->Stage(), sem::EvaluationStage::kRuntime);
}
TEST_F(ResolverEvaluationStageTest, Binary_Runtime) {
// let one = 1;
// let result = (one == 1) && (one == 1);
auto* one = Let("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal("one", 1_a);
auto* binary = LogicalAnd(lhs, rhs);
auto* result = Let("result", binary);
WrapInFunction(one, result);
ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(Sem().Get(lhs)->Stage(), sem::EvaluationStage::kRuntime);
EXPECT_EQ(Sem().Get(rhs)->Stage(), sem::EvaluationStage::kRuntime);
EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
}
TEST_F(ResolverEvaluationStageTest, Binary_Const) {
// const one = 1;
// const result = (one == 1) && (one == 1);
auto* one = Const("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal("one", 1_a);
auto* binary = LogicalAnd(lhs, rhs);
auto* result = Const("result", binary);
WrapInFunction(one, result);
ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(Sem().Get(lhs)->Stage(), sem::EvaluationStage::kConstant);
EXPECT_EQ(Sem().Get(rhs)->Stage(), sem::EvaluationStage::kConstant);
EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kConstant);
}
TEST_F(ResolverEvaluationStageTest, Binary_NotEvaluated) {
// const one = 1;
// const result = (one == 0) && (one == 1);
auto* one = Const("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal("one", 1_a);
auto* binary = LogicalAnd(lhs, rhs);
auto* result = Const("result", binary);
WrapInFunction(one, result);
ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(Sem().Get(lhs)->Stage(), sem::EvaluationStage::kConstant);
EXPECT_EQ(Sem().Get(rhs)->Stage(), sem::EvaluationStage::kNotEvaluated);
EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kConstant);
}
} // namespace
} // namespace tint::resolver

View File

@ -1510,6 +1510,11 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
failed = true;
return ast::TraverseAction::Stop;
}
if (auto* binary = expr->As<ast::BinaryExpression>();
binary && binary->IsLogical()) {
// Store potential const-eval short-circuit pair
logical_binary_lhs_to_parent_.Add(binary->lhs, binary);
}
sorted.Push(expr);
return ast::TraverseAction::Descend;
})) {
@ -1568,6 +1573,26 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
if (expr == root) {
return sem_expr;
}
// If we just processed the lhs of a constexpr logical binary expression, mark the rhs for
// short-circuiting.
if (sem_expr->ConstantValue()) {
if (auto binary = logical_binary_lhs_to_parent_.Find(expr)) {
const bool lhs_is_true = sem_expr->ConstantValue()->As<bool>();
if (((*binary)->IsLogicalAnd() && !lhs_is_true) ||
((*binary)->IsLogicalOr() && lhs_is_true)) {
// Mark entire expression tree to not const-evaluate
auto r = ast::TraverseExpressions( //
(*binary)->rhs, diagnostics_, [&](const ast::Expression* e) {
skip_const_eval_.Add(e);
return ast::TraverseAction::Descend;
});
if (!r) {
return nullptr;
}
}
}
}
}
TINT_ICE(Resolver, diagnostics_) << "Expression() did not find root node";
@ -1779,27 +1804,32 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
return nullptr;
}
auto expr_val = expr->ConstantValue();
if (!expr_val) {
TINT_ICE(Resolver, builder_->Diagnostics())
<< decl->source << "Materialize(" << decl->TypeInfo().name
<< ") called on expression with no constant value";
return nullptr;
const sem::Constant* materialized_val = nullptr;
if (!skip_const_eval_.Contains(decl)) {
auto expr_val = expr->ConstantValue();
if (!expr_val) {
TINT_ICE(Resolver, builder_->Diagnostics())
<< decl->source << "Materialize(" << decl->TypeInfo().name
<< ") called on expression with no constant value";
return nullptr;
}
auto val = const_eval_.Convert(concrete_ty, expr_val, decl->source);
if (!val) {
// Convert() has already failed and raised an diagnostic error.
return nullptr;
}
materialized_val = val.Get();
if (!materialized_val) {
TINT_ICE(Resolver, builder_->Diagnostics())
<< decl->source << "ConvertValue(" << builder_->FriendlyName(expr_val->Type())
<< " -> " << builder_->FriendlyName(concrete_ty) << ") returned invalid value";
return nullptr;
}
}
auto materialized_val = const_eval_.Convert(concrete_ty, expr_val, decl->source);
if (!materialized_val) {
// ConvertValue() has already failed and raised an diagnostic error.
return nullptr;
}
if (!materialized_val.Get()) {
TINT_ICE(Resolver, builder_->Diagnostics())
<< decl->source << "ConvertValue(" << builder_->FriendlyName(expr_val->Type()) << " -> "
<< builder_->FriendlyName(concrete_ty) << ") returned invalid value";
return nullptr;
}
auto* m = builder_->create<sem::Materialize>(expr, current_statement_, materialized_val.Get());
auto* m =
builder_->create<sem::Materialize>(expr, current_statement_, concrete_ty, materialized_val);
m->Behaviors() = expr->Behaviors();
builder_->Sem().Replace(decl, m);
return m;
@ -1894,12 +1924,16 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
ty = builder_->create<type::Reference>(ty, ref->AddressSpace(), ref->Access());
}
auto stage = sem::EarliestStage(obj->Stage(), idx->Stage());
const sem::Constant* val = nullptr;
if (auto r = const_eval_.Index(obj, idx)) {
val = r.Get();
auto stage = sem::EarliestStage(obj->Stage(), idx->Stage());
if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
} else {
return nullptr;
if (auto r = const_eval_.Index(obj, idx)) {
val = r.Get();
} else {
return nullptr;
}
}
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::IndexAccessorExpression>(
@ -1922,6 +1956,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
RegisterLoadIfNeeded(inner);
const sem::Constant* val = nullptr;
// TODO(crbug.com/tint/1582): short circuit 'expr' once const eval of Bitcast is implemented.
if (auto r = const_eval_.Bitcast(ty, inner)) {
val = r.Get();
} else {
@ -1981,8 +2016,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaybeMaterializeArguments(args, ctor_or_conv.target)) {
return nullptr;
}
const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage);
if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
}
if (stage == sem::EvaluationStage::kConstant) {
auto const_args = ConvertArguments(args, ctor_or_conv.target);
if (!const_args) {
@ -2302,13 +2341,17 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
// If the builtin is @const, and all arguments have constant values, evaluate the builtin
// now.
auto stage = sem::EarliestStage(arg_stage, builtin.sem->Stage());
const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(arg_stage, builtin.sem->Stage());
if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
}
if (stage == sem::EvaluationStage::kConstant) {
auto const_args = ConvertArguments(args, builtin.sem);
if (!const_args) {
return nullptr;
}
if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(),
const_args.Get(), expr->source)) {
value = r.Get();
@ -2787,19 +2830,25 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
const sem::Constant* value = nullptr;
if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) {
auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
// Implicit conversion (e.g. AInt -> AFloat)
if (!Convert(const_args[0], op.lhs, lhs->Declaration()->source)) {
return nullptr;
}
if (!Convert(const_args[1], op.rhs, rhs->Declaration()->source)) {
return nullptr;
}
if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
value = r.Get();
if (skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
} else if (skip_const_eval_.Contains(expr->rhs)) {
// Only the rhs should be short-circuited, use the lhs value
value = lhs->ConstantValue();
} else {
return nullptr;
auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
// Implicit conversion (e.g. AInt -> AFloat)
if (!Convert(const_args[0], op.lhs, lhs->Declaration()->source)) {
return nullptr;
}
if (!Convert(const_args[1], op.rhs, rhs->Declaration()->source)) {
return nullptr;
}
if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
value = r.Get();
} else {
return nullptr;
}
}
} else {
stage = sem::EvaluationStage::kRuntime;

View File

@ -478,6 +478,9 @@ class Resolver {
uint32_t current_scoping_depth_ = 0;
utils::UniqueVector<const sem::GlobalVariable*, 4>* resolved_overrides_ = nullptr;
utils::Hashset<TypeAndAddressSpace, 8> valid_type_storage_layouts_;
utils::Hashmap<const ast::Expression*, const ast::BinaryExpression*, 8>
logical_binary_lhs_to_parent_;
utils::Hashset<const ast::Expression*, 8> skip_const_eval_;
};
} // namespace tint::resolver

View File

@ -1318,6 +1318,9 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
bool Validator::EvaluationStage(const sem::Expression* expr,
sem::EvaluationStage latest_stage,
std::string_view constraint) const {
if (expr->Stage() == sem::EvaluationStage::kNotEvaluated) {
return true;
}
if (expr->Stage() > latest_stage) {
auto stage_name = [](sem::EvaluationStage stage) -> std::string {
switch (stage) {
@ -1327,6 +1330,8 @@ bool Validator::EvaluationStage(const sem::Expression* expr,
return "an override-expression";
case sem::EvaluationStage::kConstant:
return "a const-expression";
case sem::EvaluationStage::kNotEvaluated:
return "an unevaluated expression";
}
return "<unknown>";
};

View File

@ -32,7 +32,8 @@ Call::Call(const ast::CallExpression* declaration,
target_(target),
arguments_(std::move(arguments)) {
// Check that the stage is no earlier than the target supports
TINT_ASSERT(Semantic, target->Stage() <= stage);
TINT_ASSERT(Semantic,
(target->Stage() <= stage) || (stage == sem::EvaluationStage::kNotEvaluated));
}
Call::~Call() = default;

View File

@ -22,6 +22,8 @@ namespace tint::sem {
/// The earliest point in time that an expression can be evaluated
enum class EvaluationStage {
/// Expression will not be evaluated
kNotEvaluated,
/// Expression can be evaluated at shader creation time
kConstant,
/// Expression can be evaluated at pipeline creation time
@ -43,7 +45,7 @@ inline bool operator>(EvaluationStage a, EvaluationStage b) {
/// @param stages a list of EvaluationStage.
/// @returns the earliest stage supported by all the provided stages
inline EvaluationStage EarliestStage(std::initializer_list<EvaluationStage> stages) {
auto earliest = EvaluationStage::kConstant;
auto earliest = EvaluationStage::kNotEvaluated;
for (auto stage : stages) {
earliest = std::max(stage, earliest);
}

View File

@ -47,7 +47,7 @@ TEST_F(ExpressionTest, UnwrapMaterialize) {
sem::EvaluationStage::kRuntime, /* statement */ nullptr,
/* constant_value */ nullptr,
/* has_side_effects */ false, /* root_ident */ nullptr);
auto* b = create<Materialize>(a, /* statement */ nullptr, &c);
auto* b = create<Materialize>(a, /* statement */ nullptr, c.Type(), &c);
EXPECT_EQ(a, a->UnwrapMaterialize());
EXPECT_EQ(a, b->UnwrapMaterialize());

View File

@ -19,10 +19,11 @@ TINT_INSTANTIATE_TYPEINFO(tint::sem::Materialize);
namespace tint::sem {
Materialize::Materialize(const Expression* expr,
const Statement* statement,
const type::Type* type,
const Constant* constant)
: Base(/* declaration */ expr->Declaration(),
/* type */ constant->Type(),
/* stage */ EvaluationStage::kConstant, // Abstract can only be const-expr
/* type */ type,
/* stage */ constant ? EvaluationStage::kConstant : EvaluationStage::kNotEvaluated,
/* statement */ statement,
/* constant */ constant,
/* has_side_effects */ false,

View File

@ -30,8 +30,12 @@ class Materialize final : public Castable<Materialize, Expression> {
/// Constructor
/// @param expr the inner expression, being materialized
/// @param statement the statement that owns this expression
/// @param constant the constant value of this expression
Materialize(const Expression* expr, const Statement* statement, const Constant* constant);
/// @param type concrete type to materialize to
/// @param constant the constant value of this expression or nullptr
Materialize(const Expression* expr,
const Statement* statement,
const type::Type* type,
const Constant* constant);
/// Destructor
~Materialize() override;

View File

@ -801,7 +801,8 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
auto* expr = src->Sem().Get<sem::Expression>(node);
if (!expr || expr->Stage() == sem::EvaluationStage::kConstant) {
if (!expr || expr->Stage() == sem::EvaluationStage::kConstant ||
expr->Stage() == sem::EvaluationStage::kNotEvaluated) {
continue; // Don't polyfill @const expressions
}