diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 5ef8252622..4af06d25df 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -215,6 +215,8 @@ libtint_source_set("libtint_core_all_src") { "ast/call_expression.h", "ast/call_statement.cc", "ast/call_statement.h", + "ast/case_selector.cc", + "ast/case_selector.h", "ast/case_statement.cc", "ast/case_statement.h", "ast/compound_assignment_statement.cc", @@ -1021,6 +1023,7 @@ if (tint_build_unittests) { "ast/builtin_value_test.cc", "ast/call_expression_test.cc", "ast/call_statement_test.cc", + "ast/case_selector_test.cc", "ast/case_statement_test.cc", "ast/compound_assignment_statement_test.cc", "ast/continue_statement_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 0811598de0..512c90642b 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -83,6 +83,8 @@ set(TINT_LIB_SRCS ast/call_expression.h ast/call_statement.cc ast/call_statement.h + ast/case_selector.cc + ast/case_selector.h ast/case_statement.cc ast/case_statement.h ast/compound_assignment_statement.cc @@ -713,6 +715,7 @@ if(TINT_BUILD_TESTS) ast/builtin_value_test.cc ast/call_expression_test.cc ast/call_statement_test.cc + ast/case_selector_test.cc ast/case_statement_test.cc ast/compound_assignment_statement_test.cc ast/continue_statement_test.cc diff --git a/src/tint/ast/case_selector.cc b/src/tint/ast/case_selector.cc new file mode 100644 index 0000000000..8622d3a105 --- /dev/null +++ b/src/tint/ast/case_selector.cc @@ -0,0 +1,39 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/case_selector.h" + +#include + +#include "src/tint/program_builder.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ast::CaseSelector); + +namespace tint::ast { + +CaseSelector::CaseSelector(ProgramID pid, NodeID nid, const Source& src, const ast::Expression* e) + : Base(pid, nid, src), expr(e) {} + +CaseSelector::CaseSelector(CaseSelector&&) = default; + +CaseSelector::~CaseSelector() = default; + +const CaseSelector* CaseSelector::Clone(CloneContext* ctx) const { + // Clone arguments outside of create() call to have deterministic ordering + auto src = ctx->Clone(source); + auto ex = ctx->Clone(expr); + return ctx->dst->create(src, ex); +} + +} // namespace tint::ast diff --git a/src/tint/ast/case_selector.h b/src/tint/ast/case_selector.h new file mode 100644 index 0000000000..b4c3ca7f57 --- /dev/null +++ b/src/tint/ast/case_selector.h @@ -0,0 +1,52 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_AST_CASE_SELECTOR_H_ +#define SRC_TINT_AST_CASE_SELECTOR_H_ + +#include + +#include "src/tint/ast/block_statement.h" +#include "src/tint/ast/expression.h" + +namespace tint::ast { + +/// A case selector +class CaseSelector final : public Castable { + public: + /// Constructor + /// @param pid the identifier of the program that owns this node + /// @param nid the unique node identifier + /// @param src the source of this node + /// @param expr the selector expression, |nullptr| for a `default` selector + CaseSelector(ProgramID pid, NodeID nid, const Source& src, const Expression* expr = nullptr); + /// Move constructor + CaseSelector(CaseSelector&&); + ~CaseSelector() override; + + /// @returns true if this is a default statement + bool IsDefault() const { return expr == nullptr; } + + /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned node + const CaseSelector* Clone(CloneContext* ctx) const override; + + /// The selector, nullptr for a default selector + const Expression* const expr = nullptr; +}; + +} // namespace tint::ast + +#endif // SRC_TINT_AST_CASE_SELECTOR_H_ diff --git a/src/tint/ast/case_selector_test.cc b/src/tint/ast/case_selector_test.cc new file mode 100644 index 0000000000..16e74cc4d7 --- /dev/null +++ b/src/tint/ast/case_selector_test.cc @@ -0,0 +1,40 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/case_selector.h" + +#include "gtest/gtest-spi.h" +#include "src/tint/ast/test_helper.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::ast { +namespace { + +using CaseSelectorTest = TestHelper; + +TEST_F(CaseSelectorTest, NonDefault) { + auto* e = Expr(2_i); + auto* c = CaseSelector(e); + EXPECT_FALSE(c->IsDefault()); + EXPECT_EQ(e, c->expr); +} + +TEST_F(CaseSelectorTest, Default) { + auto* c = DefaultCaseSelector(); + EXPECT_TRUE(c->IsDefault()); +} + +} // namespace +} // namespace tint::ast diff --git a/src/tint/ast/case_statement.cc b/src/tint/ast/case_statement.cc index 3125f0516f..7b2e798ecb 100644 --- a/src/tint/ast/case_statement.cc +++ b/src/tint/ast/case_statement.cc @@ -25,10 +25,11 @@ namespace tint::ast { CaseStatement::CaseStatement(ProgramID pid, NodeID nid, const Source& src, - utils::VectorRef s, + utils::VectorRef s, const BlockStatement* b) : Base(pid, nid, src), selectors(std::move(s)), body(b) { TINT_ASSERT(AST, body); + TINT_ASSERT(AST, !selectors.IsEmpty()); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id); for (auto* selector : selectors) { TINT_ASSERT(AST, selector); @@ -40,6 +41,15 @@ CaseStatement::CaseStatement(CaseStatement&&) = default; CaseStatement::~CaseStatement() = default; +bool CaseStatement::ContainsDefault() const { + for (const auto* sel : selectors) { + if (sel->IsDefault()) { + return true; + } + } + return false; +} + const CaseStatement* CaseStatement::Clone(CloneContext* ctx) const { // Clone arguments outside of create() call to have deterministic ordering auto src = ctx->Clone(source); diff --git a/src/tint/ast/case_statement.h b/src/tint/ast/case_statement.h index eda9c010a2..acd502afcf 100644 --- a/src/tint/ast/case_statement.h +++ b/src/tint/ast/case_statement.h @@ -18,7 +18,7 @@ #include #include "src/tint/ast/block_statement.h" -#include "src/tint/ast/expression.h" +#include "src/tint/ast/case_selector.h" namespace tint::ast { @@ -34,23 +34,23 @@ class CaseStatement final : public Castable { CaseStatement(ProgramID pid, NodeID nid, const Source& src, - utils::VectorRef selectors, + utils::VectorRef selectors, const BlockStatement* body); /// Move constructor CaseStatement(CaseStatement&&); ~CaseStatement() override; - /// @returns true if this is a default statement - bool IsDefault() const { return selectors.IsEmpty(); } - /// Clones this node and all transitive child nodes using the `CloneContext` /// `ctx`. /// @param ctx the clone context /// @return the newly cloned node const CaseStatement* Clone(CloneContext* ctx) const override; + /// @returns true if this item contains a default selector + bool ContainsDefault() const; + /// The case selectors, empty if none set - const utils::Vector selectors; + const utils::Vector selectors; /// The case body const BlockStatement* const body; diff --git a/src/tint/ast/case_statement_test.cc b/src/tint/ast/case_statement_test.cc index dc3c88ad03..04b887bda2 100644 --- a/src/tint/ast/case_statement_test.cc +++ b/src/tint/ast/case_statement_test.cc @@ -27,7 +27,7 @@ namespace { using CaseStatementTest = TestHelper; TEST_F(CaseStatementTest, Creation_i32) { - auto* selector = Expr(2_i); + auto* selector = CaseSelector(2_i); utils::Vector b{selector}; auto* discard = create(); @@ -41,7 +41,7 @@ TEST_F(CaseStatementTest, Creation_i32) { } TEST_F(CaseStatementTest, Creation_u32) { - auto* selector = Expr(2_u); + auto* selector = CaseSelector(2_u); utils::Vector b{selector}; auto* discard = create(); @@ -54,8 +54,20 @@ TEST_F(CaseStatementTest, Creation_u32) { EXPECT_EQ(c->body->statements[0], discard); } +TEST_F(CaseStatementTest, ContainsDefault_WithDefault) { + utils::Vector b{CaseSelector(2_u), DefaultCaseSelector()}; + auto* c = create(b, create(utils::Empty)); + EXPECT_TRUE(c->ContainsDefault()); +} + +TEST_F(CaseStatementTest, ContainsDefault_WithOutDefault) { + utils::Vector b{CaseSelector(2_u), CaseSelector(3_u)}; + auto* c = create(b, create(utils::Empty)); + EXPECT_FALSE(c->ContainsDefault()); +} + TEST_F(CaseStatementTest, Creation_WithSource) { - utils::Vector b{Expr(2_i)}; + utils::Vector b{CaseSelector(2_i)}; auto* body = create(utils::Vector{ create(), @@ -66,22 +78,9 @@ TEST_F(CaseStatementTest, Creation_WithSource) { EXPECT_EQ(src.range.begin.column, 2u); } -TEST_F(CaseStatementTest, IsDefault_WithoutSelectors) { - auto* body = create(utils::Vector{ - create(), - }); - auto* c = create(utils::Empty, body); - EXPECT_TRUE(c->IsDefault()); -} - -TEST_F(CaseStatementTest, IsDefault_WithSelectors) { - utils::Vector b{Expr(2_i)}; - auto* c = create(b, create(utils::Empty)); - EXPECT_FALSE(c->IsDefault()); -} - TEST_F(CaseStatementTest, IsCase) { - auto* c = create(utils::Empty, create(utils::Empty)); + auto* c = create(utils::Vector{DefaultCaseSelector()}, + create(utils::Empty)); EXPECT_TRUE(c->Is()); } @@ -89,7 +88,7 @@ TEST_F(CaseStatementTest, Assert_Null_Body) { EXPECT_FATAL_FAILURE( { ProgramBuilder b; - b.create(utils::Empty, nullptr); + b.create(utils::Vector{b.DefaultCaseSelector()}, nullptr); }, "internal compiler error"); } @@ -98,7 +97,7 @@ TEST_F(CaseStatementTest, Assert_Null_Selector) { EXPECT_FATAL_FAILURE( { ProgramBuilder b; - b.create(utils::Vector{nullptr}, + b.create(utils::Vector{nullptr}, b.create(utils::Empty)); }, "internal compiler error"); @@ -109,7 +108,8 @@ TEST_F(CaseStatementTest, Assert_DifferentProgramID_Call) { { ProgramBuilder b1; ProgramBuilder b2; - b1.create(utils::Empty, b2.create(utils::Empty)); + b1.create(utils::Vector{b1.DefaultCaseSelector()}, + b2.create(utils::Empty)); }, "internal compiler error"); } @@ -119,7 +119,7 @@ TEST_F(CaseStatementTest, Assert_DifferentProgramID_Selector) { { ProgramBuilder b1; ProgramBuilder b2; - b1.create(utils::Vector{b2.Expr(2_i)}, + b1.create(utils::Vector{b2.CaseSelector(b2.Expr(2_i))}, b1.create(utils::Empty)); }, "internal compiler error"); diff --git a/src/tint/ast/switch_statement_test.cc b/src/tint/ast/switch_statement_test.cc index 0f66c6113e..00c515ea3d 100644 --- a/src/tint/ast/switch_statement_test.cc +++ b/src/tint/ast/switch_statement_test.cc @@ -25,7 +25,7 @@ namespace { using SwitchStatementTest = TestHelper; TEST_F(SwitchStatementTest, Creation) { - auto* case_stmt = create(utils::Vector{Expr(1_u)}, Block()); + auto* case_stmt = create(utils::Vector{CaseSelector(1_u)}, Block()); auto* ident = Expr("ident"); utils::Vector body{case_stmt}; @@ -44,7 +44,7 @@ TEST_F(SwitchStatementTest, Creation_WithSource) { } TEST_F(SwitchStatementTest, IsSwitch) { - utils::Vector lit{Expr(2_i)}; + utils::Vector lit{CaseSelector(2_i)}; auto* ident = Expr("ident"); utils::Vector body{create(lit, Block())}; @@ -58,7 +58,8 @@ TEST_F(SwitchStatementTest, Assert_Null_Condition) { { ProgramBuilder b; CaseStatementList cases; - cases.Push(b.create(utils::Vector{b.Expr(1_i)}, b.Block())); + cases.Push( + b.create(utils::Vector{b.CaseSelector(b.Expr(1_i))}, b.Block())); b.create(nullptr, cases); }, "internal compiler error"); @@ -82,7 +83,7 @@ TEST_F(SwitchStatementTest, Assert_DifferentProgramID_Condition) { b1.create(b2.Expr(true), utils::Vector{ b1.create( utils::Vector{ - b1.Expr(1_i), + b1.CaseSelector(b1.Expr(1_i)), }, b1.Block()), }); @@ -98,7 +99,7 @@ TEST_F(SwitchStatementTest, Assert_DifferentProgramID_CaseStatement) { b1.create(b1.Expr(true), utils::Vector{ b2.create( utils::Vector{ - b2.Expr(1_i), + b2.CaseSelector(b2.Expr(1_i)), }, b2.Block()), }); diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 892cf27a8b..f07a1f84a4 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -2847,7 +2847,7 @@ class ProgramBuilder { /// @param body the case body /// @returns the case statement pointer const ast::CaseStatement* Case(const Source& source, - utils::VectorRef selectors, + utils::VectorRef selectors, const ast::BlockStatement* body = nullptr) { return create(source, std::move(selectors), body ? body : Block()); } @@ -2856,7 +2856,7 @@ class ProgramBuilder { /// @param selectors list of selectors /// @param body the case body /// @returns the case statement pointer - const ast::CaseStatement* Case(utils::VectorRef selectors, + const ast::CaseStatement* Case(utils::VectorRef selectors, const ast::BlockStatement* body = nullptr) { return create(std::move(selectors), body ? body : Block()); } @@ -2865,7 +2865,7 @@ class ProgramBuilder { /// @param selector a single case selector /// @param body the case body /// @returns the case statement pointer - const ast::CaseStatement* Case(const ast::Expression* selector, + const ast::CaseStatement* Case(const ast::CaseSelector* selector, const ast::BlockStatement* body = nullptr) { return Case(utils::Vector{selector}, body); } @@ -2876,16 +2876,44 @@ class ProgramBuilder { /// @returns the case statement pointer const ast::CaseStatement* DefaultCase(const Source& source, const ast::BlockStatement* body = nullptr) { - return Case(source, utils::Empty, body); + return Case(source, utils::Vector{DefaultCaseSelector(source)}, body); } /// Convenience function that creates a 'default' ast::CaseStatement /// @param body the case body /// @returns the case statement pointer const ast::CaseStatement* DefaultCase(const ast::BlockStatement* body = nullptr) { - return Case(utils::Empty, body); + return Case(utils::Vector{DefaultCaseSelector()}, body); } + /// Convenience function that creates a case selector + /// @param source the source information + /// @param expr the selector expression + /// @returns the selector pointer + template + const ast::CaseSelector* CaseSelector(const Source& source, EXPR&& expr) { + return create(source, Expr(std::forward(expr))); + } + + /// Convenience function that creates a case selector + /// @param expr the selector expression + /// @returns the selector pointer + template + const ast::CaseSelector* CaseSelector(EXPR&& expr) { + return create(source_, Expr(std::forward(expr))); + } + + /// Convenience function that creates a default case selector + /// @param source the source information + /// @returns the selector pointer + const ast::CaseSelector* DefaultCaseSelector(const Source& source) { + return create(source, nullptr); + } + + /// Convenience function that creates a default case selector + /// @returns the selector pointer + const ast::CaseSelector* DefaultCaseSelector() { return create(nullptr); } + /// Creates an ast::FallthroughStatement /// @param source the source information /// @returns the fallthrough statement pointer diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc index 2cc35d81ee..0856efff77 100644 --- a/src/tint/reader/spirv/function.cc +++ b/src/tint/reader/spirv/function.cc @@ -3024,7 +3024,7 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) { for (size_t i = last_clause_index;; --i) { // Create a list of integer literals for the selector values leading to // this case clause. - utils::Vector selectors; + utils::Vector selectors; const bool has_selectors = clause_heads[i]->case_values.has_value(); if (has_selectors) { auto values = clause_heads[i]->case_values.value(); @@ -3034,15 +3034,26 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) { // The Tint AST handles 32-bit values. const uint32_t value32 = uint32_t(value & 0xFFFFFFFF); if (selector.type->IsUnsignedScalarOrVector()) { - selectors.Push(create( - Source{}, value32, ast::IntLiteralExpression::Suffix::kU)); + selectors.Push(create( + Source{}, create( + Source{}, value32, ast::IntLiteralExpression::Suffix::kU))); } else { - selectors.Push( + selectors.Push(create( + Source{}, create(Source{}, static_cast(value32), - ast::IntLiteralExpression::Suffix::kI)); + ast::IntLiteralExpression::Suffix::kI))); } } + + if ((default_info == clause_heads[i]) && construct->ContainsPos(default_info->pos)) { + // Generate a default selector + selectors.Push(create(Source{})); + } + } else { + // Generate a default selector + selectors.Push(create(Source{})); } + TINT_ASSERT(Reader, !selectors.IsEmpty()); // Where does this clause end? const auto end_id = @@ -3057,17 +3068,6 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) { swch->cases[case_idx] = create(Source{}, selectors, body); }); - if ((default_info == clause_heads[i]) && has_selectors && - construct->ContainsPos(default_info->pos)) { - // Generate a default clause with a just fallthrough. - auto* stmts = create( - Source{}, StatementList{ - create(Source{}), - }); - auto* case_stmt = create(Source{}, utils::Empty, stmts); - swch->cases.Push(case_stmt); - } - if (i == 0) { break; } diff --git a/src/tint/reader/spirv/function_cfg_test.cc b/src/tint/reader/spirv/function_cfg_test.cc index 8356e37710..538657b495 100644 --- a/src/tint/reader/spirv/function_cfg_test.cc +++ b/src/tint/reader/spirv/function_cfg_test.cc @@ -9349,10 +9349,7 @@ switch(42u) { case 20u: { var_1 = 20u; } - default: { - fallthrough; - } - case 30u: { + case 30u, default: { var_1 = 30u; } } diff --git a/src/tint/reader/spirv/function_var_test.cc b/src/tint/reader/spirv/function_var_test.cc index 5d156eb697..e3d2acfb75 100644 --- a/src/tint/reader/spirv/function_var_test.cc +++ b/src/tint/reader/spirv/function_var_test.cc @@ -1393,10 +1393,7 @@ TEST_F(SpvParserFunctionVarTest, EmitStatement_Phi_InMerge_PredecessorsDominatdB auto got = test::ToString(p->program(), ast_body); auto* expect = R"(var x_41 : u32; switch(1u) { - default: { - fallthrough; - } - case 0u: { + case 0u, default: { fallthrough; } case 1u: { diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc index cb8e217465..f962482675 100644 --- a/src/tint/reader/wgsl/parser_impl.cc +++ b/src/tint/reader/wgsl/parser_impl.cc @@ -2129,6 +2129,9 @@ Maybe ParserImpl::switch_body() { } selector_list = std::move(selectors.value); + } else { + // Push the default case selector + selector_list.Push(create(t.source())); } // Consume the optional colon if present. @@ -2148,12 +2151,12 @@ Maybe ParserImpl::switch_body() { } // case_selectors -// : expression (COMMA expression)* COMMA? +// : case_selector (COMMA case_selector)* COMMA? Expect ParserImpl::expect_case_selectors() { CaseSelectorList selectors; while (continue_parsing()) { - auto expr = expression(); + auto expr = case_selector(); if (expr.errored) { return Failure::kErrored; } @@ -2168,12 +2171,32 @@ Expect ParserImpl::expect_case_selectors() { } if (selectors.IsEmpty()) { - return add_error(peek(), "unable to parse case selectors"); + return add_error(peek(), "expected case selector expression or `default`"); } return selectors; } +// case_selector +// : DEFAULT +// | expression +Maybe ParserImpl::case_selector() { + auto& p = peek(); + + if (match(Token::Type::kDefault)) { + return create(p.source()); + } + + auto expr = expression(); + if (expr.errored) { + return Failure::kErrored; + } + if (!expr.matched) { + return Failure::kNoMatch; + } + return create(p.source(), expr.value); +} + // case_body // : // | statement case_body diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h index 3dbd7805ae..691cc5e196 100644 --- a/src/tint/reader/wgsl/parser_impl.h +++ b/src/tint/reader/wgsl/parser_impl.h @@ -74,7 +74,7 @@ class ParserImpl { /// Pre-determined small vector sizes for AST pointers //! @cond Doxygen_Suppress using AttributeList = utils::Vector; - using CaseSelectorList = utils::Vector; + using CaseSelectorList = utils::Vector; using CaseStatementList = utils::Vector; using ExpressionList = utils::Vector; using ParameterList = utils::Vector; @@ -573,6 +573,9 @@ class ParserImpl { /// Parses a `case_selectors` grammar element /// @returns the list of literals Expect expect_case_selectors(); + /// Parses a `case_selector` grammar element + /// @returns the selector + Maybe case_selector(); /// Parses a `case_body` grammar element /// @returns the parsed statements Maybe case_body(); diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc index 657b522d8c..1a8885660f 100644 --- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc +++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc @@ -1333,7 +1333,7 @@ fn f() { switch(1) { TEST_F(ParserImplErrorTest, SwitchStmtInvalidCase) { EXPECT("fn f() { switch(1) { case ^: } }", - R"(test.wgsl:1:27 error: unable to parse case selectors + R"(test.wgsl:1:27 error: expected case selector expression or `default` fn f() { switch(1) { case ^: } } ^ )"); diff --git a/src/tint/reader/wgsl/parser_impl_statement_test.cc b/src/tint/reader/wgsl/parser_impl_statement_test.cc index 7bb51d3d37..f78e94400f 100644 --- a/src/tint/reader/wgsl/parser_impl_statement_test.cc +++ b/src/tint/reader/wgsl/parser_impl_statement_test.cc @@ -143,7 +143,7 @@ TEST_F(ParserImplTest, Statement_Switch_Invalid) { EXPECT_TRUE(e.errored); EXPECT_FALSE(e.matched); EXPECT_EQ(e.value, nullptr); - EXPECT_EQ(p->error(), "1:18: unable to parse case selectors"); + EXPECT_EQ(p->error(), "1:18: expected case selector expression or `default`"); } TEST_F(ParserImplTest, Statement_Loop) { diff --git a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc index 2b8b0bd70d..61ec524089 100644 --- a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc +++ b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc @@ -25,13 +25,16 @@ TEST_F(ParserImplTest, SwitchBody_Case) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_FALSE(e->IsDefault()); + EXPECT_FALSE(e->ContainsDefault()); auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 1u); - ASSERT_TRUE(stmt->selectors[0]->Is()); - auto* expr = stmt->selectors[0]->As(); + auto* sel = stmt->selectors[0]; + EXPECT_FALSE(sel->IsDefault()); + ASSERT_TRUE(sel->expr->Is()); + + auto* expr = sel->expr->As(); EXPECT_EQ(expr->value, 1); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); ASSERT_EQ(e->body->statements.Length(), 1u); @@ -46,12 +49,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_Expression) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_FALSE(e->IsDefault()); + EXPECT_FALSE(e->ContainsDefault()); auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 1u); - ASSERT_TRUE(stmt->selectors[0]->Is()); - auto* expr = stmt->selectors[0]->As(); + + auto* sel = stmt->selectors[0]; + EXPECT_FALSE(sel->IsDefault()); + + ASSERT_TRUE(sel->expr->Is()); + auto* expr = sel->expr->As(); EXPECT_EQ(ast::BinaryOp::kAdd, expr->op); auto* v = expr->lhs->As(); @@ -74,13 +81,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_WithColon) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_FALSE(e->IsDefault()); + EXPECT_FALSE(e->ContainsDefault()); auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 1u); - ASSERT_TRUE(stmt->selectors[0]->Is()); + auto* sel = stmt->selectors[0]; + EXPECT_FALSE(sel->IsDefault()); + + ASSERT_TRUE(sel->expr->Is()); + auto* expr = sel->expr->As(); - auto* expr = stmt->selectors[0]->As(); EXPECT_EQ(expr->value, 1); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); ASSERT_EQ(e->body->statements.Length(), 1u); @@ -95,17 +105,20 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_FALSE(e->IsDefault()); + EXPECT_FALSE(e->ContainsDefault()); + auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 2u); - ASSERT_TRUE(stmt->selectors[0]->Is()); + auto* sel = stmt->selectors[0]; - auto* expr = stmt->selectors[0]->As(); + ASSERT_TRUE(sel->expr->Is()); + auto* expr = sel->expr->As(); EXPECT_EQ(expr->value, 1); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); - ASSERT_TRUE(stmt->selectors[1]->Is()); - expr = stmt->selectors[1]->As(); + sel = stmt->selectors[1]; + ASSERT_TRUE(sel->expr->Is()); + expr = sel->expr->As(); EXPECT_EQ(expr->value, 2); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } @@ -118,18 +131,20 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_FALSE(e->IsDefault()); + EXPECT_FALSE(e->ContainsDefault()); auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 2u); - ASSERT_TRUE(stmt->selectors[0]->Is()); + auto* sel = stmt->selectors[0]; - auto* expr = stmt->selectors[0]->As(); + ASSERT_TRUE(sel->expr->Is()); + auto* expr = sel->expr->As(); EXPECT_EQ(expr->value, 1); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); - ASSERT_TRUE(stmt->selectors[1]->Is()); - expr = stmt->selectors[1]->As(); + sel = stmt->selectors[1]; + ASSERT_TRUE(sel->expr->Is()); + expr = sel->expr->As(); EXPECT_EQ(expr->value, 2); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } @@ -141,7 +156,7 @@ TEST_F(ParserImplTest, SwitchBody_Case_Invalid) { EXPECT_TRUE(e.errored); EXPECT_FALSE(e.matched); EXPECT_EQ(e.value, nullptr); - EXPECT_EQ(p->error(), "1:6: unable to parse case selectors"); + EXPECT_EQ(p->error(), "1:6: expected case selector expression or `default`"); } TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) { @@ -151,7 +166,7 @@ TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) { EXPECT_TRUE(e.errored); EXPECT_FALSE(e.matched); EXPECT_EQ(e.value, nullptr); - EXPECT_EQ(p->error(), "1:5: unable to parse case selectors"); + EXPECT_EQ(p->error(), "1:5: expected case selector expression or `default`"); } TEST_F(ParserImplTest, SwitchBody_Case_MissingBracketLeft) { @@ -202,17 +217,46 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_FALSE(e->IsDefault()); + EXPECT_FALSE(e->ContainsDefault()); ASSERT_EQ(e->body->statements.Length(), 0u); ASSERT_EQ(e->selectors.Length(), 2u); - ASSERT_TRUE(e->selectors[0]->Is()); - auto* expr = e->selectors[0]->As(); + auto* sel = e->selectors[0]; + ASSERT_TRUE(sel->expr->Is()); + auto* expr = sel->expr->As(); ASSERT_EQ(expr->value, 1); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); - ASSERT_TRUE(e->selectors[1]->Is()); - expr = e->selectors[1]->As(); + sel = e->selectors[1]; + ASSERT_TRUE(sel->expr->Is()); + expr = sel->expr->As(); + ASSERT_EQ(expr->value, 2); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); +} + +TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_with_default) { + auto p = parser("case 1, default, 2 { }"); + auto e = p->switch_body(); + EXPECT_FALSE(p->has_error()) << p->error(); + EXPECT_TRUE(e.matched); + EXPECT_FALSE(e.errored); + ASSERT_NE(e.value, nullptr); + ASSERT_TRUE(e->Is()); + EXPECT_TRUE(e->ContainsDefault()); + ASSERT_EQ(e->body->statements.Length(), 0u); + ASSERT_EQ(e->selectors.Length(), 3u); + + auto* sel = e->selectors[0]; + ASSERT_TRUE(sel->expr->Is()); + auto* expr = sel->expr->As(); + ASSERT_EQ(expr->value, 1); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); + + EXPECT_TRUE(e->selectors[1]->IsDefault()); + + sel = e->selectors[2]; + ASSERT_TRUE(sel->expr->Is()); + expr = sel->expr->As(); ASSERT_EQ(expr->value, 2); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } @@ -225,17 +269,19 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_FALSE(e->IsDefault()); + EXPECT_FALSE(e->ContainsDefault()); ASSERT_EQ(e->body->statements.Length(), 0u); ASSERT_EQ(e->selectors.Length(), 2u); - ASSERT_TRUE(e->selectors[0]->Is()); - auto* expr = e->selectors[0]->As(); + auto* sel = e->selectors[0]; + ASSERT_TRUE(sel->expr->Is()); + auto* expr = sel->expr->As(); ASSERT_EQ(expr->value, 1); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); - ASSERT_TRUE(e->selectors[1]->Is()); - expr = e->selectors[1]->As(); + sel = e->selectors[1]; + ASSERT_TRUE(sel->expr->Is()); + expr = sel->expr->As(); ASSERT_EQ(expr->value, 2); EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } @@ -257,7 +303,7 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectorsStartsWithComma) { EXPECT_TRUE(e.errored); EXPECT_FALSE(e.matched); EXPECT_EQ(e.value, nullptr); - EXPECT_EQ(p->error(), "1:6: unable to parse case selectors"); + EXPECT_EQ(p->error(), "1:6: expected case selector expression or `default`"); } TEST_F(ParserImplTest, SwitchBody_Default) { @@ -268,7 +314,7 @@ TEST_F(ParserImplTest, SwitchBody_Default) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_TRUE(e->IsDefault()); + EXPECT_TRUE(e->ContainsDefault()); ASSERT_EQ(e->body->statements.Length(), 1u); EXPECT_TRUE(e->body->statements[0]->Is()); } @@ -281,7 +327,7 @@ TEST_F(ParserImplTest, SwitchBody_Default_WithColon) { EXPECT_FALSE(e.errored); ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); - EXPECT_TRUE(e->IsDefault()); + EXPECT_TRUE(e->ContainsDefault()); ASSERT_EQ(e->body->statements.Length(), 1u); EXPECT_TRUE(e->body->statements[0]->Is()); } diff --git a/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc index 014d850b36..a5e3dd37aa 100644 --- a/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc +++ b/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc @@ -29,8 +29,8 @@ TEST_F(ParserImplTest, SwitchStmt_WithoutDefault) { ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); ASSERT_EQ(e->body.Length(), 2u); - EXPECT_FALSE(e->body[0]->IsDefault()); - EXPECT_FALSE(e->body[1]->IsDefault()); + EXPECT_FALSE(e->body[0]->ContainsDefault()); + EXPECT_FALSE(e->body[1]->ContainsDefault()); } TEST_F(ParserImplTest, SwitchStmt_Empty) { @@ -58,9 +58,24 @@ TEST_F(ParserImplTest, SwitchStmt_DefaultInMiddle) { ASSERT_TRUE(e->Is()); ASSERT_EQ(e->body.Length(), 3u); - ASSERT_FALSE(e->body[0]->IsDefault()); - ASSERT_TRUE(e->body[1]->IsDefault()); - ASSERT_FALSE(e->body[2]->IsDefault()); + ASSERT_FALSE(e->body[0]->ContainsDefault()); + ASSERT_TRUE(e->body[1]->ContainsDefault()); + ASSERT_FALSE(e->body[2]->ContainsDefault()); +} + +TEST_F(ParserImplTest, SwitchStmt_Default_Mixed) { + auto p = parser(R"(switch a { + case 1, default, 2: {} +})"); + auto e = p->switch_statement(); + EXPECT_TRUE(e.matched); + EXPECT_FALSE(e.errored); + EXPECT_FALSE(p->has_error()) << p->error(); + ASSERT_NE(e.value, nullptr); + ASSERT_TRUE(e->Is()); + + ASSERT_EQ(e->body.Length(), 1u); + ASSERT_TRUE(e->body[0]->ContainsDefault()); } TEST_F(ParserImplTest, SwitchStmt_WithParens) { @@ -123,7 +138,7 @@ TEST_F(ParserImplTest, SwitchStmt_InvalidBody) { EXPECT_TRUE(e.errored); EXPECT_EQ(e.value, nullptr); EXPECT_TRUE(p->has_error()); - EXPECT_EQ(p->error(), "2:7: unable to parse case selectors"); + EXPECT_EQ(p->error(), "2:7: expected case selector expression or `default`"); } } // namespace diff --git a/src/tint/resolver/compound_statement_test.cc b/src/tint/resolver/compound_statement_test.cc index 0a96cedd8a..b5a5ba9fb9 100644 --- a/src/tint/resolver/compound_statement_test.cc +++ b/src/tint/resolver/compound_statement_test.cc @@ -390,8 +390,8 @@ TEST_F(ResolverCompoundStatementTest, Switch) { auto* stmt_a = Ignore(1_i); auto* stmt_b = Ignore(1_i); auto* stmt_c = Ignore(1_i); - auto* swi = Switch(expr, Case(Expr(1_i), Block(stmt_a)), Case(Expr(2_i), Block(stmt_b)), - DefaultCase(Block(stmt_c))); + auto* swi = Switch(expr, Case(CaseSelector(1_i), Block(stmt_a)), + Case(CaseSelector(2_i), Block(stmt_b)), DefaultCase(Block(stmt_c))); WrapInFunction(swi); ASSERT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/tint/resolver/control_block_validation_test.cc b/src/tint/resolver/control_block_validation_test.cc index 403d0bc4ea..8cbf506969 100644 --- a/src/tint/resolver/control_block_validation_test.cc +++ b/src/tint/resolver/control_block_validation_test.cc @@ -70,7 +70,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchWithoutDefault_Fail) { auto* block = Block(Decl(var), // Switch(Source{{12, 34}}, "a", // - Case(Expr(1_i)))); + Case(CaseSelector(1_i)))); WrapInFunction(block); @@ -87,16 +87,79 @@ TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_Fail) { // } auto* var = Var("a", ty.i32(), Expr(2_i)); - auto* block = Block(Decl(var), // - Switch("a", // - DefaultCase(), // - Case(Expr(1_i)), // + auto* block = Block(Decl(var), // + Switch("a", // + DefaultCase(Source{{9, 2}}), // + Case(CaseSelector(1_i)), // DefaultCase(Source{{12, 34}}))); WrapInFunction(block); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: switch statement must have exactly one default clause"); + EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause +9:2 note: previous default case)"); +} + +TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_OneInCase_Fail) { + // var a : i32 = 2; + // switch (a) { + // case 1, default: {} + // default: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + + auto* block = Block( + Decl(var), // + Switch("a", // + Case(utils::Vector{CaseSelector(1_i), DefaultCaseSelector(Source{{9, 2}})}), // + DefaultCase(Source{{12, 34}}))); + + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause +9:2 note: previous default case)"); +} + +TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_SameCase) { + // var a : i32 = 2; + // switch (a) { + // case default, 1, default: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + + auto* block = + Block(Decl(var), // + Switch("a", // + Case(utils::Vector{DefaultCaseSelector(Source{{9, 2}}), CaseSelector(1_i), + DefaultCaseSelector(Source{{12, 34}})}))); + + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause +9:2 note: previous default case)"); +} + +TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_DifferentMultiCase) { + // var a : i32 = 2; + // switch (a) { + // case 1, default: {} + // case default, 2: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + + auto* block = Block( + Decl(var), // + Switch("a", // + Case(utils::Vector{CaseSelector(1_i), DefaultCaseSelector(Source{{9, 2}})}), + Case(utils::Vector{DefaultCaseSelector(Source{{12, 34}}), CaseSelector(2_i)}))); + + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause +9:2 note: previous default case)"); } TEST_F(ResolverControlBlockValidationTest, UnreachableCode_Loop_continue) { @@ -187,9 +250,9 @@ TEST_F(ResolverControlBlockValidationTest, UnreachableCode_break) { auto* decl_z = Decl(Var("z", ty.i32())); auto* brk = Break(); auto* assign_z = Assign(Source{{12, 34}}, "z", 1_i); - WrapInFunction( // - Block(Switch(1_i, // - Case(Expr(1_i), Block(decl_z, brk, assign_z)), // + WrapInFunction( // + Block(Switch(1_i, // + Case(CaseSelector(1_i), Block(decl_z, brk, assign_z)), // DefaultCase()))); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -210,11 +273,11 @@ TEST_F(ResolverControlBlockValidationTest, UnreachableCode_break_InBlocks) { auto* decl_z = Decl(Var("z", ty.i32())); auto* brk = Break(); auto* assign_z = Assign(Source{{12, 34}}, "z", 1_i); - WrapInFunction( - Loop(Block(Switch(1_i, // - Case(Expr(1_i), Block(decl_z, Block(Block(Block(brk))), assign_z)), - DefaultCase()), // - Break()))); + WrapInFunction(Loop( + Block(Switch(1_i, // + Case(CaseSelector(1_i), Block(decl_z, Block(Block(Block(brk))), assign_z)), + DefaultCase()), // + Break()))); ASSERT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable"); @@ -231,8 +294,8 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT // } auto* var = Var("a", ty.i32(), Expr(2_i)); - auto* block = Block(Decl(var), Switch("a", // - Case(Expr(Source{{12, 34}}, 1_u)), // + auto* block = Block(Decl(var), Switch("a", // + Case(CaseSelector(Source{{12, 34}}, 1_u)), // DefaultCase())); WrapInFunction(block); @@ -250,9 +313,9 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT // } auto* var = Var("a", ty.u32(), Expr(2_u)); - auto* block = Block(Decl(var), // - Switch("a", // - Case(utils::Vector{Expr(Source{{12, 34}}, -1_i)}), // + auto* block = Block(Decl(var), // + Switch("a", // + Case(CaseSelector(Source{{12, 34}}, -1_i)), // DefaultCase())); WrapInFunction(block); @@ -273,11 +336,11 @@ TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorValueUint_Fail) auto* block = Block(Decl(var), // Switch("a", // - Case(Expr(0_u)), + Case(CaseSelector(0_u)), Case(utils::Vector{ - Expr(Source{{12, 34}}, 2_u), - Expr(3_u), - Expr(Source{{56, 78}}, 2_u), + CaseSelector(Source{{12, 34}}, 2_u), + CaseSelector(3_u), + CaseSelector(Source{{56, 78}}, 2_u), }), DefaultCase())); WrapInFunction(block); @@ -299,12 +362,12 @@ TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorValueSint_Fail) auto* block = Block(Decl(var), // Switch("a", // - Case(Expr(Source{{12, 34}}, -10_i)), + Case(CaseSelector(Source{{12, 34}}, -10_i)), Case(utils::Vector{ - Expr(0_i), - Expr(1_i), - Expr(2_i), - Expr(Source{{56, 78}}, -10_i), + CaseSelector(0_i), + CaseSelector(1_i), + CaseSelector(2_i), + CaseSelector(Source{{56, 78}}, -10_i), }), DefaultCase())); WrapInFunction(block); @@ -344,7 +407,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) { auto* block = Block(Decl(var), // Switch("a", // DefaultCase(Source{{12, 34}}), // - Case(Expr(5_i)))); + Case(CaseSelector(5_i)))); WrapInFunction(block); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -361,7 +424,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Pass) { auto* block = Block(Decl(var), // Switch("a", // DefaultCase(Source{{12, 34}}), // - Case(Add(5_i, 6_i)))); + Case(CaseSelector(Add(5_i, 6_i))))); WrapInFunction(block); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -378,7 +441,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixI32_Abstract auto* block = Block(Decl(var), // Switch("a", // DefaultCase(Source{{12, 34}}), // - Case(Add(5_i, 6_i)))); + Case(CaseSelector(Add(5_i, 6_i))))); WrapInFunction(block); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -395,7 +458,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixU32_Abstract auto* block = Block(Decl(var), // Switch("a", // DefaultCase(Source{{12, 34}}), // - Case(Add(5_a, 6_a)))); + Case(CaseSelector(Add(5_a, 6_a))))); WrapInFunction(block); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -409,10 +472,12 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Multiple) { // } auto* var = Var("a", Expr(2_u)); - auto* block = Block(Decl(var), // - Switch("a", // - DefaultCase(Source{{12, 34}}), // - Case(utils::Vector{Add(5_u, 6_u), Add(7_u, 9_u), Mul(2_u, 4_u)}))); + auto* block = + Block(Decl(var), // + Switch("a", // + DefaultCase(Source{{12, 34}}), // + Case(utils::Vector{CaseSelector(Add(5_u, 6_u)), CaseSelector(Add(7_u, 9_u)), + CaseSelector(Mul(2_u, 4_u))}))); WrapInFunction(block); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -446,8 +511,8 @@ TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelector_Expression_Fail auto* block = Block(Decl(var), // Switch("a", // - Case(Expr(Source{{12, 34}}, 10_i)), - Case(Add(Source{{56, 78}}, 5_i, 5_i)), DefaultCase())); + Case(CaseSelector(Source{{12, 34}}, 10_i)), + Case(CaseSelector(Source{{56, 78}}, Add(5_i, 5_i))), DefaultCase())); WrapInFunction(block); EXPECT_FALSE(r()->Resolve()); @@ -466,8 +531,8 @@ TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSameCase_BothExp auto* block = Block(Decl(var), // Switch("a", // - Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), - Add(Source{{12, 34}}, 6_i, 4_i)}), + Case(utils::Vector{CaseSelector(Source{{56, 78}}, Add(5_i, 5_i)), + CaseSelector(Source{{12, 34}}, Add(6_i, 4_i))}), DefaultCase())); WrapInFunction(block); @@ -485,11 +550,11 @@ TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSame_Case_Expres // } auto* var = Var("a", ty.i32(), Expr(2_i)); - auto* block = Block( - Decl(var), // - Switch("a", // - Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), Expr(Source{{12, 34}}, 10_i)}), - DefaultCase())); + auto* block = Block(Decl(var), // + Switch("a", // + Case(utils::Vector{CaseSelector(Source{{56, 78}}, Add(5_i, 5_i)), + CaseSelector(Source{{12, 34}}, 10_i)}), + DefaultCase())); WrapInFunction(block); EXPECT_FALSE(r()->Resolve()); @@ -508,7 +573,7 @@ TEST_F(ResolverControlBlockValidationTest, Switch_OverrideCondition_Fail) { auto* block = Block(Decl(var), // Switch("a", // - Case(Expr(Source{{12, 34}}, "b")), DefaultCase())); + Case(CaseSelector(Source{{12, 34}}, "b")), DefaultCase())); WrapInFunction(block); EXPECT_FALSE(r()->Resolve()); diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index e84eec5679..edc111c420 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -299,7 +299,7 @@ class DependencyScanner { TraverseExpression(s->condition); for (auto* c : s->body) { for (auto* sel : c->selectors) { - TraverseExpression(sel); + TraverseExpression(sel->expr); } TraverseStatement(c->body); } diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc index 724a527457..a984c20a62 100644 --- a/src/tint/resolver/dependency_graph_test.cc +++ b/src/tint/resolver/dependency_graph_test.cc @@ -1262,9 +1262,9 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) { Loop(Block(Assign(V, V)), // Block(Assign(V, V))), // Switch(V, // - Case(Expr(1_i), // + Case(CaseSelector(1_i), // Block(Assign(V, V))), // - Case(Expr(2_i), // + Case(CaseSelector(2_i), // Block(Fallthrough())), // DefaultCase(Block(Assign(V, V)))), // Return(V), // diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc index 251fe82fbe..1b96a277be 100644 --- a/src/tint/resolver/materialize_test.cc +++ b/src/tint/resolver/materialize_test.cc @@ -356,26 +356,30 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) { WrapInFunction(Add(target_expr(), abstract_expr)); break; case Method::kSwitchCond: - WrapInFunction(Switch(abstract_expr, // - Case(target_expr()->As()), // - DefaultCase())); + WrapInFunction( + Switch(abstract_expr, // + Case(CaseSelector(target_expr()->As())), // + DefaultCase())); break; case Method::kSwitchCase: - WrapInFunction(Switch(target_expr(), // - Case(abstract_expr->As()), // - DefaultCase())); + WrapInFunction( + Switch(target_expr(), // + Case(CaseSelector(abstract_expr->As())), // + DefaultCase())); break; case Method::kSwitchCondWithAbstractCase: - WrapInFunction(Switch(abstract_expr, // - Case(Expr(123_a)), // - Case(target_expr()->As()), // - DefaultCase())); + WrapInFunction( + Switch(abstract_expr, // + Case(CaseSelector(123_a)), // + Case(CaseSelector(target_expr()->As())), // + DefaultCase())); break; case Method::kSwitchCaseWithAbstractCase: - WrapInFunction(Switch(target_expr(), // - Case(Expr(123_a)), // - Case(abstract_expr->As()), // - DefaultCase())); + WrapInFunction( + Switch(target_expr(), // + Case(CaseSelector(123_a)), // + Case(CaseSelector(abstract_expr->As())), // + DefaultCase())); break; case Method::kWorkgroupSize: Func("f", utils::Empty, ty.void_(), utils::Empty, @@ -903,9 +907,10 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) { break; } case Method::kSwitch: { - WrapInFunction(Switch(abstract_expr(), - Case(abstract_expr()->As()), - DefaultCase())); + WrapInFunction( + Switch(abstract_expr(), + Case(CaseSelector(abstract_expr()->As())), + DefaultCase())); break; } case Method::kWorkgroupSize: { diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index b2961bdfb4..21501c58f9 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1240,26 +1240,31 @@ sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt, cons return StatementScope(stmt, sem, [&] { sem->Selectors().reserve(stmt->selectors.Length()); for (auto* sel : stmt->selectors) { + Mark(sel); + ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"}; TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - // The sem statement is created in the switch when attempting to determine the common - // type. - auto* materialized = Materialize(sem_.Get(sel), ty); - if (!materialized) { - return false; - } - if (!materialized->Type()->IsAnyOf()) { - AddError("case selector must be an i32 or u32 value", sel->source); - return false; - } - auto const_value = materialized->ConstantValue(); - if (!const_value) { - AddError("case selector must be a constant expression", sel->source); - return false; + const sem::Constant* const_value = nullptr; + if (!sel->IsDefault()) { + // The sem statement was created in the switch when attempting to determine the + // common type. + auto* materialized = Materialize(sem_.Get(sel->expr), ty); + if (!materialized) { + return false; + } + if (!materialized->Type()->IsAnyOf()) { + AddError("case selector must be an i32 or u32 value", sel->source); + return false; + } + const_value = materialized->ConstantValue(); + if (!const_value) { + AddError("case selector must be a constant expression", sel->source); + return false; + } } - sem->Selectors().emplace_back(const_value); + sem->Selectors().emplace_back(builder_->create(sel, const_value)); } Mark(stmt->body); @@ -3103,8 +3108,11 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt utils::Vector types; types.Push(cond_ty); for (auto* case_stmt : stmt->body) { - for (auto* expr : case_stmt->selectors) { - auto* sem_expr = Expression(expr); + for (auto* sel : case_stmt->selectors) { + if (sel->IsDefault()) { + continue; + } + auto* sem_expr = Expression(sel->expr); types.Push(sem_expr->Type()->UnwrapRef()); } } @@ -3127,9 +3135,6 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt if (!c) { return false; } - for (auto* expr : c->Selectors()) { - types.Push(expr->Type()->UnwrapRef()); - } cases.Push(c); behaviors.Add(c->Behaviors()); sem->Cases().emplace_back(c); diff --git a/src/tint/resolver/resolver_behavior_test.cc b/src/tint/resolver/resolver_behavior_test.cc index 5002857045..cf915016c3 100644 --- a/src/tint/resolver/resolver_behavior_test.cc +++ b/src/tint/resolver/resolver_behavior_test.cc @@ -633,7 +633,7 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultReturn) { } TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) { - auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block())); + auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block()), DefaultCase(Block())); WrapInFunction(stmt); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -643,7 +643,7 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) { } TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) { - auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block(Discard()))); + auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block()), DefaultCase(Block(Discard()))); Func("F", utils::Empty, ty.void_(), utils::Vector{stmt}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); @@ -655,7 +655,7 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) { } TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) { - auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block(Return()))); + auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block()), DefaultCase(Block(Return()))); WrapInFunction(stmt); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -665,7 +665,7 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) { } TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) { - auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block())); + auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block(Discard())), DefaultCase(Block())); Func("F", utils::Empty, ty.void_(), utils::Vector{stmt}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); @@ -677,7 +677,8 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) { } TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) { - auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Discard()))); + auto* stmt = + Switch(1_i, Case(CaseSelector(0_i), Block(Discard())), DefaultCase(Block(Discard()))); Func("F", utils::Empty, ty.void_(), utils::Vector{stmt}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); @@ -689,7 +690,8 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) } TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) { - auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Return()))); + auto* stmt = + Switch(1_i, Case(CaseSelector(0_i), Block(Discard())), DefaultCase(Block(Return()))); Func("F", utils::Empty, ty.void_(), utils::Vector{stmt}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); @@ -701,9 +703,9 @@ TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) } TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_Case1Return_DefaultEmpty) { - auto* stmt = Switch(1_i, // - Case(Expr(0_i), Block(Discard())), // - Case(Expr(1_i), Block(Return())), // + auto* stmt = Switch(1_i, // + Case(CaseSelector(0_i), Block(Discard())), // + Case(CaseSelector(1_i), Block(Return())), // DefaultCase(Block())); Func("F", utils::Empty, ty.void_(), utils::Vector{stmt}, diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc index 4538bf4894..0af3894b0e 100644 --- a/src/tint/resolver/resolver_test.cc +++ b/src/tint/resolver/resolver_test.cc @@ -112,7 +112,7 @@ TEST_F(ResolverTest, Stmt_Case) { auto* assign = Assign(lhs, rhs); auto* block = Block(assign); - auto* sel = Expr(3_i); + auto* sel = CaseSelector(3_i); auto* cse = Case(sel, block); auto* def = DefaultCase(); auto* cond_var = Var("c", ty.i32()); @@ -132,7 +132,7 @@ TEST_F(ResolverTest, Stmt_Case) { ASSERT_EQ(sem->Cases().size(), 2u); EXPECT_EQ(sem->Cases()[0]->Declaration(), cse); ASSERT_EQ(sem->Cases()[0]->Selectors().size(), 1u); - EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 0u); + EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 1u); } TEST_F(ResolverTest, Stmt_Block) { @@ -251,7 +251,7 @@ TEST_F(ResolverTest, Stmt_Switch) { auto* lhs = Expr("v"); auto* rhs = Expr(2.3_f); auto* case_block = Block(Assign(lhs, rhs)); - auto* stmt = Switch(Expr(2_i), Case(Expr(3_i), case_block), DefaultCase()); + auto* stmt = Switch(Expr(2_i), Case(CaseSelector(3_i), case_block), DefaultCase()); WrapInFunction(v, stmt); EXPECT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc index fb210c0b3f..afd9e9154b 100644 --- a/src/tint/resolver/validation_test.cc +++ b/src/tint/resolver/validation_test.cc @@ -1039,11 +1039,11 @@ TEST_F(ResolverValidationTest, Stmt_BreakInLoop) { } TEST_F(ResolverValidationTest, Stmt_BreakInSwitch) { - WrapInFunction(Loop(Block(Switch(Expr(1_i), // - Case(Expr(1_i), // - Block(Break())), // - DefaultCase()), // - Break()))); // + WrapInFunction(Loop(Block(Switch(Expr(1_i), // + Case(CaseSelector(1_i), // + Block(Break())), // + DefaultCase()), // + Break()))); // EXPECT_TRUE(r()->Resolve()) << r()->error(); } diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index f7382743bd..4489543f46 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -2324,54 +2324,50 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) { return false; } - bool has_default = false; + const sem::CaseSelector* default_selector = nullptr; std::unordered_map selectors; for (auto* case_stmt : s->body) { - if (case_stmt->IsDefault()) { - if (has_default) { - // More than one default clause - AddError("switch statement must have exactly one default clause", - case_stmt->source); - return false; - } - has_default = true; - } - auto* case_sem = sem_.Get(case_stmt); + for (auto* selector : case_sem->Selectors()) { + if (selector->IsDefault()) { + if (default_selector != nullptr) { + // More than one default clause + AddError("switch statement must have exactly one default clause", + selector->Declaration()->source); - auto& case_selectors = case_stmt->selectors; - auto& selector_values = case_sem->Selectors(); - TINT_ASSERT(Resolver, case_selectors.Length() == selector_values.size()); - for (size_t i = 0; i < case_sem->Selectors().size(); ++i) { - auto* selector = selector_values[i]; - if (cond_ty != selector->Type()) { + AddNote("previous default case", default_selector->Declaration()->source); + return false; + } + default_selector = selector; + continue; + } + + auto* decl_ty = selector->Value()->Type(); + if (cond_ty != decl_ty) { AddError( "the case selector values must have the same type as the selector expression.", - case_selectors[i]->source); + selector->Declaration()->source); return false; } - auto value = selector->As(); + auto value = selector->Value()->As(); auto it = selectors.find(value); if (it != selectors.end()) { - std::string err = "duplicate switch case '"; - if (selector->Type()->Is()) { - err += std::to_string(selector->As()); - } else { - err += std::to_string(value); - } - err += "'"; - - AddError(err, case_selectors[i]->source); + AddError("duplicate switch case '" + + (decl_ty->IsAnyOf() + ? std::to_string(i32(value)) + : std::to_string(value)) + + "'", + selector->Declaration()->source); AddNote("previous case declared here", it->second); return false; } - selectors.emplace(value, case_selectors[i]->source); + selectors.emplace(value, selector->Declaration()->source); } } - if (!has_default) { + if (default_selector == nullptr) { // No default clause AddError("switch statement must have a default clause", s->source); return false; diff --git a/src/tint/sem/switch_statement.cc b/src/tint/sem/switch_statement.cc index ed3942dce4..5eb2f09102 100644 --- a/src/tint/sem/switch_statement.cc +++ b/src/tint/sem/switch_statement.cc @@ -17,6 +17,7 @@ #include "src/tint/program_builder.h" TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseStatement); +TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseSelector); TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchStatement); namespace tint::sem { @@ -48,4 +49,13 @@ const ast::CaseStatement* CaseStatement::Declaration() const { return static_cast(Base::Declaration()); } +CaseSelector::CaseSelector(const ast::CaseSelector* decl, const Constant* val) + : Base(), decl_(decl), val_(val) {} + +CaseSelector::~CaseSelector() = default; + +const ast::CaseSelector* CaseSelector::Declaration() const { + return decl_; +} + } // namespace tint::sem diff --git a/src/tint/sem/switch_statement.h b/src/tint/sem/switch_statement.h index 7028c052e6..929f8cfacc 100644 --- a/src/tint/sem/switch_statement.h +++ b/src/tint/sem/switch_statement.h @@ -22,10 +22,12 @@ // Forward declarations namespace tint::ast { class CaseStatement; +class CaseSelector; class SwitchStatement; } // namespace tint::ast namespace tint::sem { class CaseStatement; +class CaseSelector; class Constant; class Expression; } // namespace tint::sem @@ -83,14 +85,39 @@ class CaseStatement final : public Castable { const BlockStatement* Body() const { return body_; } /// @returns the selectors for the case - std::vector& Selectors() { return selectors_; } + std::vector& Selectors() { return selectors_; } /// @returns the selectors for the case - const std::vector& Selectors() const { return selectors_; } + const std::vector& Selectors() const { return selectors_; } private: const BlockStatement* body_ = nullptr; - std::vector selectors_; + std::vector selectors_; +}; + +/// Holds semantic information about a switch case selector +class CaseSelector final : public Castable { + public: + /// Constructor + /// @param decl the selector declaration + /// @param val the case selector value, nullptr for a default selector + explicit CaseSelector(const ast::CaseSelector* decl, const Constant* val = nullptr); + + /// Destructor + ~CaseSelector() override; + + /// @returns true if this is a default selector + bool IsDefault() const { return val_ == nullptr; } + + /// @returns the case selector declaration + const ast::CaseSelector* Declaration() const; + + /// @returns the selector constant value, or nullptr if this is the default selector + const Constant* Value() const { return val_; } + + private: + const ast::CaseSelector* const decl_; + const Constant* const val_; }; } // namespace tint::sem diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc index 17f5a5bfe6..8edf8322fc 100644 --- a/src/tint/transform/std140.cc +++ b/src/tint/transform/std140.cc @@ -944,7 +944,7 @@ struct Std140::State { ret_ty = ty; } - auto* case_sel = b.Expr(u32(column_idx)); + auto* case_sel = b.CaseSelector(b.Expr(u32(column_idx))); auto* case_body = b.Block(utils::Vector{b.Return(expr)}); cases.Push(b.Case(case_sel, case_body)); } diff --git a/src/tint/transform/test_helper.h b/src/tint/transform/test_helper.h index 42218a72e6..bc82fe5a4f 100644 --- a/src/tint/transform/test_helper.h +++ b/src/tint/transform/test_helper.h @@ -115,7 +115,12 @@ class TransformTestBase : public BASE { /// @return true if the transform should be run for the given input. template bool ShouldRun(Program&& program, const DataMap& data = {}) { - EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str(); + if (!program.IsValid()) { + ADD_FAILURE() << "ShouldRun() called with invalid program: " + << program.Diagnostics().str(); + return false; + } + const Transform& t = TRANSFORM(); return t.ShouldRun(&program, data); } diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index cf7567818a..e2ac5fe315 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -1687,20 +1687,21 @@ std::string GeneratorImpl::generate_builtin_name(const sem::Builtin* builtin) { } bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) { - if (stmt->IsDefault()) { - line() << "default: {"; - } else { - auto* sem = builder_.Sem().Get(stmt); - for (auto* selector : sem->Selectors()) { - auto out = line(); + auto* sem = builder_.Sem().Get(stmt); + for (auto* selector : sem->Selectors()) { + auto out = line(); + + if (selector->IsDefault()) { + out << "default"; + } else { out << "case "; - if (!EmitConstant(out, selector)) { + if (!EmitConstant(out, selector->Value())) { return false; } - out << ":"; - if (selector == sem->Selectors().back()) { - out << " {"; - } + } + out << ":"; + if (selector == sem->Selectors().back()) { + out << " {"; } } diff --git a/src/tint/writer/glsl/generator_impl_case_test.cc b/src/tint/writer/glsl/generator_impl_case_test.cc index 1f60781086..4f1c6f310b 100644 --- a/src/tint/writer/glsl/generator_impl_case_test.cc +++ b/src/tint/writer/glsl/generator_impl_case_test.cc @@ -23,7 +23,8 @@ namespace { using GlslGeneratorImplTest_Case = TestHelper; TEST_F(GlslGeneratorImplTest_Case, Emit_Case) { - auto* s = Switch(1_i, Case(Expr(5_i), Block(create())), DefaultCase()); + auto* s = + Switch(1_i, Case(CaseSelector(5_i), Block(create())), DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -38,7 +39,7 @@ TEST_F(GlslGeneratorImplTest_Case, Emit_Case) { } TEST_F(GlslGeneratorImplTest_Case, Emit_Case_BreaksByDefault) { - auto* s = Switch(1_i, Case(Expr(5_i), Block()), DefaultCase()); + auto* s = Switch(1_i, Case(CaseSelector(5_i), Block()), DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -53,8 +54,8 @@ TEST_F(GlslGeneratorImplTest_Case, Emit_Case_BreaksByDefault) { } TEST_F(GlslGeneratorImplTest_Case, Emit_Case_WithFallthrough) { - auto* s = - Switch(1_i, Case(Expr(5_i), Block(create())), DefaultCase()); + auto* s = Switch(1_i, Case(CaseSelector(5_i), Block(create())), + DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -72,8 +73,8 @@ TEST_F(GlslGeneratorImplTest_Case, Emit_Case_MultipleSelectors) { auto* s = Switch(1_i, Case( utils::Vector{ - Expr(5_i), - Expr(6_i), + CaseSelector(5_i), + CaseSelector(6_i), }, Block(create())), DefaultCase()); diff --git a/src/tint/writer/glsl/generator_impl_switch_test.cc b/src/tint/writer/glsl/generator_impl_switch_test.cc index 7a2c7509bc..ada3c0ba55 100644 --- a/src/tint/writer/glsl/generator_impl_switch_test.cc +++ b/src/tint/writer/glsl/generator_impl_switch_test.cc @@ -25,21 +25,13 @@ TEST_F(GlslGeneratorImplTest_Switch, Emit_Switch) { GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate); auto* def_body = Block(create()); - auto* def = create(utils::Empty, def_body); - - utils::Vector case_val{Expr(5_i)}; + auto* def = create(utils::Vector{DefaultCaseSelector()}, def_body); auto* case_body = Block(create()); - - auto* case_stmt = create(case_val, case_body); - - utils::Vector body{ - case_stmt, - def, - }; + auto* case_stmt = create(utils::Vector{CaseSelector(5_i)}, case_body); auto* cond = Expr("cond"); - auto* s = create(cond, body); + auto* s = create(cond, utils::Vector{case_stmt, def}); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -58,5 +50,30 @@ TEST_F(GlslGeneratorImplTest_Switch, Emit_Switch) { )"); } +TEST_F(GlslGeneratorImplTest_Switch, Emit_Switch_MixedDefault) { + GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate); + + auto* def_body = Block(create()); + auto* def = create(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()}, + def_body); + + auto* cond = Expr("cond"); + auto* s = create(cond, utils::Vector{def}); + WrapInFunction(s); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(s)) << gen.error(); + EXPECT_EQ(gen.result(), R"( switch(cond) { + case 5: + default: { + break; + } + } +)"); +} + } // namespace } // namespace tint::writer::glsl diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 7bdb8e95f3..6a3c5c4195 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -2562,20 +2562,20 @@ std::string GeneratorImpl::generate_builtin_name(const sem::Builtin* builtin) { bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) { auto* stmt = s->body[case_idx]; - if (stmt->IsDefault()) { - line() << "default: {"; - } else { - auto* sem = builder_.Sem().Get(stmt); - for (auto* selector : sem->Selectors()) { - auto out = line(); + auto* sem = builder_.Sem().Get(stmt); + for (auto* selector : sem->Selectors()) { + auto out = line(); + if (selector->IsDefault()) { + out << "default"; + } else { out << "case "; - if (!EmitConstant(out, selector)) { + if (!EmitConstant(out, selector->Value())) { return false; } - out << ":"; - if (selector == sem->Selectors().back()) { - out << " {"; - } + } + out << ":"; + if (selector == sem->Selectors().back()) { + out << " {"; } } @@ -3652,7 +3652,7 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { } bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) { - TINT_ASSERT(Writer, stmt->body.Length() == 1 && stmt->body[0]->IsDefault()); + TINT_ASSERT(Writer, stmt->body.Length() == 1 && stmt->body[0]->ContainsDefault()); // FXC fails to compile a switch with just a default case, ignoring the // default case body. We work around this here by emitting the default case @@ -3685,7 +3685,8 @@ bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) { bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) { // BUG(crbug.com/tint/1188): work around default-only switches - if (stmt->body.Length() == 1 && stmt->body[0]->IsDefault()) { + if (stmt->body.Length() == 1 && stmt->body[0]->selectors.Length() == 1 && + stmt->body[0]->ContainsDefault()) { return EmitDefaultOnlySwitch(stmt); } diff --git a/src/tint/writer/hlsl/generator_impl_case_test.cc b/src/tint/writer/hlsl/generator_impl_case_test.cc index c55f14b6a1..bfcbcafdb8 100644 --- a/src/tint/writer/hlsl/generator_impl_case_test.cc +++ b/src/tint/writer/hlsl/generator_impl_case_test.cc @@ -23,7 +23,8 @@ namespace { using HlslGeneratorImplTest_Case = TestHelper; TEST_F(HlslGeneratorImplTest_Case, Emit_Case) { - auto* s = Switch(1_i, Case(Expr(5_i), Block(create())), DefaultCase()); + auto* s = + Switch(1_i, Case(CaseSelector(5_i), Block(create())), DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -38,7 +39,7 @@ TEST_F(HlslGeneratorImplTest_Case, Emit_Case) { } TEST_F(HlslGeneratorImplTest_Case, Emit_Case_BreaksByDefault) { - auto* s = Switch(1_i, Case(Expr(5_i), Block()), DefaultCase()); + auto* s = Switch(1_i, Case(CaseSelector(5_i), Block()), DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -53,9 +54,9 @@ TEST_F(HlslGeneratorImplTest_Case, Emit_Case_BreaksByDefault) { } TEST_F(HlslGeneratorImplTest_Case, Emit_Case_WithFallthrough) { - auto* s = Switch(1_i, // - Case(Expr(4_i), Block(create())), // - Case(Expr(5_i), Block(create())), // + auto* s = Switch(1_i, // + Case(CaseSelector(4_i), Block(create())), // + Case(CaseSelector(5_i), Block(create())), // DefaultCase()); WrapInFunction(s); @@ -75,9 +76,10 @@ TEST_F(HlslGeneratorImplTest_Case, Emit_Case_WithFallthrough) { } TEST_F(HlslGeneratorImplTest_Case, Emit_Case_MultipleSelectors) { - auto* s = - Switch(1_i, Case(utils::Vector{Expr(5_i), Expr(6_i)}, Block(create())), - DefaultCase()); + auto* s = Switch(1_i, + Case(utils::Vector{CaseSelector(5_i), CaseSelector(6_i)}, + Block(create())), + DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); diff --git a/src/tint/writer/hlsl/generator_impl_switch_test.cc b/src/tint/writer/hlsl/generator_impl_switch_test.cc index 24c17a692e..a84e632ca1 100644 --- a/src/tint/writer/hlsl/generator_impl_switch_test.cc +++ b/src/tint/writer/hlsl/generator_impl_switch_test.cc @@ -23,9 +23,9 @@ using HlslGeneratorImplTest_Switch = TestHelper; TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch) { GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate); - auto* s = Switch( // - Expr("cond"), // - Case(Expr(5_i), Block(Break())), // + auto* s = Switch( // + Expr("cond"), // + Case(CaseSelector(5_i), Block(Break())), // DefaultCase()); WrapInFunction(s); @@ -45,6 +45,27 @@ TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch) { )"); } +TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_MixedDefault) { + GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate); + auto* s = Switch( // + Expr("cond"), // + Case(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()}, Block(Break()))); + WrapInFunction(s); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(s)) << gen.error(); + EXPECT_EQ(gen.result(), R"( switch(cond) { + case 5: + default: { + break; + } + } +)"); +} + TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase) { GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate); GlobalVar("a", ty.i32(), ast::AddressSpace::kPrivate); diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index f765d93960..976734fc98 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -1589,20 +1589,21 @@ std::string GeneratorImpl::generate_builtin_name(const sem::Builtin* builtin) { } bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) { - if (stmt->IsDefault()) { - line() << "default: {"; - } else { - auto* sem = builder_.Sem().Get(stmt); - for (auto* selector : sem->Selectors()) { - auto out = line(); + auto* sem = builder_.Sem().Get(stmt); + for (auto* selector : sem->Selectors()) { + auto out = line(); + + if (selector->IsDefault()) { + out << "default"; + } else { out << "case "; - if (!EmitConstant(out, selector)) { + if (!EmitConstant(out, selector->Value())) { return false; } - out << ":"; - if (selector == sem->Selectors().back()) { - out << " {"; - } + } + out << ":"; + if (selector == sem->Selectors().back()) { + out << " {"; } } diff --git a/src/tint/writer/msl/generator_impl_case_test.cc b/src/tint/writer/msl/generator_impl_case_test.cc index 250d67d9c7..8aae4fe2c1 100644 --- a/src/tint/writer/msl/generator_impl_case_test.cc +++ b/src/tint/writer/msl/generator_impl_case_test.cc @@ -23,7 +23,8 @@ namespace { using MslGeneratorImplTest = TestHelper; TEST_F(MslGeneratorImplTest, Emit_Case) { - auto* s = Switch(1_i, Case(Expr(5_i), Block(create())), DefaultCase()); + auto* s = + Switch(1_i, Case(CaseSelector(5_i), Block(create())), DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -38,7 +39,7 @@ TEST_F(MslGeneratorImplTest, Emit_Case) { } TEST_F(MslGeneratorImplTest, Emit_Case_BreaksByDefault) { - auto* s = Switch(1_i, Case(Expr(5_i), Block()), DefaultCase()); + auto* s = Switch(1_i, Case(CaseSelector(5_i), Block()), DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -53,8 +54,8 @@ TEST_F(MslGeneratorImplTest, Emit_Case_BreaksByDefault) { } TEST_F(MslGeneratorImplTest, Emit_Case_WithFallthrough) { - auto* s = - Switch(1_i, Case(Expr(5_i), Block(create())), DefaultCase()); + auto* s = Switch(1_i, Case(CaseSelector(5_i), Block(create())), + DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -72,8 +73,8 @@ TEST_F(MslGeneratorImplTest, Emit_Case_MultipleSelectors) { auto* s = Switch(1_i, Case( utils::Vector{ - Expr(5_i), - Expr(6_i), + CaseSelector(5_i), + CaseSelector(6_i), }, Block(create())), DefaultCase()); diff --git a/src/tint/writer/msl/generator_impl_switch_test.cc b/src/tint/writer/msl/generator_impl_switch_test.cc index b327f34f36..6b47d09a84 100644 --- a/src/tint/writer/msl/generator_impl_switch_test.cc +++ b/src/tint/writer/msl/generator_impl_switch_test.cc @@ -25,16 +25,13 @@ TEST_F(MslGeneratorImplTest, Emit_Switch) { auto* cond = Var("cond", ty.i32()); auto* def_body = Block(create()); - auto* def = create(utils::Empty, def_body); - - utils::Vector case_val{Expr(5_i)}; + auto* def = Case(DefaultCaseSelector(), def_body); auto* case_body = Block(create()); - - auto* case_stmt = create(case_val, case_body); + auto* case_stmt = Case(CaseSelector(5_i), case_body); utils::Vector body{case_stmt, def}; - auto* s = create(Expr(cond), body); + auto* s = Switch(cond, body); WrapInFunction(cond, s); GeneratorImpl& gen = Build(); @@ -52,5 +49,27 @@ TEST_F(MslGeneratorImplTest, Emit_Switch) { )"); } +TEST_F(MslGeneratorImplTest, Emit_Switch_MixedDefault) { + auto* cond = Var("cond", ty.i32()); + + auto* def_body = Block(create()); + auto* def = Case(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()}, def_body); + + auto* s = Switch(cond, def); + WrapInFunction(cond, s); + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(s)) << gen.error(); + EXPECT_EQ(gen.result(), R"( switch(cond) { + case 5: + default: { + break; + } + } +)"); +} + } // namespace } // namespace tint::writer::msl diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index b72de6782f..790ebe6f7f 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -3456,19 +3456,26 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { std::vector case_ids; for (const auto* item : stmt->body) { - if (item->IsDefault()) { - case_ids.push_back(default_block_id); + auto block_id = default_block_id; + if (!item->ContainsDefault()) { + auto block = result_op(); + block_id = std::get(block); + } + case_ids.push_back(block_id); + + // If this case statement is only a default selector skip adding the block + // as it will be done below. + if (item->selectors.Length() == 1 && item->ContainsDefault()) { continue; } - auto block = result_op(); - auto block_id = std::get(block); - - case_ids.push_back(block_id); - auto* sem = builder_.Sem().Get(item); for (auto* selector : sem->Selectors()) { - params.push_back(Operand(selector->As())); + if (selector->IsDefault()) { + continue; + } + + params.push_back(Operand(selector->Value()->As())); params.push_back(Operand(block_id)); } } @@ -3490,7 +3497,7 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { for (uint32_t i = 0; i < body.Length(); i++) { auto* item = body[i]; - if (item->IsDefault()) { + if (item->ContainsDefault()) { generated_default = true; } diff --git a/src/tint/writer/spirv/builder_switch_test.cc b/src/tint/writer/spirv/builder_switch_test.cc index 4ee8f39f2f..c59c640316 100644 --- a/src/tint/writer/spirv/builder_switch_test.cc +++ b/src/tint/writer/spirv/builder_switch_test.cc @@ -62,9 +62,9 @@ TEST_F(BuilderTest, Switch_WithCase) { auto* func = Func("a_func", utils::Empty, ty.void_(), utils::Vector{ - Switch("a", // - Case(Expr(1_i), Block(Assign("v", 1_i))), // - Case(Expr(2_i), Block(Assign("v", 2_i))), // + Switch("a", // + Case(CaseSelector(1_i), Block(Assign("v", 1_i))), // + Case(CaseSelector(2_i), Block(Assign("v", 2_i))), // DefaultCase()), }); @@ -119,9 +119,9 @@ TEST_F(BuilderTest, Switch_WithCase_Unsigned) { auto* func = Func("a_func", utils::Empty, ty.void_(), utils::Vector{ - Switch("a", // - Case(Expr(1_u), Block(Assign("v", 1_i))), // - Case(Expr(2_u), Block(Assign("v", 2_i))), // + Switch("a", // + Case(CaseSelector(1_u), Block(Assign("v", 1_i))), // + Case(CaseSelector(2_u), Block(Assign("v", 2_i))), // DefaultCase()), }); @@ -226,11 +226,11 @@ TEST_F(BuilderTest, Switch_WithCaseAndDefault) { auto* func = Func("a_func", utils::Empty, ty.void_(), utils::Vector{ - Switch(Expr("a"), // - Case(Expr(1_i), // - Block(Assign("v", 1_i))), // - Case(utils::Vector{Expr(2_i), Expr(3_i)}, // - Block(Assign("v", 2_i))), // + Switch(Expr("a"), // + Case(CaseSelector(1_i), // + Block(Assign("v", 1_i))), // + Case(utils::Vector{CaseSelector(2_i), CaseSelector(3_i)}, // + Block(Assign("v", 2_i))), // DefaultCase(Block(Assign("v", 3_i)))), }); @@ -273,6 +273,61 @@ OpFunctionEnd )"); } +TEST_F(BuilderTest, Switch_WithCaseAndMixedDefault) { + // switch(a) { + // case 1i: + // v = 1i; + // case 2i, 3i, default: + // v = 2i; + // } + + auto* v = GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate); + auto* a = GlobalVar("a", ty.i32(), ast::AddressSpace::kPrivate); + + auto* func = Func("a_func", utils::Empty, ty.void_(), + utils::Vector{Switch(Expr("a"), // + Case(CaseSelector(1_i), // + Block(Assign("v", 1_i))), // + Case(utils::Vector{CaseSelector(2_i), CaseSelector(3_i), + DefaultCaseSelector()}, // + Block(Assign("v", 2_i))) // + )}); + + spirv::Builder& b = Build(); + + ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error(); + ASSERT_TRUE(b.GenerateFunction(func)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v" +OpName %5 "a" +OpName %8 "a_func" +%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%5 = OpVariable %2 Private %4 +%7 = OpTypeVoid +%6 = OpTypeFunction %7 +%14 = OpConstant %3 1 +%15 = OpConstant %3 2 +%8 = OpFunction %7 None %6 +%9 = OpLabel +%11 = OpLoad %3 %5 +OpSelectionMerge %10 None +OpSwitch %11 %12 1 %13 2 %12 3 %12 +%13 = OpLabel +OpStore %1 %14 +OpBranch %10 +%12 = OpLabel +OpStore %1 %15 +OpBranch %10 +%10 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + TEST_F(BuilderTest, Switch_CaseWithFallthrough) { // switch(a) { // case 1i: @@ -290,9 +345,9 @@ TEST_F(BuilderTest, Switch_CaseWithFallthrough) { auto* func = Func("a_func", utils::Empty, ty.void_(), utils::Vector{ Switch(Expr("a"), // - Case(Expr(1_i), // + Case(CaseSelector(1_i), // Block(Assign("v", 1_i), Fallthrough())), // - Case(Expr(2_i), // + Case(CaseSelector(2_i), // Block(Assign("v", 2_i))), // DefaultCase(Block(Assign("v", 3_i)))), }); @@ -351,9 +406,9 @@ TEST_F(BuilderTest, Switch_WithNestedBreak) { auto* func = Func("a_func", utils::Empty, ty.void_(), utils::Vector{ - Switch("a", // - Case(Expr(1_i), // - Block( // + Switch("a", // + Case(CaseSelector(1_i), // + Block( // If(Expr(true), Block(create())), Assign("v", 1_i))), DefaultCase()), @@ -414,9 +469,9 @@ TEST_F(BuilderTest, Switch_AllReturn) { auto* fn = Func("f", utils::Empty, ty.i32(), utils::Vector{ - Switch(1_i, // - Case(Expr(1_i), Block(Return(1_i))), // - Case(Expr(2_i), Block(Fallthrough())), // + Switch(1_i, // + Case(CaseSelector(1_i), Block(Return(1_i))), // + Case(CaseSelector(2_i), Block(Fallthrough())), // DefaultCase(Block(Return(3_i)))), }); diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc index 7183c74dc3..118bd4c68f 100644 --- a/src/tint/writer/wgsl/generator_impl.cc +++ b/src/tint/writer/wgsl/generator_impl.cc @@ -1024,26 +1024,28 @@ bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) { } bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) { - if (stmt->IsDefault()) { + if (stmt->selectors.Length() == 1 && stmt->ContainsDefault()) { line() << "default: {"; } else { auto out = line(); out << "case "; bool first = true; - for (auto* expr : stmt->selectors) { + for (auto* sel : stmt->selectors) { if (!first) { out << ", "; } first = false; - if (!EmitExpression(out, expr)) { + + if (sel->IsDefault()) { + out << "default"; + } else if (!EmitExpression(out, sel->expr)) { return false; } } out << ": {"; } - if (!EmitStatementsWithIndent(stmt->body->statements)) { return false; } diff --git a/src/tint/writer/wgsl/generator_impl_case_test.cc b/src/tint/writer/wgsl/generator_impl_case_test.cc index 6d59970ff7..39c28cb2d8 100644 --- a/src/tint/writer/wgsl/generator_impl_case_test.cc +++ b/src/tint/writer/wgsl/generator_impl_case_test.cc @@ -22,7 +22,8 @@ namespace { using WgslGeneratorImplTest = TestHelper; TEST_F(WgslGeneratorImplTest, Emit_Case) { - auto* s = Switch(1_i, Case(Expr(5_i), Block(create())), DefaultCase()); + auto* s = + Switch(1_i, Case(CaseSelector(5_i), Block(create())), DefaultCase()); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -40,8 +41,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Case_MultipleSelectors) { auto* s = Switch(1_i, Case( utils::Vector{ - Expr(5_i), - Expr(6_i), + CaseSelector(5_i), + CaseSelector(6_i), }, Block(create())), DefaultCase()); diff --git a/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc b/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc index 2b120514b2..093fba114a 100644 --- a/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc +++ b/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc @@ -23,8 +23,8 @@ using WgslGeneratorImplTest = TestHelper; TEST_F(WgslGeneratorImplTest, Emit_Fallthrough) { auto* f = create(); - WrapInFunction(Switch(1_i, // - Case(Expr(1_i), Block(f)), // + WrapInFunction(Switch(1_i, // + Case(CaseSelector(1_i), Block(f)), // DefaultCase())); GeneratorImpl& gen = Build(); diff --git a/src/tint/writer/wgsl/generator_impl_switch_test.cc b/src/tint/writer/wgsl/generator_impl_switch_test.cc index 5cf0ed2714..7a39079ba3 100644 --- a/src/tint/writer/wgsl/generator_impl_switch_test.cc +++ b/src/tint/writer/wgsl/generator_impl_switch_test.cc @@ -25,13 +25,10 @@ TEST_F(WgslGeneratorImplTest, Emit_Switch) { GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate); auto* def_body = Block(create()); - auto* def = create(utils::Empty, def_body); - - utils::Vector case_val{Expr(5_i)}; + auto* def = Case(DefaultCaseSelector(), def_body); auto* case_body = Block(create()); - - auto* case_stmt = create(case_val, case_body); + auto* case_stmt = Case(utils::Vector{CaseSelector(5_i)}, case_body); utils::Vector body{ case_stmt, @@ -39,7 +36,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Switch) { }; auto* cond = Expr("cond"); - auto* s = create(cond, body); + auto* s = Switch(cond, body); WrapInFunction(s); GeneratorImpl& gen = Build(); @@ -58,5 +55,28 @@ TEST_F(WgslGeneratorImplTest, Emit_Switch) { )"); } +TEST_F(WgslGeneratorImplTest, Emit_Switch_MixedDefault) { + GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate); + + auto* def_body = Block(create()); + auto* def = Case(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()}, def_body); + + auto* cond = Expr("cond"); + auto* s = Switch(cond, utils::Vector{def}); + WrapInFunction(s); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(s)) << gen.error(); + EXPECT_EQ(gen.result(), R"( switch(cond) { + case 5i, default: { + break; + } + } +)"); +} + } // namespace } // namespace tint::writer::wgsl diff --git a/test/tint/statements/switch/case_default.wgsl b/test/tint/statements/switch/case_default.wgsl new file mode 100644 index 0000000000..c7ffc49fbe --- /dev/null +++ b/test/tint/statements/switch/case_default.wgsl @@ -0,0 +1,16 @@ +@compute @workgroup_size(1) +fn f() { + var i : i32; + var result : i32; + switch(i) { + case default: { + result = 10; + } + case 1: { + result = 22; + } + case 2: { + result = 33; + } + } +} diff --git a/test/tint/statements/switch/case_default.wgsl.expected.dxc.hlsl b/test/tint/statements/switch/case_default.wgsl.expected.dxc.hlsl new file mode 100644 index 0000000000..059ec9847c --- /dev/null +++ b/test/tint/statements/switch/case_default.wgsl.expected.dxc.hlsl @@ -0,0 +1,20 @@ +[numthreads(1, 1, 1)] +void f() { + int i = 0; + int result = 0; + switch(i) { + default: { + result = 10; + break; + } + case 1: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } + return; +} diff --git a/test/tint/statements/switch/case_default.wgsl.expected.fxc.hlsl b/test/tint/statements/switch/case_default.wgsl.expected.fxc.hlsl new file mode 100644 index 0000000000..059ec9847c --- /dev/null +++ b/test/tint/statements/switch/case_default.wgsl.expected.fxc.hlsl @@ -0,0 +1,20 @@ +[numthreads(1, 1, 1)] +void f() { + int i = 0; + int result = 0; + switch(i) { + default: { + result = 10; + break; + } + case 1: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } + return; +} diff --git a/test/tint/statements/switch/case_default.wgsl.expected.glsl b/test/tint/statements/switch/case_default.wgsl.expected.glsl new file mode 100644 index 0000000000..3da9112503 --- /dev/null +++ b/test/tint/statements/switch/case_default.wgsl.expected.glsl @@ -0,0 +1,26 @@ +#version 310 es + +void f() { + int i = 0; + int result = 0; + switch(i) { + default: { + result = 10; + break; + } + case 1: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } +} + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +void main() { + f(); + return; +} diff --git a/test/tint/statements/switch/case_default.wgsl.expected.msl b/test/tint/statements/switch/case_default.wgsl.expected.msl new file mode 100644 index 0000000000..a76392fad7 --- /dev/null +++ b/test/tint/statements/switch/case_default.wgsl.expected.msl @@ -0,0 +1,23 @@ +#include + +using namespace metal; +kernel void f() { + int i = 0; + int result = 0; + switch(i) { + default: { + result = 10; + break; + } + case 1: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } + return; +} + diff --git a/test/tint/statements/switch/case_default.wgsl.expected.spvasm b/test/tint/statements/switch/case_default.wgsl.expected.spvasm new file mode 100644 index 0000000000..6402c562f8 --- /dev/null +++ b/test/tint/statements/switch/case_default.wgsl.expected.spvasm @@ -0,0 +1,39 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 18 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %f "f" + OpExecutionMode %f LocalSize 1 1 1 + OpName %f "f" + OpName %i "i" + OpName %result "result" + %void = OpTypeVoid + %1 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %8 = OpConstantNull %int + %int_10 = OpConstant %int 10 + %int_22 = OpConstant %int 22 + %int_33 = OpConstant %int 33 + %f = OpFunction %void None %1 + %4 = OpLabel + %i = OpVariable %_ptr_Function_int Function %8 + %result = OpVariable %_ptr_Function_int Function %8 + %11 = OpLoad %int %i + OpSelectionMerge %10 None + OpSwitch %11 %12 1 %13 2 %14 + %12 = OpLabel + OpStore %result %int_10 + OpBranch %10 + %13 = OpLabel + OpStore %result %int_22 + OpBranch %10 + %14 = OpLabel + OpStore %result %int_33 + OpBranch %10 + %10 = OpLabel + OpReturn + OpFunctionEnd diff --git a/test/tint/statements/switch/case_default.wgsl.expected.wgsl b/test/tint/statements/switch/case_default.wgsl.expected.wgsl new file mode 100644 index 0000000000..792708aaeb --- /dev/null +++ b/test/tint/statements/switch/case_default.wgsl.expected.wgsl @@ -0,0 +1,16 @@ +@compute @workgroup_size(1) +fn f() { + var i : i32; + var result : i32; + switch(i) { + default: { + result = 10; + } + case 1: { + result = 22; + } + case 2: { + result = 33; + } + } +} diff --git a/test/tint/statements/switch/case_default_mixed.wgsl b/test/tint/statements/switch/case_default_mixed.wgsl new file mode 100644 index 0000000000..e28b24acc6 --- /dev/null +++ b/test/tint/statements/switch/case_default_mixed.wgsl @@ -0,0 +1,16 @@ +@compute @workgroup_size(1) +fn f() { + var i : i32; + var result : i32; + switch(i) { + case 0: { + result = 10; + } + case 1, default: { + result = 22; + } + case 2: { + result = 33; + } + } +} diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.dxc.hlsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.dxc.hlsl new file mode 100644 index 0000000000..b15ec5bb77 --- /dev/null +++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.dxc.hlsl @@ -0,0 +1,21 @@ +[numthreads(1, 1, 1)] +void f() { + int i = 0; + int result = 0; + switch(i) { + case 0: { + result = 10; + break; + } + case 1: + default: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } + return; +} diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.fxc.hlsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.fxc.hlsl new file mode 100644 index 0000000000..b15ec5bb77 --- /dev/null +++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.fxc.hlsl @@ -0,0 +1,21 @@ +[numthreads(1, 1, 1)] +void f() { + int i = 0; + int result = 0; + switch(i) { + case 0: { + result = 10; + break; + } + case 1: + default: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } + return; +} diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.glsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.glsl new file mode 100644 index 0000000000..cbd24c0c69 --- /dev/null +++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.glsl @@ -0,0 +1,27 @@ +#version 310 es + +void f() { + int i = 0; + int result = 0; + switch(i) { + case 0: { + result = 10; + break; + } + case 1: + default: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } +} + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +void main() { + f(); + return; +} diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.msl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.msl new file mode 100644 index 0000000000..f36caad647 --- /dev/null +++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.msl @@ -0,0 +1,24 @@ +#include + +using namespace metal; +kernel void f() { + int i = 0; + int result = 0; + switch(i) { + case 0: { + result = 10; + break; + } + case 1: + default: { + result = 22; + break; + } + case 2: { + result = 33; + break; + } + } + return; +} + diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.spvasm b/test/tint/statements/switch/case_default_mixed.wgsl.expected.spvasm new file mode 100644 index 0000000000..a212876225 --- /dev/null +++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.spvasm @@ -0,0 +1,39 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 18 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %f "f" + OpExecutionMode %f LocalSize 1 1 1 + OpName %f "f" + OpName %i "i" + OpName %result "result" + %void = OpTypeVoid + %1 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %8 = OpConstantNull %int + %int_10 = OpConstant %int 10 + %int_22 = OpConstant %int 22 + %int_33 = OpConstant %int 33 + %f = OpFunction %void None %1 + %4 = OpLabel + %i = OpVariable %_ptr_Function_int Function %8 + %result = OpVariable %_ptr_Function_int Function %8 + %11 = OpLoad %int %i + OpSelectionMerge %10 None + OpSwitch %11 %12 0 %13 1 %12 2 %14 + %13 = OpLabel + OpStore %result %int_10 + OpBranch %10 + %12 = OpLabel + OpStore %result %int_22 + OpBranch %10 + %14 = OpLabel + OpStore %result %int_33 + OpBranch %10 + %10 = OpLabel + OpReturn + OpFunctionEnd diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.wgsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.wgsl new file mode 100644 index 0000000000..920b432b89 --- /dev/null +++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.wgsl @@ -0,0 +1,16 @@ +@compute @workgroup_size(1) +fn f() { + var i : i32; + var result : i32; + switch(i) { + case 0: { + result = 10; + } + case 1, default: { + result = 22; + } + case 2: { + result = 33; + } + } +}