diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 5b125b4e5a..a0eba5a161 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1183,6 +1183,13 @@ uint32_t Builder::GenerateIdentifierExpression( return val; } +uint32_t Builder::GenerateNonReferenceExpression(const ast::Expression* expr) { + if (const auto id = GenerateExpression(expr)) { + return GenerateLoadIfNeeded(TypeOf(expr), id); + } + return 0; +} + uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) { if (auto* ref = type->As()) { type = ref->StoreType(); @@ -3578,7 +3585,10 @@ bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) { if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) { // It's a break-if. TINT_ASSERT(Writer, !backedge_stack_.empty()); - const auto cond_id = GenerateExpression(stmt->condition); + const auto cond_id = GenerateNonReferenceExpression(stmt->condition); + if (!cond_id) { + return false; + } backedge_stack_.back() = Backedge(spv::Op::OpBranchConditional, {Operand::Int(cond_id), Operand::Int(ci.break_target_id), @@ -3590,7 +3600,10 @@ bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) { is_just_a_break(es.back()->body)) { // It's a break-unless. TINT_ASSERT(Writer, !backedge_stack_.empty()); - const auto cond_id = GenerateExpression(stmt->condition); + const auto cond_id = GenerateNonReferenceExpression(stmt->condition); + if (!cond_id) { + return false; + } backedge_stack_.back() = Backedge(spv::Op::OpBranchConditional, {Operand::Int(cond_id), Operand::Int(ci.loop_header_id), diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index d85e988c30..1b43ee418c 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -454,9 +454,17 @@ class Builder { /// @param stmt the statement to generate /// @returns true if the statement was generated bool GenerateStatement(const ast::Statement* stmt); - /// Geneates an OpLoad - /// @param type the type to load - /// @param id the variable id to load + /// Generates an expression. If the WGSL expression does not have reference + /// type, then return the SPIR-V ID for the expression. Otherwise implement + /// the WGSL Load Rule: generate an OpLoad and return the ID of the result. + /// Returns 0 if the expression could not be generated. + /// @param expr the expression to be generate + /// @returns the the ID of the expression, or loaded expression + uint32_t GenerateNonReferenceExpression(const ast::Expression* expr); + /// Generates an OpLoad on the given ID if it has reference type in WGSL, + /// othewrise return the ID itself. + /// @param type the type of the expression + /// @param id the SPIR-V id of the experssion /// @returns the ID of the loaded value or `id` if type is not a reference uint32_t GenerateLoadIfNeeded(const sem::Type* type, uint32_t id); /// Generates an OpStore. Emits an error and returns false if we're diff --git a/src/writer/spirv/builder_loop_test.cc b/src/writer/spirv/builder_loop_test.cc index 56d0aacd48..a074461202 100644 --- a/src/writer/spirv/builder_loop_test.cc +++ b/src/writer/spirv/builder_loop_test.cc @@ -287,6 +287,84 @@ OpBranchConditional %6 %1 %2 )"); } +TEST_F(BuilderTest, Loop_WithContinuing_BreakIf_ConditionIsVar) { + // loop { + // continuing { + // var cond = true; + // if (cond) { break; } + // } + // } + + auto* cond_var = Decl(Var("cond", nullptr, Expr(true))); + auto* if_stmt = If(Expr("cond"), Block(Break()), ast::ElseStatementList{}); + auto* continuing = Block(cond_var, if_stmt); + auto* loop = Loop(Block(), continuing); + WrapInFunction(loop); + + spirv::Builder& b = Build(); + + b.push_function(Function{}); + + EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool +%6 = OpConstantTrue %5 +%8 = OpTypePointer Function %5 +%9 = OpConstantNull %5 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpBranch %1 +%1 = OpLabel +OpLoopMerge %2 %3 None +OpBranch %4 +%4 = OpLabel +OpBranch %3 +%3 = OpLabel +OpStore %7 %6 +%10 = OpLoad %5 %7 +OpBranchConditional %10 %2 %1 +%2 = OpLabel +)"); +} + +TEST_F(BuilderTest, Loop_WithContinuing_BreakUnless_ConditionIsVar) { + // loop { + // continuing { + // var cond = true; + // if (cond) {} else { break; } + // } + // } + auto* cond_var = Decl(Var("cond", nullptr, Expr(true))); + auto* if_stmt = If(Expr("cond"), Block(), + ast::ElseStatementList{Else(nullptr, Block(Break()))}); + auto* continuing = Block(cond_var, if_stmt); + auto* loop = Loop(Block(), continuing); + WrapInFunction(loop); + + spirv::Builder& b = Build(); + + b.push_function(Function{}); + + EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool +%6 = OpConstantTrue %5 +%8 = OpTypePointer Function %5 +%9 = OpConstantNull %5 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpBranch %1 +%1 = OpLabel +OpLoopMerge %2 %3 None +OpBranch %4 +%4 = OpLabel +OpBranch %3 +%3 = OpLabel +OpStore %7 %6 +%10 = OpLoad %5 %7 +OpBranchConditional %10 %1 %2 +%2 = OpLabel +)"); +} + TEST_F(BuilderTest, Loop_WithContinuing_BreakIf_Nested) { // Make sure the right backedge and break target are used. // loop {