tint/resolver: Validate discard is only used by fragment shaders

Fixed: tint:1373
Change-Id: Ieb2a808982d8fa8b199e57d4df44f29390fa6e74
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105961
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-10-14 13:44:54 +00:00 committed by Dawn LUCI CQ
parent 1a567780d9
commit d9222f44c9
19 changed files with 279 additions and 109 deletions

View File

@ -143,7 +143,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_return) {
TEST_F(ResolverFunctionValidationTest, UnreachableCode_return_InBlocks) { TEST_F(ResolverFunctionValidationTest, UnreachableCode_return_InBlocks) {
// fn func() -> { // fn func() -> {
// var a : i32; // var a : i32;
// utils::Vector {{{return;}}} // {{{return;}}}
// a = 2i; // a = 2i;
//} //}
@ -184,7 +184,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard) {
TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard_InBlocks) { TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard_InBlocks) {
// fn func() -> { // fn func() -> {
// var a : i32; // var a : i32;
// utils::Vector {{{discard;}}} // {{{discard;}}}
// a = 2i; // a = 2i;
//} //}
@ -202,6 +202,59 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard_InBlocks) {
EXPECT_FALSE(Sem().Get(assign_a)->IsReachable()); EXPECT_FALSE(Sem().Get(assign_a)->IsReachable());
} }
TEST_F(ResolverFunctionValidationTest, DiscardCalledDirectlyFromVertexEntryPoint) {
// @vertex() fn func() -> @position(0) vec4<f32> { discard; }
Func(Source{{1, 2}}, "func", utils::Empty, ty.vec4<f32>(),
utils::Vector{
Discard(Source{{12, 34}}),
},
utils::Vector{Stage(ast::PipelineStage::kVertex)},
utils::Vector{Builtin(ast::BuiltinValue::kPosition)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: discard statement cannot be used in vertex pipeline stage");
}
TEST_F(ResolverFunctionValidationTest, DiscardCalledIndirectlyFromComputeEntryPoint) {
// fn f0 { discard; }
// fn f1 { f0(); }
// fn f2 { f1(); }
// @compute @workgroup_size(1) fn main { return f2(); }
Func(Source{{1, 2}}, "f0", utils::Empty, ty.void_(),
utils::Vector{
Discard(Source{{12, 34}}),
});
Func(Source{{3, 4}}, "f1", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(Call("f0")),
});
Func(Source{{5, 6}}, "f2", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(Call("f1")),
});
Func(Source{{7, 8}}, "main", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(Call("f2")),
},
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1_i),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: discard statement cannot be used in compute pipeline stage
1:2 note: called by function 'f0'
3:4 note: called by function 'f1'
5:6 note: called by function 'f2'
7:8 note: called by entry point 'main')");
}
TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) { TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) {
// fn func() -> int { var a:i32 = 2i; } // fn func() -> int { var a:i32 = 2i; }

View File

@ -3218,7 +3218,7 @@ sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
builder_->create<sem::Statement>(stmt, current_compound_statement_, current_function_); builder_->create<sem::Statement>(stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kDiscard; sem->Behaviors() = sem::Behavior::kDiscard;
current_function_->SetHasDiscard(); current_function_->SetDiscardStatement(sem);
return validator_.DiscardStatement(sem, current_statement_); return validator_.DiscardStatement(sem, current_statement_);
}); });

View File

@ -43,7 +43,9 @@ class ResolverBehaviorTest : public ResolverTest {
TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) { TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) {
auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1_i))); auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1_i)));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -53,7 +55,9 @@ TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) {
TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) { TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) {
auto* stmt = Decl(Var("lhs", ty.i32(), Add(1_i, Call("DiscardOrNext")))); auto* stmt = Decl(Var("lhs", ty.i32(), Add(1_i, Call("DiscardOrNext"))));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -63,7 +67,9 @@ TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) {
TEST_F(ResolverBehaviorTest, ExprBitcastOp) { TEST_F(ResolverBehaviorTest, ExprBitcastOp) {
auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast<u32>(Call("DiscardOrNext")))); auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast<u32>(Call("DiscardOrNext"))));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -79,7 +85,9 @@ TEST_F(ResolverBehaviorTest, ExprIndex_Arr) {
}); });
auto* stmt = Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1_i))); auto* stmt = Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1_i)));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -89,8 +97,13 @@ TEST_F(ResolverBehaviorTest, ExprIndex_Arr) {
TEST_F(ResolverBehaviorTest, ExprIndex_Idx) { TEST_F(ResolverBehaviorTest, ExprIndex_Idx) {
auto* stmt = Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext")))); auto* stmt = Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext"))));
WrapInFunction(Decl(Var("arr", ty.array<i32, 4>())), //
stmt); Func("F", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("arr", ty.array<i32, 4>())), //
stmt,
},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -102,7 +115,9 @@ TEST_F(ResolverBehaviorTest, ExprUnaryOp) {
auto* stmt = auto* stmt =
Decl(Var("lhs", ty.i32(), Decl(Var("lhs", ty.i32(),
create<ast::UnaryOpExpression>(ast::UnaryOp::kComplement, Call("DiscardOrNext")))); create<ast::UnaryOpExpression>(ast::UnaryOp::kComplement, Call("DiscardOrNext"))));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -124,8 +139,13 @@ TEST_F(ResolverBehaviorTest, StmtAssign) {
TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) { TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) {
auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1_i); auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1_i);
WrapInFunction(Decl(Var("lhs", ty.array<i32, 4>())), //
stmt); Func("F", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("lhs", ty.array<i32, 4>())), //
stmt,
},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -135,8 +155,13 @@ TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) {
TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) { TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) {
auto* stmt = Assign("lhs", Call("DiscardOrNext")); auto* stmt = Assign("lhs", Call("DiscardOrNext"));
WrapInFunction(Decl(Var("lhs", ty.i32())), //
stmt); Func("F", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("lhs", ty.i32())), //
stmt,
},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -156,7 +181,9 @@ TEST_F(ResolverBehaviorTest, StmtBlockEmpty) {
TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) { TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) {
auto* stmt = Block(Discard()); auto* stmt = Block(Discard());
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -178,7 +205,9 @@ TEST_F(ResolverBehaviorTest, StmtCallReturn) {
TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) { TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) {
Func("f", utils::Empty, ty.void_(), utils::Vector{Discard()}); Func("f", utils::Empty, ty.void_(), utils::Vector{Discard()});
auto* stmt = CallStmt(Call("f")); auto* stmt = CallStmt(Call("f"));
WrapInFunction(stmt);
Func("g", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -189,7 +218,9 @@ TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) {
TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) { TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) {
auto* stmt = auto* stmt =
For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block(Break())); For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block(Break()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -220,7 +251,9 @@ TEST_F(ResolverBehaviorTest, StmtContinue) {
TEST_F(ResolverBehaviorTest, StmtDiscard) { TEST_F(ResolverBehaviorTest, StmtDiscard) {
auto* stmt = Discard(); auto* stmt = Discard();
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -256,7 +289,9 @@ TEST_F(ResolverBehaviorTest, StmtForLoopContinue_NoExit) {
TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) { TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) {
auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard())); auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -277,7 +312,9 @@ TEST_F(ResolverBehaviorTest, StmtForLoopReturn) {
TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) { TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) {
auto* stmt = auto* stmt =
For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block(Break())); For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block(Break()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -287,7 +324,9 @@ TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) {
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) { TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) {
auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block()); auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block());
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -307,7 +346,9 @@ TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondTrue) {
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) { TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) {
auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1_i), nullptr, Block()); auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1_i), nullptr, Block());
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -327,7 +368,9 @@ TEST_F(ResolverBehaviorTest, StmtWhileBreak) {
TEST_F(ResolverBehaviorTest, StmtWhileDiscard) { TEST_F(ResolverBehaviorTest, StmtWhileDiscard) {
auto* stmt = While(Expr(true), Block(Discard())); auto* stmt = While(Expr(true), Block(Discard()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -357,7 +400,9 @@ TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondTrue) {
TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondCallFuncMayDiscard) { TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondCallFuncMayDiscard) {
auto* stmt = While(Equal(Call("DiscardOrNext"), 1_i), Block()); auto* stmt = While(Equal(Call("DiscardOrNext"), 1_i), Block());
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -377,7 +422,9 @@ TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) {
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) { TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) {
auto* stmt = If(true, Block(Discard())); auto* stmt = If(true, Block(Discard()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -387,7 +434,9 @@ TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) {
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) { TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) {
auto* stmt = If(true, Block(), Else(Block(Discard()))); auto* stmt = If(true, Block(), Else(Block(Discard())));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -397,7 +446,9 @@ TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) {
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) { TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) {
auto* stmt = If(true, Block(Discard()), Else(Block(Discard()))); auto* stmt = If(true, Block(Discard()), Else(Block(Discard())));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -407,7 +458,9 @@ TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) {
TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) { TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) {
auto* stmt = If(Equal(Call("DiscardOrNext"), 1_i), Block()); auto* stmt = If(Equal(Call("DiscardOrNext"), 1_i), Block());
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -418,7 +471,9 @@ TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) {
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) { TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) {
auto* stmt = If(true, Block(), // auto* stmt = If(true, Block(), //
Else(If(Equal(Call("DiscardOrNext"), 1_i), Block()))); Else(If(Equal(Call("DiscardOrNext"), 1_i), Block())));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -438,7 +493,9 @@ TEST_F(ResolverBehaviorTest, StmtLetDecl) {
TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) { TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) {
auto* stmt = Decl(Let("lhs", ty.i32(), Call("DiscardOrNext"))); auto* stmt = Decl(Let("lhs", ty.i32(), Call("DiscardOrNext")));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -474,7 +531,9 @@ TEST_F(ResolverBehaviorTest, StmtLoopContinue_NoExit) {
TEST_F(ResolverBehaviorTest, StmtLoopDiscard) { TEST_F(ResolverBehaviorTest, StmtLoopDiscard) {
auto* stmt = Loop(Block(Discard())); auto* stmt = Loop(Block(Discard()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -522,6 +581,7 @@ TEST_F(ResolverBehaviorTest, StmtReturn) {
TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) { TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) {
auto* stmt = Return(Call("DiscardOrNext")); auto* stmt = Return(Call("DiscardOrNext"));
Func("F", utils::Empty, ty.i32(), utils::Vector{stmt}); Func("F", utils::Empty, ty.i32(), utils::Vector{stmt});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -552,7 +612,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultEmpty) {
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) { TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) {
auto* stmt = Switch(1_i, DefaultCase(Block(Discard()))); auto* stmt = Switch(1_i, DefaultCase(Block(Discard())));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -582,7 +644,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) {
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) { TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block(Discard()))); auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block(Discard())));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -602,7 +666,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) {
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) { TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block())); auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -612,7 +678,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) { TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Discard()))); auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Discard())));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -622,7 +690,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard)
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) { TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Return()))); auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Return())));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -635,7 +705,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_Case1Return_Def
Case(Expr(0_i), Block(Discard())), // Case(Expr(0_i), Block(Discard())), //
Case(Expr(1_i), Block(Return())), // Case(Expr(1_i), Block(Return())), //
DefaultCase(Block())); DefaultCase(Block()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -646,7 +718,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_Case1Return_Def
TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) { TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) {
auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block())); auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block()));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -666,7 +740,9 @@ TEST_F(ResolverBehaviorTest, StmtVarDecl) {
TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) { TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) {
auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext"))); auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext")));
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();

View File

@ -1942,6 +1942,18 @@ bool Validator::Matrix(const sem::Matrix* ty, const Source& source) const {
} }
bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points) const { bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points) const {
auto backtrace = [&](const sem::Function* func, const sem::Function* entry_point) {
if (func != entry_point) {
TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) {
AddNote("called by function '" + symbols_.NameFor(f->Declaration()->symbol) + "'",
f->Declaration()->source);
});
AddNote("called by entry point '" +
symbols_.NameFor(entry_point->Declaration()->symbol) + "'",
entry_point->Declaration()->source);
}
};
auto check_workgroup_storage = [&](const sem::Function* func, auto check_workgroup_storage = [&](const sem::Function* func,
const sem::Function* entry_point) { const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage(); auto stage = entry_point->Declaration()->PipelineStage();
@ -1959,17 +1971,7 @@ bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points)
} }
} }
AddNote("variable is declared here", var->Declaration()->source); AddNote("variable is declared here", var->Declaration()->source);
if (func != entry_point) { backtrace(func, entry_point);
TraverseCallChain(
diagnostics_, entry_point, func, [&](const sem::Function* f) {
AddNote("called by function '" +
symbols_.NameFor(f->Declaration()->symbol) + "'",
f->Declaration()->source);
});
AddNote("called by entry point '" +
symbols_.NameFor(entry_point->Declaration()->symbol) + "'",
entry_point->Declaration()->source);
}
return false; return false;
} }
} }
@ -1977,17 +1979,6 @@ bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points)
return true; return true;
}; };
for (auto* entry_point : entry_points) {
if (!check_workgroup_storage(entry_point, entry_point)) {
return false;
}
for (auto* func : entry_point->TransitivelyCalledFunctions()) {
if (!check_workgroup_storage(func, entry_point)) {
return false;
}
}
}
auto check_builtin_calls = [&](const sem::Function* func, const sem::Function* entry_point) { auto check_builtin_calls = [&](const sem::Function* func, const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage(); auto stage = entry_point->Declaration()->PipelineStage();
for (auto* builtin : func->DirectlyCalledBuiltins()) { for (auto* builtin : func->DirectlyCalledBuiltins()) {
@ -1997,16 +1988,34 @@ bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points)
err << "built-in cannot be used by " << stage << " pipeline stage"; err << "built-in cannot be used by " << stage << " pipeline stage";
AddError(err.str(), AddError(err.str(),
call ? call->Declaration()->source : func->Declaration()->source); call ? call->Declaration()->source : func->Declaration()->source);
if (func != entry_point) { backtrace(func, entry_point);
TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) { return false;
AddNote("called by function '" + }
symbols_.NameFor(f->Declaration()->symbol) + "'", }
f->Declaration()->source); return true;
}); };
AddNote("called by entry point '" +
symbols_.NameFor(entry_point->Declaration()->symbol) + "'", auto check_no_discards = [&](const sem::Function* func, const sem::Function* entry_point) {
entry_point->Declaration()->source); if (auto* discard = func->DiscardStatement()) {
} auto stage = entry_point->Declaration()->PipelineStage();
std::stringstream err;
err << "discard statement cannot be used in " << stage << " pipeline stage";
AddError(err.str(), discard->Declaration()->source);
backtrace(func, entry_point);
return false;
}
return true;
};
auto check_func = [&](const sem::Function* func, const sem::Function* entry_point) {
if (!check_workgroup_storage(func, entry_point)) {
return false;
}
if (!check_builtin_calls(func, entry_point)) {
return false;
}
if (entry_point->Declaration()->PipelineStage() != ast::PipelineStage::kFragment) {
if (!check_no_discards(func, entry_point)) {
return false; return false;
} }
} }
@ -2014,15 +2023,16 @@ bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points)
}; };
for (auto* entry_point : entry_points) { for (auto* entry_point : entry_points) {
if (!check_builtin_calls(entry_point, entry_point)) { if (!check_func(entry_point, entry_point)) {
return false; return false;
} }
for (auto* func : entry_point->TransitivelyCalledFunctions()) { for (auto* func : entry_point->TransitivelyCalledFunctions()) {
if (!check_builtin_calls(func, entry_point)) { if (!check_func(func, entry_point)) {
return false; return false;
} }
} }
} }
return true; return true;
} }

View File

@ -237,12 +237,17 @@ class Function final : public Castable<Function, CallTarget> {
/// @returns true if `sym` is an ancestor entry point of this function /// @returns true if `sym` is an ancestor entry point of this function
bool HasAncestorEntryPoint(Symbol sym) const; bool HasAncestorEntryPoint(Symbol sym) const;
/// Sets that this function has a discard statement /// Records the first discard statement in the function
void SetHasDiscard() { has_discard_ = true; } /// @param stmt the `discard` statement.
void SetDiscardStatement(const Statement* stmt) {
if (!discard_stmt_) {
discard_stmt_ = stmt;
}
}
/// Returns true if this function has a discard statement /// @returns the first discard statement for the function, or nullptr if the function does not
/// @returns true if this function has a discard statement /// use `discard`.
bool HasDiscard() const { return has_discard_; } const Statement* DiscardStatement() const { return discard_stmt_; }
/// @return the behaviors of this function /// @return the behaviors of this function
const sem::Behaviors& Behaviors() const { return behaviors_; } const sem::Behaviors& Behaviors() const { return behaviors_; }
@ -271,7 +276,7 @@ class Function final : public Castable<Function, CallTarget> {
std::vector<const Call*> direct_calls_; std::vector<const Call*> direct_calls_;
std::vector<const Call*> callsites_; std::vector<const Call*> callsites_;
std::vector<const Function*> ancestor_entry_points_; std::vector<const Function*> ancestor_entry_points_;
bool has_discard_ = false; const Statement* discard_stmt_ = nullptr;
sem::Behaviors behaviors_{sem::Behavior::kNext}; sem::Behaviors behaviors_{sem::Behavior::kNext};
std::optional<uint32_t> return_location_; std::optional<uint32_t> return_location_;

View File

@ -20,7 +20,7 @@ namespace {
using GlslGeneratorImplTest_Block = TestHelper; using GlslGeneratorImplTest_Block = TestHelper;
TEST_F(GlslGeneratorImplTest_Block, Emit_Block) { TEST_F(GlslGeneratorImplTest_Block, Emit_Block) {
auto* b = Block(create<ast::DiscardStatement>()); auto* b = Block(Return());
WrapInFunction(b); WrapInFunction(b);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -29,7 +29,7 @@ TEST_F(GlslGeneratorImplTest_Block, Emit_Block) {
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
discard; return;
} }
)"); )");
} }

View File

@ -21,7 +21,9 @@ using GlslGeneratorImplTest_Discard = TestHelper;
TEST_F(GlslGeneratorImplTest_Discard, Emit_Discard) { TEST_F(GlslGeneratorImplTest_Discard, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>(); auto* stmt = create<ast::DiscardStatement>();
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -27,7 +27,8 @@ TEST_F(GlslGeneratorImplTest_Loop, Emit_Loop) {
auto* continuing = Block(); auto* continuing = Block();
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -47,7 +48,8 @@ TEST_F(GlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) {
auto* continuing = Block(CallStmt(Call("a_statement"))); auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -81,7 +83,9 @@ TEST_F(GlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
continuing = Block(Assign(lhs, rhs)); continuing = Block(Assign(lhs, rhs));
auto* outer = Loop(body, continuing); auto* outer = Loop(body, continuing);
WrapInFunction(outer);
Func("F", utils::Empty, ty.void_(), utils::Vector{outer},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -2767,7 +2767,7 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
out << ") {"; out << ") {";
} }
if (sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>()) { if (sem->DiscardStatement() && !sem->ReturnType()->Is<sem::Void>()) {
// BUG(crbug.com/tint/1081): work around non-void functions with discard // BUG(crbug.com/tint/1081): work around non-void functions with discard
// failing compilation sometimes // failing compilation sometimes
if (!EmitFunctionBodyWithDiscard(func)) { if (!EmitFunctionBodyWithDiscard(func)) {
@ -2791,7 +2791,7 @@ bool GeneratorImpl::EmitFunctionBodyWithDiscard(const ast::Function* func) {
// there is always an (unused) return statement. // there is always an (unused) return statement.
auto* sem = builder_.Sem().Get(func); auto* sem = builder_.Sem().Get(func);
TINT_ASSERT(Writer, sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>()); TINT_ASSERT(Writer, sem->DiscardStatement() && !sem->ReturnType()->Is<sem::Void>());
ScopedIndent si(this); ScopedIndent si(this);
line() << "if (true) {"; line() << "if (true) {";

View File

@ -20,7 +20,7 @@ namespace {
using HlslGeneratorImplTest_Block = TestHelper; using HlslGeneratorImplTest_Block = TestHelper;
TEST_F(HlslGeneratorImplTest_Block, Emit_Block) { TEST_F(HlslGeneratorImplTest_Block, Emit_Block) {
auto* b = Block(create<ast::DiscardStatement>()); auto* b = Block(Return());
WrapInFunction(b); WrapInFunction(b);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -29,7 +29,7 @@ TEST_F(HlslGeneratorImplTest_Block, Emit_Block) {
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
discard; return;
} }
)"); )");
} }

View File

@ -21,7 +21,9 @@ using HlslGeneratorImplTest_Discard = TestHelper;
TEST_F(HlslGeneratorImplTest_Discard, Emit_Discard) { TEST_F(HlslGeneratorImplTest_Discard, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>(); auto* stmt = create<ast::DiscardStatement>();
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -27,7 +27,8 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_Loop) {
auto* continuing = Block(); auto* continuing = Block();
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -47,7 +48,8 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) {
auto* continuing = Block(CallStmt(Call("a_statement"))); auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -81,7 +83,9 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
continuing = Block(Assign(lhs, rhs)); continuing = Block(Assign(lhs, rhs));
auto* outer = Loop(body, continuing); auto* outer = Loop(body, continuing);
WrapInFunction(outer);
Func("F", utils::Empty, ty.void_(), utils::Vector{outer},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -20,7 +20,7 @@ namespace {
using MslGeneratorImplTest = TestHelper; using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, Emit_Block) { TEST_F(MslGeneratorImplTest, Emit_Block) {
auto* b = Block(create<ast::DiscardStatement>()); auto* b = Block(Return());
WrapInFunction(b); WrapInFunction(b);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -29,13 +29,13 @@ TEST_F(MslGeneratorImplTest, Emit_Block) {
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
discard_fragment(); return;
} }
)"); )");
} }
TEST_F(MslGeneratorImplTest, Emit_Block_WithoutNewline) { TEST_F(MslGeneratorImplTest, Emit_Block_WithoutNewline) {
auto* b = Block(create<ast::DiscardStatement>()); auto* b = Block(Return());
WrapInFunction(b); WrapInFunction(b);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -44,7 +44,7 @@ TEST_F(MslGeneratorImplTest, Emit_Block_WithoutNewline) {
ASSERT_TRUE(gen.EmitBlock(b)) << gen.error(); ASSERT_TRUE(gen.EmitBlock(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
discard_fragment(); return;
} }
)"); )");
} }

View File

@ -21,7 +21,9 @@ using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, Emit_Discard) { TEST_F(MslGeneratorImplTest, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>(); auto* stmt = create<ast::DiscardStatement>();
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -26,7 +26,9 @@ TEST_F(MslGeneratorImplTest, Emit_Loop) {
auto* body = Block(create<ast::DiscardStatement>()); auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(); auto* continuing = Block();
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l);
Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -45,7 +47,9 @@ TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) {
auto* body = Block(create<ast::DiscardStatement>()); auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(CallStmt(Call("a_statement"))); auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l);
Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -76,7 +80,9 @@ TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
continuing = Block(Assign("lhs", "rhs")); continuing = Block(Assign("lhs", "rhs"));
auto* outer = Loop(body, continuing); auto* outer = Loop(body, continuing);
WrapInFunction(outer);
Func("F", utils::Empty, ty.void_(), utils::Vector{outer},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -21,13 +21,15 @@ namespace {
using BuilderTest = TestHelper; using BuilderTest = TestHelper;
TEST_F(BuilderTest, Discard) { TEST_F(BuilderTest, Discard) {
auto* expr = create<ast::DiscardStatement>(); auto* stmt = Discard();
WrapInFunction(expr);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build(); spirv::Builder& b = Build();
b.push_function(Function{}); b.push_function(Function{});
EXPECT_TRUE(b.GenerateStatement(expr)) << b.error(); EXPECT_TRUE(b.GenerateStatement(stmt)) << b.error();
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpKill EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpKill
)"); )");
} }

View File

@ -20,7 +20,7 @@ namespace {
using WgslGeneratorImplTest = TestHelper; using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_Block) { TEST_F(WgslGeneratorImplTest, Emit_Block) {
auto* b = Block(create<ast::DiscardStatement>()); auto* b = Block(Return());
WrapInFunction(b); WrapInFunction(b);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -29,7 +29,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Block) {
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
discard; return;
} }
)"); )");
} }

View File

@ -21,7 +21,9 @@ using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_Discard) { TEST_F(WgslGeneratorImplTest, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>(); auto* stmt = create<ast::DiscardStatement>();
WrapInFunction(stmt);
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -26,7 +26,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Loop) {
auto* continuing = Block(); auto* continuing = Block();
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -46,7 +47,8 @@ TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) {
auto* continuing = Block(CallStmt(Call("a_statement"))); auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); Func("F", utils::Empty, ty.void_(), utils::Vector{l},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();