resolver: Implement Behavior Analysis

This change implements the behavior analysis for expressions and
statements as described in:
https://www.w3.org/TR/WGSL/#behaviors

This CL makes no changes to the validation rules. This will be done as a
followup change.

Bug: tint:1302
Change-Id: If0a251a7982ea15ff5d93b54a5cc5ed03ba60608
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68408
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton 2021-12-03 15:49:34 +00:00 committed by Tint LUCI CQ
parent bf39c8fb19
commit 3298625760
10 changed files with 940 additions and 60 deletions

View File

@ -689,6 +689,7 @@ if(${TINT_BUILD_TESTS})
resolver/pipeline_overridable_constant_test.cc
resolver/ptr_ref_test.cc
resolver/ptr_ref_validation_test.cc
resolver/resolver_behavior_test.cc
resolver/resolver_constants_test.cc
resolver/resolver_test_helper.cc
resolver/resolver_test_helper.h

View File

@ -32,6 +32,7 @@
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/depth_multisampled_texture.h"
#include "src/ast/depth_texture.h"
#include "src/ast/disable_validation_decoration.h"
@ -1864,6 +1865,19 @@ class ProgramBuilder {
/// @returns the break statement pointer
const ast::BreakStatement* Break() { return create<ast::BreakStatement>(); }
/// Creates an ast::ContinueStatement
/// @param source the source information
/// @returns the continue statement pointer
const ast::ContinueStatement* Continue(const Source& source) {
return create<ast::ContinueStatement>(source);
}
/// Creates an ast::ContinueStatement
/// @returns the continue statement pointer
const ast::ContinueStatement* Continue() {
return create<ast::ContinueStatement>();
}
/// Creates an ast::ReturnStatement with no return value
/// @param source the source information
/// @returns the return statement pointer
@ -2041,6 +2055,13 @@ class ProgramBuilder {
body);
}
/// Creates a ast::ElseStatement with no condition and body
/// @param body the else body
/// @returns the else statement pointer
const ast::ElseStatement* Else(const ast::BlockStatement* body) {
return create<ast::ElseStatement>(nullptr, body);
}
/// Creates a ast::IfStatement with input condition, body, and optional
/// variadic else statements
/// @param condition the if statement condition expression

View File

@ -657,11 +657,21 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
<< "Resolver::Function() called with a current compound statement";
return nullptr;
}
if (!StatementScope(decl->body,
builder_->create<sem::FunctionBlockStatement>(func),
[&] { return Statements(decl->body->statements); })) {
auto* body = StatementScope(
decl->body, builder_->create<sem::FunctionBlockStatement>(func),
[&] { return Statements(decl->body->statements); });
if (!body) {
return nullptr;
}
func->Behaviors() = body->Behaviors();
if (func->Behaviors().Contains(sem::Behavior::kReturn)) {
// https://www.w3.org/TR/WGSL/#behaviors-rules
// We assign a behavior to each function: it is its bodys behavior
// (treating the body as a regular statement), with any "Return" replaced
// by "Next".
func->Behaviors().Remove(sem::Behavior::kReturn);
func->Behaviors().Add(sem::Behavior::kNext);
}
}
for (auto* deco : decl->decorations) {
@ -797,13 +807,22 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
}
bool Resolver::Statements(const ast::StatementList& stmts) {
sem::Behaviors behaviors{sem::Behavior::kNext};
for (auto* stmt : stmts) {
Mark(stmt);
auto* sem = Statement(stmt);
if (!sem) {
return false;
}
// s1 s2:(B1{Next}) B2
// ValidateStatements will ensure that statements can only follow a Next.
behaviors.Remove(sem::Behavior::kNext);
behaviors.Add(sem->Behaviors());
}
current_statement_->Behaviors() = behaviors;
if (!ValidateStatements(stmts)) {
return false;
}
@ -887,6 +906,7 @@ sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
return false;
}
sem->SetBlock(body);
sem->Behaviors() = body->Behaviors();
return true;
});
}
@ -900,6 +920,8 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
return false;
}
sem->SetCondition(cond);
sem->Behaviors() = cond->Behaviors();
sem->Behaviors().Remove(sem::Behavior::kNext);
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
@ -908,12 +930,23 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
[&] { return Statements(stmt->body->statements); })) {
return false;
}
sem->Behaviors().Add(body->Behaviors());
for (auto* else_stmt : stmt->else_statements) {
Mark(else_stmt);
if (!ElseStatement(else_stmt)) {
auto* else_sem = ElseStatement(else_stmt);
if (!else_sem) {
return false;
}
sem->Behaviors().Add(else_sem->Behaviors());
}
if (stmt->else_statements.empty() ||
stmt->else_statements.back()->condition != nullptr) {
// https://www.w3.org/TR/WGSL/#behaviors-rules
// if statements without an else branch are treated as if they had an
// empty else branch (which adds Next to their behavior)
sem->Behaviors().Add(sem::Behavior::kNext);
}
return ValidateIfStatement(sem);
@ -930,7 +963,12 @@ sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
return false;
}
sem->SetCondition(cond);
// https://www.w3.org/TR/WGSL/#behaviors-rules
// if statements with else if branches are treated as if they were nested
// simple if/else statements
sem->Behaviors() = cond->Behaviors();
}
sem->Behaviors().Remove(sem::Behavior::kNext);
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
@ -939,6 +977,7 @@ sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
[&] { return Statements(stmt->body->statements); })) {
return false;
}
sem->Behaviors().Add(body->Behaviors());
return ValidateElseStatement(sem);
});
@ -964,20 +1003,32 @@ sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
if (!Statements(stmt->body->statements)) {
return false;
}
auto& behaviors = sem->Behaviors();
behaviors = body->Behaviors();
if (stmt->continuing) {
Mark(stmt->continuing);
if (!stmt->continuing->Empty()) {
auto* continuing =
auto* continuing = StatementScope(
stmt->continuing,
builder_->create<sem::LoopContinuingBlockStatement>(
stmt->continuing, current_compound_statement_,
current_function_);
return StatementScope(stmt->continuing, continuing, [&] {
return Statements(stmt->continuing->statements);
}) != nullptr;
current_function_),
[&] { return Statements(stmt->continuing->statements); });
if (!continuing) {
return false;
}
behaviors.Add(continuing->Behaviors());
}
}
if (behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
behaviors.Add(sem::Behavior::kNext);
} else {
behaviors.Remove(sem::Behavior::kNext);
}
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
return true;
});
});
@ -988,11 +1039,14 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(
auto* sem = builder_->create<sem::ForLoopStatement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
auto& behaviors = sem->Behaviors();
if (auto* initializer = stmt->initializer) {
Mark(initializer);
if (!Statement(initializer)) {
auto* init = Statement(initializer);
if (!init) {
return false;
}
behaviors.Add(init->Behaviors());
}
if (auto* cond_expr = stmt->condition) {
@ -1001,13 +1055,16 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(
return false;
}
sem->SetCondition(cond);
behaviors.Add(cond->Behaviors());
}
if (auto* continuing = stmt->continuing) {
Mark(continuing);
if (!Statement(continuing)) {
auto* cont = Statement(continuing);
if (!cont) {
return false;
}
behaviors.Add(cont->Behaviors());
}
Mark(stmt->body);
@ -1019,6 +1076,15 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(
return false;
}
behaviors.Add(body->Behaviors());
if (stmt->condition ||
behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
behaviors.Add(sem::Behavior::kNext);
} else {
behaviors.Remove(sem::Behavior::kNext);
}
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
return ValidateForLoopStatement(sem);
});
}
@ -1072,6 +1138,19 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
if (!sem_expr) {
return nullptr;
}
// https://www.w3.org/TR/WGSL/#behaviors-rules
// an expression behavior is always either {Next} or {Next, Discard}
if (sem_expr->Behaviors() != sem::Behavior::kNext &&
sem_expr->Behaviors() != sem::Behaviors{sem::Behavior::kNext, // NOLINT
sem::Behavior::kDiscard} &&
!IsCallStatement(expr)) {
TINT_ICE(Resolver, diagnostics_)
<< expr->TypeInfo().name
<< " behaviors are: " << sem_expr->Behaviors();
return nullptr;
}
builder_->Sem().Add(expr, sem_expr);
if (expr == root) {
return sem_expr;
@ -1084,52 +1163,57 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
sem::Expression* Resolver::IndexAccessor(
const ast::IndexAccessorExpression* expr) {
auto* idx = expr->index;
auto* parent_raw_ty = TypeOf(expr->object);
auto* parent_ty = parent_raw_ty->UnwrapRef();
auto* idx = Sem(expr->index);
auto* obj = Sem(expr->object);
auto* obj_raw_ty = obj->Type();
auto* obj_ty = obj_raw_ty->UnwrapRef();
const sem::Type* ty = nullptr;
if (auto* arr = parent_ty->As<sem::Array>()) {
if (auto* arr = obj_ty->As<sem::Array>()) {
ty = arr->ElemType();
} else if (auto* vec = parent_ty->As<sem::Vector>()) {
} else if (auto* vec = obj_ty->As<sem::Vector>()) {
ty = vec->type();
} else if (auto* mat = parent_ty->As<sem::Matrix>()) {
} else if (auto* mat = obj_ty->As<sem::Matrix>()) {
ty = builder_->create<sem::Vector>(mat->type(), mat->rows());
} else {
AddError("cannot index type '" + TypeNameOf(parent_ty) + "'", expr->source);
AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", expr->source);
return nullptr;
}
auto* idx_ty = TypeOf(idx)->UnwrapRef();
auto* idx_ty = idx->Type()->UnwrapRef();
if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
AddError("index must be of type 'i32' or 'u32', found: '" +
TypeNameOf(idx_ty) + "'",
idx->source);
idx->Declaration()->source);
return nullptr;
}
if (parent_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
if (!parent_raw_ty->Is<sem::Reference>()) {
if (obj_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
if (!obj_raw_ty->Is<sem::Reference>()) {
// TODO(bclayton): expand this to allow any const_expr expression
// https://github.com/gpuweb/gpuweb/issues/1272
if (!idx->As<ast::IntLiteralExpression>()) {
if (!idx->Declaration()->As<ast::IntLiteralExpression>()) {
AddError("index must be signed or unsigned integer literal",
idx->source);
idx->Declaration()->source);
return nullptr;
}
}
}
// If we're extracting from a reference, we return a reference.
if (auto* ref = parent_raw_ty->As<sem::Reference>()) {
if (auto* ref = obj_raw_ty->As<sem::Reference>()) {
ty = builder_->create<sem::Reference>(ty, ref->StorageClass(),
ref->Access());
}
auto val = EvaluateConstantValue(expr, ty);
return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
auto* sem =
builder_->create<sem::Expression>(expr, ty, current_statement_, val);
sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
return sem;
}
sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
auto* inner = Sem(expr->expr);
auto* ty = Type(expr->type);
if (!ty) {
return nullptr;
@ -1140,12 +1224,17 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
}
auto val = EvaluateConstantValue(expr, ty);
return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
auto* sem =
builder_->create<sem::Expression>(expr, ty, current_statement_, val);
sem->Behaviors() = inner->Behaviors();
return sem;
}
sem::Call* Resolver::Call(const ast::CallExpression* expr) {
std::vector<const sem::Expression*> args(expr->args.size());
std::vector<const sem::Type*> arg_tys(args.size());
sem::Behaviors arg_behaviors;
for (size_t i = 0; i < expr->args.size(); i++) {
auto* arg = Sem(expr->args[i]);
if (!arg) {
@ -1153,8 +1242,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
}
args[i] = arg;
arg_tys[i] = args[i]->Type();
arg_behaviors.Add(arg->Behaviors());
}
arg_behaviors.Remove(sem::Behavior::kNext);
auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
// The call has resolved to a type constructor or cast.
if (args.size() == 1) {
@ -1192,7 +1284,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
}
if (auto* fn = As<sem::Function>(resolved)) {
return FunctionCall(expr, fn, std::move(args));
return FunctionCall(expr, fn, std::move(args), arg_behaviors);
}
auto name = builder_->Symbols().NameFor(ident->symbol);
@ -1247,7 +1339,8 @@ sem::Call* Resolver::IntrinsicCall(
sem::Call* Resolver::FunctionCall(
const ast::CallExpression* expr,
sem::Function* target,
const std::vector<const sem::Expression*> args) {
const std::vector<const sem::Expression*> args,
sem::Behaviors arg_behaviors) {
auto sym = expr->target.name->symbol;
auto name = builder_->Symbols().NameFor(sym);
@ -1272,6 +1365,8 @@ sem::Call* Resolver::FunctionCall(
target->AddCallSite(call);
call->Behaviors() = arg_behaviors + target->Behaviors();
if (!ValidateFunctionCall(call)) {
return nullptr;
}
@ -1285,15 +1380,10 @@ sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr,
const sem::Type* source) {
// It is not valid to have a type-cast call expression inside a call
// statement.
if (current_statement_) {
if (auto* stmt =
current_statement_->Declaration()->As<ast::CallStatement>()) {
if (stmt->expr == expr) {
if (IsCallStatement(expr)) {
AddError("type cast evaluated but not used", expr->source);
return nullptr;
}
}
}
auto* call_target = utils::GetOrCreate(
type_conversions_, TypeConversionSig{target, source},
@ -1349,15 +1439,10 @@ sem::Call* Resolver::TypeConstructor(
const std::vector<const sem::Type*> arg_tys) {
// It is not valid to have a type-constructor call expression as a call
// statement.
if (current_statement_) {
if (auto* stmt =
current_statement_->Declaration()->As<ast::CallStatement>()) {
if (stmt->expr == expr) {
if (IsCallStatement(expr)) {
AddError("type constructor evaluated but not used", expr->source);
return nullptr;
}
}
}
auto* call_target = utils::GetOrCreate(
type_ctors_, TypeConstructorSig{ty, arg_tys},
@ -1619,8 +1704,11 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
using Matrix = sem::Matrix;
using Vector = sem::Vector;
auto* lhs_ty = TypeOf(expr->lhs)->UnwrapRef();
auto* rhs_ty = TypeOf(expr->rhs)->UnwrapRef();
auto* lhs = Sem(expr->lhs);
auto* rhs = Sem(expr->rhs);
auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* lhs_vec = lhs_ty->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
@ -1636,7 +1724,10 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
auto build = [&](const sem::Type* ty) {
auto val = EvaluateConstantValue(expr, ty);
return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
auto* sem =
builder_->create<sem::Expression>(expr, ty, current_statement_, val);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
return sem;
};
// Binary logical expressions
@ -1798,7 +1889,8 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
}
sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
auto* expr_ty = TypeOf(unary->expr);
auto* expr = Sem(unary->expr);
auto* expr_ty = expr->Type();
if (!expr_ty) {
return nullptr;
}
@ -1880,7 +1972,10 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
}
auto val = EvaluateConstantValue(unary, ty);
return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
auto* sem =
builder_->create<sem::Expression>(unary, ty, current_statement_, val);
sem->Behaviors() = expr->Behaviors();
return sem;
}
sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
@ -2248,10 +2343,15 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
auto& behaviors = current_statement_->Behaviors();
behaviors = sem::Behavior::kReturn;
if (auto* value = stmt->value) {
if (!Expression(value)) {
auto* expr = Expression(value);
if (!expr) {
return false;
}
behaviors.Add(expr->Behaviors() - sem::Behavior::kNext);
}
// Validate after processing the return value expression so that its type is
@ -2265,17 +2365,28 @@ sem::SwitchStatement* Resolver::SwitchStatement(
auto* sem = builder_->create<sem::SwitchStatement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
if (!Expression(stmt->condition)) {
auto& behaviors = sem->Behaviors();
auto* cond = Expression(stmt->condition);
if (!cond) {
return false;
}
behaviors = cond->Behaviors() - sem::Behavior::kNext;
for (auto* case_stmt : stmt->body) {
Mark(case_stmt);
if (!CaseStatement(case_stmt)) {
auto* c = CaseStatement(case_stmt);
if (!c) {
return false;
}
behaviors.Add(c->Behaviors());
}
if (behaviors.Contains(sem::Behavior::kBreak)) {
behaviors.Add(sem::Behavior::kNext);
}
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
return ValidateSwitch(stmt);
});
}
@ -2304,6 +2415,10 @@ sem::Statement* Resolver::VariableDeclStatement(
current_block_->AddDecl(stmt->variable);
}
if (auto* ctor = var->Constructor()) {
sem->Behaviors() = ctor->Behaviors();
}
return ValidateVariable(var);
});
}
@ -2313,10 +2428,22 @@ sem::Statement* Resolver::AssignmentStatement(
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) {
auto* lhs = Expression(stmt->lhs);
if (!lhs) {
return false;
}
auto* rhs = Expression(stmt->rhs);
if (!rhs) {
return false;
}
auto& behaviors = sem->Behaviors();
behaviors = rhs->Behaviors();
if (!stmt->lhs->Is<ast::PhonyExpression>()) {
behaviors.Add(lhs->Behaviors());
}
return ValidateAssignment(stmt);
});
}
@ -2324,13 +2451,23 @@ sem::Statement* Resolver::AssignmentStatement(
sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); });
return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kBreak;
return ValidateBreakStatement(sem);
});
}
sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); });
return StatementScope(stmt, sem, [&] {
if (auto* expr = Expression(stmt->expr)) {
sem->Behaviors() = expr->Behaviors();
return true;
}
return false;
});
}
sem::Statement* Resolver::ContinueStatement(
@ -2338,6 +2475,8 @@ sem::Statement* Resolver::ContinueStatement(
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kContinue;
// Set if we've hit the first continue statement in our parent loop
if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
if (!block->FirstContinue()) {
@ -2354,6 +2493,7 @@ sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kDiscard;
current_function_->SetHasDiscard();
return ValidateDiscardStatement(sem);
@ -2365,6 +2505,8 @@ sem::Statement* Resolver::FallthroughStatement(
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kFallthrough;
return ValidateFallthroughStatement(sem);
});
}
@ -2512,6 +2654,12 @@ bool Resolver::IsIntrinsic(Symbol symbol) const {
return sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone;
}
bool Resolver::IsCallStatement(const ast::Expression* expr) const {
return current_statement_ &&
Is<ast::CallStatement>(current_statement_->Declaration(),
[&](auto* stmt) { return stmt->expr == expr; });
}
////////////////////////////////////////////////////////////////////////////////
// Resolver::TypeConversionSig
////////////////////////////////////////////////////////////////////////////////

View File

@ -185,7 +185,8 @@ class Resolver {
sem::Function* Function(const ast::Function*);
sem::Call* FunctionCall(const ast::CallExpression*,
sem::Function* target,
const std::vector<const sem::Expression*> args);
const std::vector<const sem::Expression*> args,
sem::Behaviors arg_behaviors);
sem::Expression* Identifier(const ast::IdentifierExpression*);
sem::Call* IntrinsicCall(const ast::CallExpression*,
sem::IntrinsicType,
@ -460,6 +461,9 @@ class Resolver {
/// function.
bool IsIntrinsic(Symbol) const;
/// @returns true if `expr` is the current CallStatement's CallExpression
bool IsCallStatement(const ast::Expression* expr) const;
/// @returns the resolved symbol (function, type or variable) for the given
/// ast::Identifier or ast::TypeName cast to the given semantic type.
template <typename SEM = sem::Node>

View File

@ -0,0 +1,687 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/resolver/resolver.h"
#include "gtest/gtest.h"
#include "src/resolver/resolver_test_helper.h"
#include "src/sem/expression.h"
namespace tint {
namespace resolver {
namespace {
class ResolverBehaviorTest : public ResolverTest {
protected:
void SetUp() override {
// Create a function called 'DiscardOrNext' which returns an i32, and has
// the behavior of {Discard, Return}, which when called, will have the
// behavior {Discard, Next}.
Func("DiscardOrNext", {}, ty.i32(),
{
If(true, Block(Discard())),
Return(1),
});
}
};
TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) {
auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1)));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) {
auto* stmt = Decl(Var("lhs", ty.i32(), Add(1, Call("DiscardOrNext"))));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, ExprBitcastOp) {
auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast<u32>(Call("DiscardOrNext"))));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, ExprIndex_Arr) {
Func("ArrayDiscardOrNext", {}, ty.array<i32, 4>(),
{
If(true, Block(Discard())),
Return(Construct(ty.array<i32, 4>())),
});
auto* stmt =
Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1)));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, ExprIndex_Idx) {
auto* stmt =
Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext"))));
WrapInFunction(Decl(Var("arr", ty.array<i32, 4>())), //
stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, ExprUnaryOp) {
auto* stmt = Decl(Var("lhs", ty.i32(),
create<ast::UnaryOpExpression>(
ast::UnaryOp::kComplement, Call("DiscardOrNext"))));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtAssign) {
auto* stmt = Assign("lhs", "rhs");
WrapInFunction(Decl(Var("lhs", ty.i32())), //
Decl(Var("rhs", ty.i32())), //
stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) {
auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1);
WrapInFunction(Decl(Var("lhs", ty.array<i32, 4>())), //
stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) {
auto* stmt = Assign("lhs", Call("DiscardOrNext"));
WrapInFunction(Decl(Var("lhs", ty.i32())), //
stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtBlockEmpty) {
auto* stmt = Block();
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) {
auto* stmt = Block(Discard());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtCallReturn) {
Func("f", {}, ty.void_(), {Return()});
auto* stmt = CallStmt(Call("f"));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) {
Func("f", {}, ty.void_(), {Discard()});
auto* stmt = CallStmt(Call("f"));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) {
auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
nullptr, Block(Break()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtBreak) {
auto* stmt = Break();
WrapInFunction(Loop(Block(stmt)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kBreak);
}
TEST_F(ResolverBehaviorTest, StmtContinue) {
auto* stmt = Continue();
WrapInFunction(Loop(Block(stmt)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kContinue);
}
TEST_F(ResolverBehaviorTest, StmtDiscard) {
auto* stmt = Discard();
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty) {
auto* stmt = For(nullptr, nullptr, nullptr, Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_TRUE(sem->Behaviors().Empty());
}
TEST_F(ResolverBehaviorTest, StmtForLoopBreak) {
auto* stmt = For(nullptr, nullptr, nullptr, Block(Break()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtForLoopContinue) {
auto* stmt = For(nullptr, nullptr, nullptr, Block(Continue()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_TRUE(sem->Behaviors().Empty());
}
TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) {
auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtForLoopReturn) {
auto* stmt = For(nullptr, nullptr, nullptr, Block(Return()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
}
TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) {
auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
nullptr, Block(Break()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) {
auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
nullptr, Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondTrue) {
auto* stmt = For(nullptr, true, nullptr, Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) {
auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1), nullptr, Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtForLoopBreak_ContCallFuncMayDiscard) {
auto* stmt =
For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block(Break()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_ContCallFuncMayDiscard) {
auto* stmt = For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) {
auto* stmt = If(true, Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) {
auto* stmt = If(true, Block(Discard()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) {
auto* stmt = If(true, Block(), Else(Block(Discard())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) {
auto* stmt = If(true, Block(Discard()), Else(Block(Discard())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) {
auto* stmt = If(Equal(Call("DiscardOrNext"), 1), Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) {
auto* stmt = If(true, Block(), //
Else(Equal(Call("DiscardOrNext"), 1), Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtLetDecl) {
auto* stmt = Decl(Const("v", ty.i32(), Expr(1)));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) {
auto* stmt = Decl(Const("lhs", ty.i32(), Call("DiscardOrNext")));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtLoopEmpty) {
auto* stmt = Loop(Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_TRUE(sem->Behaviors().Empty());
}
TEST_F(ResolverBehaviorTest, StmtLoopBreak) {
auto* stmt = Loop(Block(Break()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtLoopContinue) {
auto* stmt = Loop(Block(Continue()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_TRUE(sem->Behaviors().Empty());
}
TEST_F(ResolverBehaviorTest, StmtLoopDiscard) {
auto* stmt = Loop(Block(Discard()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtLoopReturn) {
auto* stmt = Loop(Block(Return()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
}
TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContEmpty) {
auto* stmt = Loop(Block(), Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_TRUE(sem->Behaviors().Empty());
}
TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContBreak) {
auto* stmt = Loop(Block(), Block(Break()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtReturn) {
auto* stmt = Return();
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
}
TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) {
auto* stmt = Return(Call("DiscardOrNext"));
Func("F", {}, ty.i32(), {stmt});
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kReturn, sem::Behavior::kDiscard));
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondTrue_DefaultEmpty) {
auto* stmt = Switch(1, DefaultCase(Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultEmpty) {
auto* stmt = Switch(1, DefaultCase(Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) {
auto* stmt = Switch(1, DefaultCase(Block(Discard())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultReturn) {
auto* stmt = Switch(1, DefaultCase(Block(Return())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) {
auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) {
auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Discard())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kDiscard));
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) {
auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Return())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kReturn));
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
auto* stmt = Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest,
StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) {
auto* stmt =
Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Discard())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
}
TEST_F(ResolverBehaviorTest,
StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) {
auto* stmt =
Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Return())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kReturn));
}
TEST_F(ResolverBehaviorTest,
StmtSwitch_CondLiteral_Case0Discard_Case1Return_DefaultEmpty) {
auto* stmt = Switch(1, //
Case(Expr(0), Block(Discard())), //
Case(Expr(1), Block(Return())), //
DefaultCase(Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext,
sem::Behavior::kReturn));
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) {
auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtVarDecl) {
auto* stmt = Decl(Var("v", ty.i32()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) {
auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext")));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(),
sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -1042,6 +1042,18 @@ bool Resolver::ValidateFunction(const sem::Function* func) {
}
}
// https://www.w3.org/TR/WGSL/#behaviors-rules
// a function behavior is always one of {}, {Next}, {Discard}, or
// {Next, Discard}.
if (func->Behaviors() != sem::Behaviors{} && // NOLINT: bad warning
func->Behaviors() != sem::Behavior::kNext &&
func->Behaviors() != sem::Behavior::kDiscard &&
func->Behaviors() != sem::Behaviors{sem::Behavior::kNext, //
sem::Behavior::kDiscard}) {
TINT_ICE(Resolver, diagnostics_)
<< "function '" << name << "' behaviors are: " << func->Behaviors();
}
return true;
}

View File

@ -68,7 +68,7 @@ class Expression : public Castable<Expression, Node> {
const sem::Type* const type_;
const Statement* const statement_;
const Constant constant_;
sem::Behaviors behaviors_;
sem::Behaviors behaviors_{sem::Behavior::kNext};
};
} // namespace sem

View File

@ -240,6 +240,12 @@ class Function : public Castable<Function, CallTarget> {
/// @returns true if this function has a discard statement
bool HasDiscard() const { return has_discard_; }
/// @return the behaviors of this function
const sem::Behaviors& Behaviors() const { return behaviors_; }
/// @return the behaviors of this function
sem::Behaviors& Behaviors() { return behaviors_; }
private:
VariableBindings TransitivelyReferencedSamplerVariablesImpl(
ast::SamplerKind kind) const;
@ -257,6 +263,7 @@ class Function : public Castable<Function, CallTarget> {
std::vector<const Call*> callsites_;
std::vector<const Function*> ancestor_entry_points_;
bool has_discard_ = false;
sem::Behaviors behaviors_{sem::Behavior::kNext};
};
} // namespace sem

View File

@ -110,8 +110,7 @@ class Statement : public Castable<Statement, Node> {
const ast::Statement* const declaration_;
const CompoundStatement* const parent_;
const sem::Function* const function_;
sem::Behaviors behaviors_;
sem::Behaviors behaviors_{sem::Behavior::kNext};
};
/// CompoundStatement is the base class of statements that can hold other

View File

@ -253,6 +253,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") {
"../src/resolver/pipeline_overridable_constant_test.cc",
"../src/resolver/ptr_ref_test.cc",
"../src/resolver/ptr_ref_validation_test.cc",
"../src/resolver/resolver_behavior_test.cc",
"../src/resolver/resolver_constants_test.cc",
"../src/resolver/resolver_test.cc",
"../src/resolver/resolver_test_helper.cc",