diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e587fb551a..7a3bb13fb1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -689,6 +689,7 @@ if(${TINT_BUILD_TESTS}) resolver/pipeline_overridable_constant_test.cc resolver/ptr_ref_test.cc resolver/ptr_ref_validation_test.cc + resolver/resolver_behavior_test.cc resolver/resolver_constants_test.cc resolver/resolver_test_helper.cc resolver/resolver_test_helper.h diff --git a/src/program_builder.h b/src/program_builder.h index bc7a52ce38..43351261cd 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -32,6 +32,7 @@ #include "src/ast/call_expression.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" +#include "src/ast/continue_statement.h" #include "src/ast/depth_multisampled_texture.h" #include "src/ast/depth_texture.h" #include "src/ast/disable_validation_decoration.h" @@ -1864,6 +1865,19 @@ class ProgramBuilder { /// @returns the break statement pointer const ast::BreakStatement* Break() { return create(); } + /// Creates an ast::ContinueStatement + /// @param source the source information + /// @returns the continue statement pointer + const ast::ContinueStatement* Continue(const Source& source) { + return create(source); + } + + /// Creates an ast::ContinueStatement + /// @returns the continue statement pointer + const ast::ContinueStatement* Continue() { + return create(); + } + /// Creates an ast::ReturnStatement with no return value /// @param source the source information /// @returns the return statement pointer @@ -2041,6 +2055,13 @@ class ProgramBuilder { body); } + /// Creates a ast::ElseStatement with no condition and body + /// @param body the else body + /// @returns the else statement pointer + const ast::ElseStatement* Else(const ast::BlockStatement* body) { + return create(nullptr, body); + } + /// Creates a ast::IfStatement with input condition, body, and optional /// variadic else statements /// @param condition the if statement condition expression diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 36ca10785b..6ac7835da1 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -657,11 +657,21 @@ sem::Function* Resolver::Function(const ast::Function* decl) { << "Resolver::Function() called with a current compound statement"; return nullptr; } - if (!StatementScope(decl->body, - builder_->create(func), - [&] { return Statements(decl->body->statements); })) { + auto* body = StatementScope( + decl->body, builder_->create(func), + [&] { return Statements(decl->body->statements); }); + if (!body) { return nullptr; } + func->Behaviors() = body->Behaviors(); + if (func->Behaviors().Contains(sem::Behavior::kReturn)) { + // https://www.w3.org/TR/WGSL/#behaviors-rules + // We assign a behavior to each function: it is its body’s behavior + // (treating the body as a regular statement), with any "Return" replaced + // by "Next". + func->Behaviors().Remove(sem::Behavior::kReturn); + func->Behaviors().Add(sem::Behavior::kNext); + } } for (auto* deco : decl->decorations) { @@ -797,13 +807,22 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { } bool Resolver::Statements(const ast::StatementList& stmts) { + sem::Behaviors behaviors{sem::Behavior::kNext}; + for (auto* stmt : stmts) { Mark(stmt); auto* sem = Statement(stmt); if (!sem) { return false; } + // s1 s2:(B1∖{Next}) ∪ B2 + // ValidateStatements will ensure that statements can only follow a Next. + behaviors.Remove(sem::Behavior::kNext); + behaviors.Add(sem->Behaviors()); } + + current_statement_->Behaviors() = behaviors; + if (!ValidateStatements(stmts)) { return false; } @@ -887,6 +906,7 @@ sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) { return false; } sem->SetBlock(body); + sem->Behaviors() = body->Behaviors(); return true; }); } @@ -900,6 +920,8 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) { return false; } sem->SetCondition(cond); + sem->Behaviors() = cond->Behaviors(); + sem->Behaviors().Remove(sem::Behavior::kNext); Mark(stmt->body); auto* body = builder_->create( @@ -908,12 +930,23 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) { [&] { return Statements(stmt->body->statements); })) { return false; } + sem->Behaviors().Add(body->Behaviors()); for (auto* else_stmt : stmt->else_statements) { Mark(else_stmt); - if (!ElseStatement(else_stmt)) { + auto* else_sem = ElseStatement(else_stmt); + if (!else_sem) { return false; } + sem->Behaviors().Add(else_sem->Behaviors()); + } + + if (stmt->else_statements.empty() || + stmt->else_statements.back()->condition != nullptr) { + // https://www.w3.org/TR/WGSL/#behaviors-rules + // if statements without an else branch are treated as if they had an + // empty else branch (which adds Next to their behavior) + sem->Behaviors().Add(sem::Behavior::kNext); } return ValidateIfStatement(sem); @@ -930,7 +963,12 @@ sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) { return false; } sem->SetCondition(cond); + // https://www.w3.org/TR/WGSL/#behaviors-rules + // if statements with else if branches are treated as if they were nested + // simple if/else statements + sem->Behaviors() = cond->Behaviors(); } + sem->Behaviors().Remove(sem::Behavior::kNext); Mark(stmt->body); auto* body = builder_->create( @@ -939,6 +977,7 @@ sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) { [&] { return Statements(stmt->body->statements); })) { return false; } + sem->Behaviors().Add(body->Behaviors()); return ValidateElseStatement(sem); }); @@ -964,20 +1003,32 @@ sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) { if (!Statements(stmt->body->statements)) { return false; } + auto& behaviors = sem->Behaviors(); + behaviors = body->Behaviors(); if (stmt->continuing) { Mark(stmt->continuing); if (!stmt->continuing->Empty()) { - auto* continuing = + auto* continuing = StatementScope( + stmt->continuing, builder_->create( stmt->continuing, current_compound_statement_, - current_function_); - return StatementScope(stmt->continuing, continuing, [&] { - return Statements(stmt->continuing->statements); - }) != nullptr; + current_function_), + [&] { return Statements(stmt->continuing->statements); }); + if (!continuing) { + return false; + } + behaviors.Add(continuing->Behaviors()); } } + if (behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit? + behaviors.Add(sem::Behavior::kNext); + } else { + behaviors.Remove(sem::Behavior::kNext); + } + behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue); + return true; }); }); @@ -988,11 +1039,14 @@ sem::ForLoopStatement* Resolver::ForLoopStatement( auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { + auto& behaviors = sem->Behaviors(); if (auto* initializer = stmt->initializer) { Mark(initializer); - if (!Statement(initializer)) { + auto* init = Statement(initializer); + if (!init) { return false; } + behaviors.Add(init->Behaviors()); } if (auto* cond_expr = stmt->condition) { @@ -1001,13 +1055,16 @@ sem::ForLoopStatement* Resolver::ForLoopStatement( return false; } sem->SetCondition(cond); + behaviors.Add(cond->Behaviors()); } if (auto* continuing = stmt->continuing) { Mark(continuing); - if (!Statement(continuing)) { + auto* cont = Statement(continuing); + if (!cont) { return false; } + behaviors.Add(cont->Behaviors()); } Mark(stmt->body); @@ -1019,6 +1076,15 @@ sem::ForLoopStatement* Resolver::ForLoopStatement( return false; } + behaviors.Add(body->Behaviors()); + if (stmt->condition || + behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit? + behaviors.Add(sem::Behavior::kNext); + } else { + behaviors.Remove(sem::Behavior::kNext); + } + behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue); + return ValidateForLoopStatement(sem); }); } @@ -1072,6 +1138,19 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) { if (!sem_expr) { return nullptr; } + + // https://www.w3.org/TR/WGSL/#behaviors-rules + // an expression behavior is always either {Next} or {Next, Discard} + if (sem_expr->Behaviors() != sem::Behavior::kNext && + sem_expr->Behaviors() != sem::Behaviors{sem::Behavior::kNext, // NOLINT + sem::Behavior::kDiscard} && + !IsCallStatement(expr)) { + TINT_ICE(Resolver, diagnostics_) + << expr->TypeInfo().name + << " behaviors are: " << sem_expr->Behaviors(); + return nullptr; + } + builder_->Sem().Add(expr, sem_expr); if (expr == root) { return sem_expr; @@ -1084,52 +1163,57 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) { sem::Expression* Resolver::IndexAccessor( const ast::IndexAccessorExpression* expr) { - auto* idx = expr->index; - auto* parent_raw_ty = TypeOf(expr->object); - auto* parent_ty = parent_raw_ty->UnwrapRef(); + auto* idx = Sem(expr->index); + auto* obj = Sem(expr->object); + auto* obj_raw_ty = obj->Type(); + auto* obj_ty = obj_raw_ty->UnwrapRef(); const sem::Type* ty = nullptr; - if (auto* arr = parent_ty->As()) { + if (auto* arr = obj_ty->As()) { ty = arr->ElemType(); - } else if (auto* vec = parent_ty->As()) { + } else if (auto* vec = obj_ty->As()) { ty = vec->type(); - } else if (auto* mat = parent_ty->As()) { + } else if (auto* mat = obj_ty->As()) { ty = builder_->create(mat->type(), mat->rows()); } else { - AddError("cannot index type '" + TypeNameOf(parent_ty) + "'", expr->source); + AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", expr->source); return nullptr; } - auto* idx_ty = TypeOf(idx)->UnwrapRef(); + auto* idx_ty = idx->Type()->UnwrapRef(); if (!idx_ty->IsAnyOf()) { AddError("index must be of type 'i32' or 'u32', found: '" + TypeNameOf(idx_ty) + "'", - idx->source); + idx->Declaration()->source); return nullptr; } - if (parent_ty->IsAnyOf()) { - if (!parent_raw_ty->Is()) { + if (obj_ty->IsAnyOf()) { + if (!obj_raw_ty->Is()) { // TODO(bclayton): expand this to allow any const_expr expression // https://github.com/gpuweb/gpuweb/issues/1272 - if (!idx->As()) { + if (!idx->Declaration()->As()) { AddError("index must be signed or unsigned integer literal", - idx->source); + idx->Declaration()->source); return nullptr; } } } // If we're extracting from a reference, we return a reference. - if (auto* ref = parent_raw_ty->As()) { + if (auto* ref = obj_raw_ty->As()) { ty = builder_->create(ty, ref->StorageClass(), ref->Access()); } auto val = EvaluateConstantValue(expr, ty); - return builder_->create(expr, ty, current_statement_, val); + auto* sem = + builder_->create(expr, ty, current_statement_, val); + sem->Behaviors() = idx->Behaviors() + obj->Behaviors(); + return sem; } sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { + auto* inner = Sem(expr->expr); auto* ty = Type(expr->type); if (!ty) { return nullptr; @@ -1140,12 +1224,17 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { } auto val = EvaluateConstantValue(expr, ty); - return builder_->create(expr, ty, current_statement_, val); + auto* sem = + builder_->create(expr, ty, current_statement_, val); + sem->Behaviors() = inner->Behaviors(); + return sem; } sem::Call* Resolver::Call(const ast::CallExpression* expr) { std::vector args(expr->args.size()); std::vector arg_tys(args.size()); + sem::Behaviors arg_behaviors; + for (size_t i = 0; i < expr->args.size(); i++) { auto* arg = Sem(expr->args[i]); if (!arg) { @@ -1153,8 +1242,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { } args[i] = arg; arg_tys[i] = args[i]->Type(); + arg_behaviors.Add(arg->Behaviors()); } + arg_behaviors.Remove(sem::Behavior::kNext); + auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* { // The call has resolved to a type constructor or cast. if (args.size() == 1) { @@ -1192,7 +1284,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { } if (auto* fn = As(resolved)) { - return FunctionCall(expr, fn, std::move(args)); + return FunctionCall(expr, fn, std::move(args), arg_behaviors); } auto name = builder_->Symbols().NameFor(ident->symbol); @@ -1247,7 +1339,8 @@ sem::Call* Resolver::IntrinsicCall( sem::Call* Resolver::FunctionCall( const ast::CallExpression* expr, sem::Function* target, - const std::vector args) { + const std::vector args, + sem::Behaviors arg_behaviors) { auto sym = expr->target.name->symbol; auto name = builder_->Symbols().NameFor(sym); @@ -1272,6 +1365,8 @@ sem::Call* Resolver::FunctionCall( target->AddCallSite(call); + call->Behaviors() = arg_behaviors + target->Behaviors(); + if (!ValidateFunctionCall(call)) { return nullptr; } @@ -1285,14 +1380,9 @@ sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr, const sem::Type* source) { // It is not valid to have a type-cast call expression inside a call // statement. - if (current_statement_) { - if (auto* stmt = - current_statement_->Declaration()->As()) { - if (stmt->expr == expr) { - AddError("type cast evaluated but not used", expr->source); - return nullptr; - } - } + if (IsCallStatement(expr)) { + AddError("type cast evaluated but not used", expr->source); + return nullptr; } auto* call_target = utils::GetOrCreate( @@ -1349,14 +1439,9 @@ sem::Call* Resolver::TypeConstructor( const std::vector arg_tys) { // It is not valid to have a type-constructor call expression as a call // statement. - if (current_statement_) { - if (auto* stmt = - current_statement_->Declaration()->As()) { - if (stmt->expr == expr) { - AddError("type constructor evaluated but not used", expr->source); - return nullptr; - } - } + if (IsCallStatement(expr)) { + AddError("type constructor evaluated but not used", expr->source); + return nullptr; } auto* call_target = utils::GetOrCreate( @@ -1619,8 +1704,11 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { using Matrix = sem::Matrix; using Vector = sem::Vector; - auto* lhs_ty = TypeOf(expr->lhs)->UnwrapRef(); - auto* rhs_ty = TypeOf(expr->rhs)->UnwrapRef(); + auto* lhs = Sem(expr->lhs); + auto* rhs = Sem(expr->rhs); + + auto* lhs_ty = lhs->Type()->UnwrapRef(); + auto* rhs_ty = rhs->Type()->UnwrapRef(); auto* lhs_vec = lhs_ty->As(); auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr; @@ -1636,7 +1724,10 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { auto build = [&](const sem::Type* ty) { auto val = EvaluateConstantValue(expr, ty); - return builder_->create(expr, ty, current_statement_, val); + auto* sem = + builder_->create(expr, ty, current_statement_, val); + sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); + return sem; }; // Binary logical expressions @@ -1798,7 +1889,8 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { } sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { - auto* expr_ty = TypeOf(unary->expr); + auto* expr = Sem(unary->expr); + auto* expr_ty = expr->Type(); if (!expr_ty) { return nullptr; } @@ -1880,7 +1972,10 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { } auto val = EvaluateConstantValue(unary, ty); - return builder_->create(unary, ty, current_statement_, val); + auto* sem = + builder_->create(unary, ty, current_statement_, val); + sem->Behaviors() = expr->Behaviors(); + return sem; } sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) { @@ -2248,10 +2343,15 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { + auto& behaviors = current_statement_->Behaviors(); + behaviors = sem::Behavior::kReturn; + if (auto* value = stmt->value) { - if (!Expression(value)) { + auto* expr = Expression(value); + if (!expr) { return false; } + behaviors.Add(expr->Behaviors() - sem::Behavior::kNext); } // Validate after processing the return value expression so that its type is @@ -2265,17 +2365,28 @@ sem::SwitchStatement* Resolver::SwitchStatement( auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { - if (!Expression(stmt->condition)) { + auto& behaviors = sem->Behaviors(); + + auto* cond = Expression(stmt->condition); + if (!cond) { return false; } + behaviors = cond->Behaviors() - sem::Behavior::kNext; for (auto* case_stmt : stmt->body) { Mark(case_stmt); - if (!CaseStatement(case_stmt)) { + auto* c = CaseStatement(case_stmt); + if (!c) { return false; } + behaviors.Add(c->Behaviors()); } + if (behaviors.Contains(sem::Behavior::kBreak)) { + behaviors.Add(sem::Behavior::kNext); + } + behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough); + return ValidateSwitch(stmt); }); } @@ -2304,6 +2415,10 @@ sem::Statement* Resolver::VariableDeclStatement( current_block_->AddDecl(stmt->variable); } + if (auto* ctor = var->Constructor()) { + sem->Behaviors() = ctor->Behaviors(); + } + return ValidateVariable(var); }); } @@ -2313,10 +2428,22 @@ sem::Statement* Resolver::AssignmentStatement( auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { - if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) { + auto* lhs = Expression(stmt->lhs); + if (!lhs) { return false; } + auto* rhs = Expression(stmt->rhs); + if (!rhs) { + return false; + } + + auto& behaviors = sem->Behaviors(); + behaviors = rhs->Behaviors(); + if (!stmt->lhs->Is()) { + behaviors.Add(lhs->Behaviors()); + } + return ValidateAssignment(stmt); }); } @@ -2324,13 +2451,23 @@ sem::Statement* Resolver::AssignmentStatement( sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); - return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); }); + return StatementScope(stmt, sem, [&] { + sem->Behaviors() = sem::Behavior::kBreak; + + return ValidateBreakStatement(sem); + }); } sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); - return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); }); + return StatementScope(stmt, sem, [&] { + if (auto* expr = Expression(stmt->expr)) { + sem->Behaviors() = expr->Behaviors(); + return true; + } + return false; + }); } sem::Statement* Resolver::ContinueStatement( @@ -2338,6 +2475,8 @@ sem::Statement* Resolver::ContinueStatement( auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { + sem->Behaviors() = sem::Behavior::kContinue; + // Set if we've hit the first continue statement in our parent loop if (auto* block = sem->FindFirstParent()) { if (!block->FirstContinue()) { @@ -2354,6 +2493,7 @@ sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { + sem->Behaviors() = sem::Behavior::kDiscard; current_function_->SetHasDiscard(); return ValidateDiscardStatement(sem); @@ -2365,6 +2505,8 @@ sem::Statement* Resolver::FallthroughStatement( auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { + sem->Behaviors() = sem::Behavior::kFallthrough; + return ValidateFallthroughStatement(sem); }); } @@ -2512,6 +2654,12 @@ bool Resolver::IsIntrinsic(Symbol symbol) const { return sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone; } +bool Resolver::IsCallStatement(const ast::Expression* expr) const { + return current_statement_ && + Is(current_statement_->Declaration(), + [&](auto* stmt) { return stmt->expr == expr; }); +} + //////////////////////////////////////////////////////////////////////////////// // Resolver::TypeConversionSig //////////////////////////////////////////////////////////////////////////////// diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 84565eef87..336e183bda 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -185,7 +185,8 @@ class Resolver { sem::Function* Function(const ast::Function*); sem::Call* FunctionCall(const ast::CallExpression*, sem::Function* target, - const std::vector args); + const std::vector args, + sem::Behaviors arg_behaviors); sem::Expression* Identifier(const ast::IdentifierExpression*); sem::Call* IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType, @@ -460,6 +461,9 @@ class Resolver { /// function. bool IsIntrinsic(Symbol) const; + /// @returns true if `expr` is the current CallStatement's CallExpression + bool IsCallStatement(const ast::Expression* expr) const; + /// @returns the resolved symbol (function, type or variable) for the given /// ast::Identifier or ast::TypeName cast to the given semantic type. template diff --git a/src/resolver/resolver_behavior_test.cc b/src/resolver/resolver_behavior_test.cc new file mode 100644 index 0000000000..6ee656bf86 --- /dev/null +++ b/src/resolver/resolver_behavior_test.cc @@ -0,0 +1,687 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/resolver/resolver.h" + +#include "gtest/gtest.h" +#include "src/resolver/resolver_test_helper.h" +#include "src/sem/expression.h" + +namespace tint { +namespace resolver { +namespace { + +class ResolverBehaviorTest : public ResolverTest { + protected: + void SetUp() override { + // Create a function called 'DiscardOrNext' which returns an i32, and has + // the behavior of {Discard, Return}, which when called, will have the + // behavior {Discard, Next}. + Func("DiscardOrNext", {}, ty.i32(), + { + If(true, Block(Discard())), + Return(1), + }); + } +}; + +TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) { + auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) { + auto* stmt = Decl(Var("lhs", ty.i32(), Add(1, Call("DiscardOrNext")))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, ExprBitcastOp) { + auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast(Call("DiscardOrNext")))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, ExprIndex_Arr) { + Func("ArrayDiscardOrNext", {}, ty.array(), + { + If(true, Block(Discard())), + Return(Construct(ty.array())), + }); + + auto* stmt = + Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, ExprIndex_Idx) { + auto* stmt = + Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext")))); + WrapInFunction(Decl(Var("arr", ty.array())), // + stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, ExprUnaryOp) { + auto* stmt = Decl(Var("lhs", ty.i32(), + create( + ast::UnaryOp::kComplement, Call("DiscardOrNext")))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtAssign) { + auto* stmt = Assign("lhs", "rhs"); + WrapInFunction(Decl(Var("lhs", ty.i32())), // + Decl(Var("rhs", ty.i32())), // + stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) { + auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1); + WrapInFunction(Decl(Var("lhs", ty.array())), // + stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) { + auto* stmt = Assign("lhs", Call("DiscardOrNext")); + WrapInFunction(Decl(Var("lhs", ty.i32())), // + stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtBlockEmpty) { + auto* stmt = Block(); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) { + auto* stmt = Block(Discard()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtCallReturn) { + Func("f", {}, ty.void_(), {Return()}); + auto* stmt = CallStmt(Call("f")); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) { + Func("f", {}, ty.void_(), {Discard()}); + auto* stmt = CallStmt(Call("f")); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) { + auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, + nullptr, Block(Break())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtBreak) { + auto* stmt = Break(); + WrapInFunction(Loop(Block(stmt))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kBreak); +} + +TEST_F(ResolverBehaviorTest, StmtContinue) { + auto* stmt = Continue(); + WrapInFunction(Loop(Block(stmt))); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kContinue); +} + +TEST_F(ResolverBehaviorTest, StmtDiscard) { + auto* stmt = Discard(); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopEmpty) { + auto* stmt = For(nullptr, nullptr, nullptr, Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_TRUE(sem->Behaviors().Empty()); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopBreak) { + auto* stmt = For(nullptr, nullptr, nullptr, Block(Break())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopContinue) { + auto* stmt = For(nullptr, nullptr, nullptr, Block(Continue())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_TRUE(sem->Behaviors().Empty()); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) { + auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopReturn) { + auto* stmt = For(nullptr, nullptr, nullptr, Block(Return())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) { + auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, + nullptr, Block(Break())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) { + auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, + nullptr, Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondTrue) { + auto* stmt = For(nullptr, true, nullptr, Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) { + auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1), nullptr, Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopBreak_ContCallFuncMayDiscard) { + auto* stmt = + For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block(Break())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_ContCallFuncMayDiscard) { + auto* stmt = For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) { + auto* stmt = If(true, Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) { + auto* stmt = If(true, Block(Discard())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) { + auto* stmt = If(true, Block(), Else(Block(Discard()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) { + auto* stmt = If(true, Block(Discard()), Else(Block(Discard()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) { + auto* stmt = If(Equal(Call("DiscardOrNext"), 1), Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) { + auto* stmt = If(true, Block(), // + Else(Equal(Call("DiscardOrNext"), 1), Block())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtLetDecl) { + auto* stmt = Decl(Const("v", ty.i32(), Expr(1))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) { + auto* stmt = Decl(Const("lhs", ty.i32(), Call("DiscardOrNext"))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtLoopEmpty) { + auto* stmt = Loop(Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_TRUE(sem->Behaviors().Empty()); +} + +TEST_F(ResolverBehaviorTest, StmtLoopBreak) { + auto* stmt = Loop(Block(Break())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtLoopContinue) { + auto* stmt = Loop(Block(Continue())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_TRUE(sem->Behaviors().Empty()); +} + +TEST_F(ResolverBehaviorTest, StmtLoopDiscard) { + auto* stmt = Loop(Block(Discard())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtLoopReturn) { + auto* stmt = Loop(Block(Return())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn); +} + +TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContEmpty) { + auto* stmt = Loop(Block(), Block()); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_TRUE(sem->Behaviors().Empty()); +} + +TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContBreak) { + auto* stmt = Loop(Block(), Block(Break())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtReturn) { + auto* stmt = Return(); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn); +} + +TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) { + auto* stmt = Return(Call("DiscardOrNext")); + Func("F", {}, ty.i32(), {stmt}); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kReturn, sem::Behavior::kDiscard)); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondTrue_DefaultEmpty) { + auto* stmt = Switch(1, DefaultCase(Block())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultEmpty) { + auto* stmt = Switch(1, DefaultCase(Block())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) { + auto* stmt = Switch(1, DefaultCase(Block(Discard()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultReturn) { + auto* stmt = Switch(1, DefaultCase(Block(Return()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) { + auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) { + auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Discard()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kDiscard)); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) { + auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Return()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kReturn)); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) { + auto* stmt = Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, + StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) { + auto* stmt = + Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Discard()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard); +} + +TEST_F(ResolverBehaviorTest, + StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) { + auto* stmt = + Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Return()))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kReturn)); +} + +TEST_F(ResolverBehaviorTest, + StmtSwitch_CondLiteral_Case0Discard_Case1Return_DefaultEmpty) { + auto* stmt = Switch(1, // + Case(Expr(0), Block(Discard())), // + Case(Expr(1), Block(Return())), // + DefaultCase(Block())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext, + sem::Behavior::kReturn)); +} + +TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) { + auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +TEST_F(ResolverBehaviorTest, StmtVarDecl) { + auto* stmt = Decl(Var("v", ty.i32())); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext); +} + +TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) { + auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext"))); + WrapInFunction(stmt); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(stmt); + EXPECT_EQ(sem->Behaviors(), + sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext)); +} + +} // namespace +} // namespace resolver +} // namespace tint diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc index eb2be86c6b..ef365ce593 100644 --- a/src/resolver/resolver_validation.cc +++ b/src/resolver/resolver_validation.cc @@ -1042,6 +1042,18 @@ bool Resolver::ValidateFunction(const sem::Function* func) { } } + // https://www.w3.org/TR/WGSL/#behaviors-rules + // a function behavior is always one of {}, {Next}, {Discard}, or + // {Next, Discard}. + if (func->Behaviors() != sem::Behaviors{} && // NOLINT: bad warning + func->Behaviors() != sem::Behavior::kNext && + func->Behaviors() != sem::Behavior::kDiscard && + func->Behaviors() != sem::Behaviors{sem::Behavior::kNext, // + sem::Behavior::kDiscard}) { + TINT_ICE(Resolver, diagnostics_) + << "function '" << name << "' behaviors are: " << func->Behaviors(); + } + return true; } diff --git a/src/sem/expression.h b/src/sem/expression.h index 06ae10b87c..b2ff4ac899 100644 --- a/src/sem/expression.h +++ b/src/sem/expression.h @@ -68,7 +68,7 @@ class Expression : public Castable { const sem::Type* const type_; const Statement* const statement_; const Constant constant_; - sem::Behaviors behaviors_; + sem::Behaviors behaviors_{sem::Behavior::kNext}; }; } // namespace sem diff --git a/src/sem/function.h b/src/sem/function.h index ea834a7c54..6d980c5f49 100644 --- a/src/sem/function.h +++ b/src/sem/function.h @@ -240,6 +240,12 @@ class Function : public Castable { /// @returns true if this function has a discard statement bool HasDiscard() const { return has_discard_; } + /// @return the behaviors of this function + const sem::Behaviors& Behaviors() const { return behaviors_; } + + /// @return the behaviors of this function + sem::Behaviors& Behaviors() { return behaviors_; } + private: VariableBindings TransitivelyReferencedSamplerVariablesImpl( ast::SamplerKind kind) const; @@ -257,6 +263,7 @@ class Function : public Castable { std::vector callsites_; std::vector ancestor_entry_points_; bool has_discard_ = false; + sem::Behaviors behaviors_{sem::Behavior::kNext}; }; } // namespace sem diff --git a/src/sem/statement.h b/src/sem/statement.h index e1c5160602..1468da919e 100644 --- a/src/sem/statement.h +++ b/src/sem/statement.h @@ -110,8 +110,7 @@ class Statement : public Castable { const ast::Statement* const declaration_; const CompoundStatement* const parent_; const sem::Function* const function_; - - sem::Behaviors behaviors_; + sem::Behaviors behaviors_{sem::Behavior::kNext}; }; /// CompoundStatement is the base class of statements that can hold other diff --git a/test/BUILD.gn b/test/BUILD.gn index bfaa498f0e..c2cc4826f9 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -253,6 +253,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") { "../src/resolver/pipeline_overridable_constant_test.cc", "../src/resolver/ptr_ref_test.cc", "../src/resolver/ptr_ref_validation_test.cc", + "../src/resolver/resolver_behavior_test.cc", "../src/resolver/resolver_constants_test.cc", "../src/resolver/resolver_test.cc", "../src/resolver/resolver_test_helper.cc",