tint: fix const eval short-circuiting with mixed runtime and constant expressions

For logical binary expressions that can be short-circuited, if the rhs
tree contained a mix of constant and runtime expressions, we would
erroneously mark the node as runtime, although some of its children were
resolved as kNotEvaluated. This would then fail during backend
generation.

This is a fork of 115820, addressing review comments, as amaiorano is OOO this week.

Bug: chromium:1403752
Bug: tint:1581
Change-Id: I18682c7fe1db092d280390881ff86b3c0db23e9b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/116020
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2023-01-04 12:30:47 +00:00 committed by Dawn LUCI CQ
parent be367b73ae
commit f3f813eb0c
12 changed files with 167 additions and 31 deletions

View File

@ -1697,8 +1697,8 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_ShiftRight) {
}
TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Compound) {
auto* expr = LogicalOr(LessThan(1_u, Add(Shr(3_u, 4_u), 9_u)),
GreaterThan(2.5_f, Div(6.7_f, Mul(2.3_f, 5.5_f))));
auto* expr = LogicalAnd(LessThan(1_u, Add(Shr(3_u, 4_u), 9_u)),
GreaterThan(2.5_f, Div(6.7_f, Mul(2.3_f, 5.5_f))));
WrapInFunction(expr);
auto& b = CreateBuilder();
@ -1714,7 +1714,7 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Compound) {
%4 (f32) = 2.3 * 5.5
%5 (f32) = 6.7 / %4 (f32)
%6 (bool) = 2.5 > %5 (f32)
%7 (bool) = %3 (bool) || %6 (bool)
%7 (bool) = %3 (bool) && %6 (bool)
)");
}

View File

@ -2208,6 +2208,32 @@ TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Swizzle) {
EXPECT_EQ(r()->error(), "12:34 error: invalid vector swizzle member");
}
////////////////////////////////////////////////
// Short-Circuit Mixed Constant and Runtime
////////////////////////////////////////////////
TEST_F(ResolverConstEvalTest, ShortCircuit_And_MixedConstantAndRuntime) {
// var j : i32;
// let result = false && j < (0 - 8);
auto* j = Decl(Var("j", ty.i32()));
auto* binary = LogicalAnd(Expr(false), LessThan("j", Sub(0_a, 8_a)));
auto* result = Let("result", binary);
WrapInFunction(j, result);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateAnd(Sem(), binary);
}
TEST_F(ResolverConstEvalTest, ShortCircuit_Or_MixedConstantAndRuntime) {
// var j : i32;
// let result = true || j < (0 - 8);
auto* j = Decl(Var("j", ty.i32()));
auto* binary = LogicalOr(Expr(true), LessThan("j", Sub(0_a, 8_a)));
auto* result = Let("result", binary);
WrapInFunction(j, result);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ValidateOr(Sem(), binary);
}
////////////////////////////////////////////////
// Short-Circuit Nested
////////////////////////////////////////////////

View File

@ -356,7 +356,7 @@ const type::AbstractFloat* build_fa(MatchState& state) {
}
bool match_fa(MatchState& state, const type::Type* ty) {
return (state.earliest_eval_stage == sem::EvaluationStage::kConstant) &&
return (state.earliest_eval_stage <= sem::EvaluationStage::kConstant) &&
ty->IsAnyOf<Any, type::AbstractNumeric>();
}
@ -365,7 +365,7 @@ const type::AbstractInt* build_ia(MatchState& state) {
}
bool match_ia(MatchState& state, const type::Type* ty) {
return (state.earliest_eval_stage == sem::EvaluationStage::kConstant) &&
return (state.earliest_eval_stage <= sem::EvaluationStage::kConstant) &&
ty->IsAnyOf<Any, type::AbstractInt>();
}

View File

@ -2620,13 +2620,18 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
}
const constant::Value* val = nullptr;
if (auto r = const_eval_.Literal(ty, literal)) {
val = r.Get();
} else {
return nullptr;
auto stage = sem::EvaluationStage::kConstant;
if (skip_const_eval_.Contains(literal)) {
stage = sem::EvaluationStage::kNotEvaluated;
}
return builder_->create<sem::Expression>(literal, ty, sem::EvaluationStage::kConstant,
current_statement_, std::move(val),
if (stage == sem::EvaluationStage::kConstant) {
if (auto r = const_eval_.Literal(ty, literal)) {
val = r.Get();
} else {
return nullptr;
}
}
return builder_->create<sem::Expression>(literal, ty, stage, current_statement_, std::move(val),
/* has_side_effects */ false);
}
@ -2899,29 +2904,36 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
}
const constant::Value* value = nullptr;
if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) {
if (skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
} else if (skip_const_eval_.Contains(expr->rhs)) {
// Only the rhs should be short-circuited, use the lhs value
value = lhs->ConstantValue();
if (skip_const_eval_.Contains(expr)) {
// This expression is short-circuited by an ancestor expression.
// Do not const-eval.
stage = sem::EvaluationStage::kNotEvaluated;
} else if (lhs->Stage() == sem::EvaluationStage::kConstant &&
rhs->Stage() == sem::EvaluationStage::kNotEvaluated) {
// Short-circuiting binary expression. Use the LHS value and stage.
value = lhs->ConstantValue();
stage = sem::EvaluationStage::kConstant;
} else if (stage == sem::EvaluationStage::kConstant) {
// Both LHS and RHS have expressions that are constant evaluation stage.
if (op.const_eval_fn) { // Do we have a @const operator?
// Yes. Perform any required abstract argument values implicit conversions to the
// overload parameter types, and const-eval.
utils::Vector const_args{lhs->ConstantValue(), rhs->ConstantValue()};
// Implicit conversion (e.g. AInt -> AFloat)
if (!Convert(const_args[0], op.lhs, lhs->Declaration()->source)) {
return nullptr;
}
if (!Convert(const_args[1], op.rhs, rhs->Declaration()->source)) {
return nullptr;
}
if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
value = r.Get();
} else {
auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
// Implicit conversion (e.g. AInt -> AFloat)
if (!Convert(const_args[0], op.lhs, lhs->Declaration()->source)) {
return nullptr;
}
if (!Convert(const_args[1], op.rhs, rhs->Declaration()->source)) {
return nullptr;
}
if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
value = r.Get();
} else {
return nullptr;
}
return nullptr;
}
} else {
// The arguments have constant values, but the operator cannot be const-evaluated. This
// can only be evaluated at runtime.
stage = sem::EvaluationStage::kRuntime;
}
}

View File

@ -39,6 +39,9 @@ Expression::Expression(const ast::Expression* declaration,
has_side_effects_(has_side_effects) {
TINT_ASSERT(Semantic, type_);
TINT_ASSERT(Semantic, (constant != nullptr) == (stage == EvaluationStage::kConstant));
if (constant != nullptr) {
TINT_ASSERT(Semantic, type_ == constant->Type());
}
}
Expression::~Expression() = default;

View File

@ -0,0 +1 @@
fn d(){var j:i32;for(;0<0&&j<0-8;){}}

View File

@ -0,0 +1,12 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void d() {
int j = 0;
{
for(; false; ) {
}
}
}

View File

@ -0,0 +1,12 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void d() {
int j = 0;
{
for(; false; ) {
}
}
}

View File

@ -0,0 +1,14 @@
#version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
void d() {
int j = 0;
{
for(; false; ) {
}
}
}

View File

@ -0,0 +1,9 @@
#include <metal_stdlib>
using namespace metal;
void d() {
int j = 0;
for(; false; ) {
}
}

View File

@ -0,0 +1,42 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 19
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %unused_entry_point "unused_entry_point"
OpName %d "d"
OpName %j "j"
%void = OpTypeVoid
%1 = OpTypeFunction %void
%int = OpTypeInt 32 1
%_ptr_Function_int = OpTypePointer Function %int
%10 = OpConstantNull %int
%bool = OpTypeBool
%true = OpConstantTrue %bool
%unused_entry_point = OpFunction %void None %1
%4 = OpLabel
OpReturn
OpFunctionEnd
%d = OpFunction %void None %1
%6 = OpLabel
%j = OpVariable %_ptr_Function_int Function %10
OpBranch %11
%11 = OpLabel
OpLoopMerge %12 %13 None
OpBranch %14
%14 = OpLabel
OpSelectionMerge %17 None
OpBranchConditional %true %18 %17
%18 = OpLabel
OpBranch %12
%17 = OpLabel
OpBranch %13
%13 = OpLabel
OpBranch %11
%12 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,5 @@
fn d() {
var j : i32;
for(; ((0 < 0) && (j < (0 - 8))); ) {
}
}