diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc index 8aca58c083..098921776d 100644 --- a/src/ast/case_statement.cc +++ b/src/ast/case_statement.cc @@ -19,15 +19,14 @@ namespace ast { CaseStatement::CaseStatement() : Statement() {} -CaseStatement::CaseStatement(std::unique_ptr condition, - StatementList body) - : Statement(), condition_(std::move(condition)), body_(std::move(body)) {} +CaseStatement::CaseStatement(CaseSelectorList conditions, StatementList body) + : Statement(), conditions_(std::move(conditions)), body_(std::move(body)) {} CaseStatement::CaseStatement(const Source& source, - std::unique_ptr condition, + CaseSelectorList conditions, StatementList body) : Statement(source), - condition_(std::move(condition)), + conditions_(std::move(conditions)), body_(std::move(body)) {} CaseStatement::CaseStatement(CaseStatement&&) = default; @@ -52,7 +51,16 @@ void CaseStatement::to_str(std::ostream& out, size_t indent) const { if (IsDefault()) { out << "Default{" << std::endl; } else { - out << "Case " << condition_->to_str() << "{" << std::endl; + out << "Case "; + bool first = true; + for (const auto& lit : conditions_) { + if (!first) + out << ", "; + + first = false; + out << lit->to_str(); + } + out << "{" << std::endl; } for (const auto& stmt : body_) diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h index 5d2c24ec3f..11dcd33989 100644 --- a/src/ast/case_statement.h +++ b/src/ast/case_statement.h @@ -27,35 +27,38 @@ namespace tint { namespace ast { +/// A list of case literals +using CaseSelectorList = std::vector>; + /// A case statement class CaseStatement : public Statement { public: /// Constructor CaseStatement(); /// Constructor - /// @param condition the case condition + /// @param conditions the case conditions /// @param body the case body - CaseStatement(std::unique_ptr condition, StatementList body); + CaseStatement(CaseSelectorList conditions, StatementList body); /// Constructor /// @param source the source information - /// @param condition the case condition + /// @param conditions the case conditions /// @param body the case body CaseStatement(const Source& source, - std::unique_ptr condition, + CaseSelectorList conditions, StatementList body); /// Move constructor CaseStatement(CaseStatement&&); ~CaseStatement() override; - /// Sets the condition for the case statement - /// @param condition the condition to set - void set_condition(std::unique_ptr condition) { - condition_ = std::move(condition); + /// Sets the conditions for the case statement + /// @param conditions the conditions to set + void set_conditions(CaseSelectorList conditions) { + conditions_ = std::move(conditions); } - /// @returns the case condition or nullptr if none set - Literal* condition() const { return condition_.get(); } + /// @returns the case condition, empty if none set + const CaseSelectorList& conditions() const { return conditions_; } /// @returns true if this is a default statement - bool IsDefault() const { return condition_ == nullptr; } + bool IsDefault() const { return conditions_.empty(); } /// Sets the case body /// @param body the case body @@ -77,7 +80,7 @@ class CaseStatement : public Statement { private: CaseStatement(const CaseStatement&) = delete; - std::unique_ptr condition_; + CaseSelectorList conditions_; StatementList body_; }; diff --git a/src/ast/case_statement_test.cc b/src/ast/case_statement_test.cc index 4e841ace72..8653cd7077 100644 --- a/src/ast/case_statement_test.cc +++ b/src/ast/case_statement_test.cc @@ -17,8 +17,10 @@ #include "gtest/gtest.h" #include "src/ast/bool_literal.h" #include "src/ast/if_statement.h" +#include "src/ast/int_literal.h" #include "src/ast/kill_statement.h" #include "src/ast/type/bool_type.h" +#include "src/ast/type/i32_type.h" namespace tint { namespace ast { @@ -28,22 +30,28 @@ using CaseStatementTest = testing::Test; TEST_F(CaseStatementTest, Creation) { ast::type::BoolType bool_type; - auto b = std::make_unique(&bool_type, true); + + CaseSelectorList b; + b.push_back(std::make_unique(&bool_type, true)); + StatementList stmts; stmts.push_back(std::make_unique()); - auto* bool_ptr = b.get(); + auto* bool_ptr = b.back().get(); auto* kill_ptr = stmts[0].get(); CaseStatement c(std::move(b), std::move(stmts)); - EXPECT_EQ(c.condition(), bool_ptr); + ASSERT_EQ(c.conditions().size(), 1); + EXPECT_EQ(c.conditions()[0].get(), bool_ptr); ASSERT_EQ(c.body().size(), 1u); EXPECT_EQ(c.body()[0].get(), kill_ptr); } TEST_F(CaseStatementTest, Creation_WithSource) { ast::type::BoolType bool_type; - auto b = std::make_unique(&bool_type, true); + CaseSelectorList b; + b.push_back(std::make_unique(&bool_type, true)); + StatementList stmts; stmts.push_back(std::make_unique()); @@ -64,9 +72,11 @@ TEST_F(CaseStatementTest, IsDefault_WithoutCondition) { TEST_F(CaseStatementTest, IsDefault_WithCondition) { ast::type::BoolType bool_type; - auto b = std::make_unique(&bool_type, true); + CaseSelectorList b; + b.push_back(std::make_unique(&bool_type, true)); + CaseStatement c; - c.set_condition(std::move(b)); + c.set_conditions(std::move(b)); EXPECT_FALSE(c.IsDefault()); } @@ -82,7 +92,9 @@ TEST_F(CaseStatementTest, IsValid) { TEST_F(CaseStatementTest, IsValid_NullBodyStatement) { ast::type::BoolType bool_type; - auto b = std::make_unique(&bool_type, true); + CaseSelectorList b; + b.push_back(std::make_unique(&bool_type, true)); + StatementList stmts; stmts.push_back(std::make_unique()); stmts.push_back(nullptr); @@ -93,20 +105,24 @@ TEST_F(CaseStatementTest, IsValid_NullBodyStatement) { TEST_F(CaseStatementTest, IsValid_InvalidBodyStatement) { ast::type::BoolType bool_type; - auto b = std::make_unique(&bool_type, true); + CaseSelectorList b; + b.push_back(std::make_unique(&bool_type, true)); + StatementList stmts; stmts.push_back(std::make_unique()); - CaseStatement c(std::move(b), std::move(stmts)); + CaseStatement c({std::move(b)}, std::move(stmts)); EXPECT_FALSE(c.IsValid()); } TEST_F(CaseStatementTest, ToStr_WithCondition) { ast::type::BoolType bool_type; - auto b = std::make_unique(&bool_type, true); + CaseSelectorList b; + b.push_back(std::make_unique(&bool_type, true)); + StatementList stmts; stmts.push_back(std::make_unique()); - CaseStatement c(std::move(b), std::move(stmts)); + CaseStatement c({std::move(b)}, std::move(stmts)); std::ostringstream out; c.to_str(out, 2); @@ -116,10 +132,28 @@ TEST_F(CaseStatementTest, ToStr_WithCondition) { )"); } +TEST_F(CaseStatementTest, ToStr_WithMultipleConditions) { + ast::type::I32Type i32; + + CaseSelectorList b; + b.push_back(std::make_unique(&i32, 1)); + b.push_back(std::make_unique(&i32, 2)); + StatementList stmts; + stmts.push_back(std::make_unique()); + CaseStatement c(std::move(b), std::move(stmts)); + + std::ostringstream out; + c.to_str(out, 2); + EXPECT_EQ(out.str(), R"( Case 1, 2{ + Kill{} + } +)"); +} + TEST_F(CaseStatementTest, ToStr_WithoutCondition) { StatementList stmts; stmts.push_back(std::make_unique()); - CaseStatement c(nullptr, std::move(stmts)); + CaseStatement c(CaseSelectorList{}, std::move(stmts)); std::ostringstream out; c.to_str(out, 2); diff --git a/src/ast/switch_statement_test.cc b/src/ast/switch_statement_test.cc index c4adb28bec..512df39b29 100644 --- a/src/ast/switch_statement_test.cc +++ b/src/ast/switch_statement_test.cc @@ -30,7 +30,8 @@ using SwitchStatementTest = testing::Test; TEST_F(SwitchStatementTest, Creation) { ast::type::BoolType bool_type; - auto lit = std::make_unique(&bool_type, true); + CaseSelectorList lit; + lit.push_back(std::make_unique(&bool_type, true)); auto ident = std::make_unique("ident"); CaseStatementList body; body.push_back( @@ -61,7 +62,8 @@ TEST_F(SwitchStatementTest, IsSwitch) { TEST_F(SwitchStatementTest, IsValid) { ast::type::BoolType bool_type; - auto lit = std::make_unique(&bool_type, true); + CaseSelectorList lit; + lit.push_back(std::make_unique(&bool_type, true)); auto ident = std::make_unique("ident"); CaseStatementList body; body.push_back( @@ -73,7 +75,8 @@ TEST_F(SwitchStatementTest, IsValid) { TEST_F(SwitchStatementTest, IsValid_Null_Condition) { ast::type::BoolType bool_type; - auto lit = std::make_unique(&bool_type, true); + CaseSelectorList lit; + lit.push_back(std::make_unique(&bool_type, true)); CaseStatementList body; body.push_back( std::make_unique(std::move(lit), StatementList())); @@ -85,7 +88,8 @@ TEST_F(SwitchStatementTest, IsValid_Null_Condition) { TEST_F(SwitchStatementTest, IsValid_Invalid_Condition) { ast::type::BoolType bool_type; - auto lit = std::make_unique(&bool_type, true); + CaseSelectorList lit; + lit.push_back(std::make_unique(&bool_type, true)); auto ident = std::make_unique(""); CaseStatementList body; body.push_back( @@ -97,7 +101,8 @@ TEST_F(SwitchStatementTest, IsValid_Invalid_Condition) { TEST_F(SwitchStatementTest, IsValid_Null_BodyStatement) { ast::type::BoolType bool_type; - auto lit = std::make_unique(&bool_type, true); + CaseSelectorList lit; + lit.push_back(std::make_unique(&bool_type, true)); auto ident = std::make_unique("ident"); CaseStatementList body; body.push_back( @@ -115,8 +120,8 @@ TEST_F(SwitchStatementTest, IsValid_Invalid_BodyStatement) { case_body.push_back(nullptr); CaseStatementList body; - body.push_back( - std::make_unique(nullptr, std::move(case_body))); + body.push_back(std::make_unique(CaseSelectorList{}, + std::move(case_body))); SwitchStatement stmt(std::move(ident), std::move(body)); EXPECT_FALSE(stmt.IsValid()); @@ -138,7 +143,8 @@ TEST_F(SwitchStatementTest, ToStr_Empty) { TEST_F(SwitchStatementTest, ToStr) { ast::type::BoolType bool_type; - auto lit = std::make_unique(&bool_type, true); + CaseSelectorList lit; + lit.push_back(std::make_unique(&bool_type, true)); auto ident = std::make_unique("ident"); CaseStatementList body; body.push_back( diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc index c02a375a7b..e81ef0c37d 100644 --- a/src/reader/wgsl/parser_impl.cc +++ b/src/reader/wgsl/parser_impl.cc @@ -1800,7 +1800,7 @@ std::unique_ptr ParserImpl::switch_stmt() { } // switch_body -// : CASE const_literal COLON BRACKET_LEFT case_body BRACKET_RIGHT +// : CASE case_selectors COLON BRACKET_LEFT case_body BRACKET_RIGHT // | DEFAULT COLON BRACKET_LEFT case_body BRACKET_RIGHT std::unique_ptr ParserImpl::switch_body() { auto t = peek(); @@ -1813,14 +1813,14 @@ std::unique_ptr ParserImpl::switch_body() { auto stmt = std::make_unique(); stmt->set_source(source); if (t.IsCase()) { - auto cond = const_literal(); + auto cond = case_selectors(); if (has_error()) return nullptr; - if (cond == nullptr) { + if (cond.empty()) { set_error(peek(), "unable to parse case conditional"); return nullptr; } - stmt->set_condition(std::move(cond)); + stmt->set_conditions(std::move(cond)); } t = next(); @@ -1850,6 +1850,24 @@ std::unique_ptr ParserImpl::switch_body() { return stmt; } +// case_selectors +// : const_literal (COMMA const_literal)* +ast::CaseSelectorList ParserImpl::case_selectors() { + ast::CaseSelectorList selectors; + + for (;;) { + auto cond = const_literal(); + if (has_error()) + return {}; + if (cond == nullptr) + break; + + selectors.push_back(std::move(cond)); + } + + return selectors; +} + // case_body // : // | statement case_body diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h index 826ea8a927..718f3c20d6 100644 --- a/src/reader/wgsl/parser_impl.h +++ b/src/reader/wgsl/parser_impl.h @@ -23,6 +23,7 @@ #include "src/ast/assignment_statement.h" #include "src/ast/builtin.h" +#include "src/ast/case_statement.h" #include "src/ast/constructor_expression.h" #include "src/ast/else_statement.h" #include "src/ast/entry_point.h" @@ -216,6 +217,9 @@ class ParserImpl { /// Parses a `switch_body` grammar element /// @returns the parsed statement or nullptr std::unique_ptr switch_body(); + /// Parses a `case_selectors` grammar element + /// @returns the list of literals + ast::CaseSelectorList case_selectors(); /// Parses a `case_body` grammar element /// @returns the parsed statements ast::StatementList case_body(); diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 5104709932..bd4a04fc9b 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -165,8 +165,9 @@ TEST_F(TypeDeterminerTest, Stmt_Case) { body.push_back(std::make_unique(std::move(lhs), std::move(rhs))); - ast::CaseStatement cse(std::make_unique(&i32, 3), - std::move(body)); + ast::CaseSelectorList lit; + lit.push_back(std::make_unique(&i32, 3)); + ast::CaseStatement cse(std::move(lit), std::move(body)); EXPECT_TRUE(td()->DetermineResultType(&cse)); ASSERT_NE(lhs_ptr->result_type(), nullptr); @@ -355,9 +356,12 @@ TEST_F(TypeDeterminerTest, Stmt_Switch) { body.push_back(std::make_unique(std::move(lhs), std::move(rhs))); + ast::CaseSelectorList lit; + lit.push_back(std::make_unique(&i32, 3)); + ast::CaseStatementList cases; - cases.push_back(std::make_unique( - std::make_unique(&i32, 3), std::move(body))); + cases.push_back( + std::make_unique(std::move(lit), std::move(body))); ast::SwitchStatement stmt(std::make_unique( std::make_unique(&i32, 2)), diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index c83759e5c0..50dad688b8 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -724,9 +724,18 @@ bool GeneratorImpl::EmitCase(ast::CaseStatement* stmt) { } else { out_ << "case "; - if (!EmitLiteral(stmt->condition())) { - return false; + bool first = true; + for (const auto& lit : stmt->conditions()) { + if (!first) { + out_ << ", "; + } + + first = false; + if (!EmitLiteral(lit.get())) { + return false; + } } + out_ << ":"; } diff --git a/src/writer/wgsl/generator_impl_case_test.cc b/src/writer/wgsl/generator_impl_case_test.cc index d1c3fd197d..5a1a906036 100644 --- a/src/writer/wgsl/generator_impl_case_test.cc +++ b/src/writer/wgsl/generator_impl_case_test.cc @@ -31,12 +31,13 @@ using GeneratorImplTest = testing::Test; TEST_F(GeneratorImplTest, Emit_Case) { ast::type::I32Type i32; - auto cond = std::make_unique(&i32, 5); ast::StatementList body; body.push_back(std::make_unique()); - ast::CaseStatement c(std::move(cond), std::move(body)); + ast::CaseSelectorList lit; + lit.push_back(std::make_unique(&i32, 5)); + ast::CaseStatement c(std::move(lit), std::move(body)); GeneratorImpl g; g.increment_indent(); @@ -48,6 +49,27 @@ TEST_F(GeneratorImplTest, Emit_Case) { )"); } +TEST_F(GeneratorImplTest, Emit_Case_MultipleSelectors) { + ast::type::I32Type i32; + + ast::StatementList body; + body.push_back(std::make_unique()); + + ast::CaseSelectorList lit; + lit.push_back(std::make_unique(&i32, 5)); + lit.push_back(std::make_unique(&i32, 6)); + ast::CaseStatement c(std::move(lit), std::move(body)); + + GeneratorImpl g; + g.increment_indent(); + + ASSERT_TRUE(g.EmitCase(&c)) << g.error(); + EXPECT_EQ(g.result(), R"( case 5, 6: { + break; + } +)"); +} + TEST_F(GeneratorImplTest, Emit_Case_Default) { ast::CaseStatement c; diff --git a/src/writer/wgsl/generator_impl_switch_test.cc b/src/writer/wgsl/generator_impl_switch_test.cc index b784f21aef..b149541206 100644 --- a/src/writer/wgsl/generator_impl_switch_test.cc +++ b/src/writer/wgsl/generator_impl_switch_test.cc @@ -37,7 +37,9 @@ TEST_F(GeneratorImplTest, Emit_Switch) { def->set_body(std::move(def_body)); ast::type::I32Type i32; - auto case_val = std::make_unique(&i32, 5); + ast::CaseSelectorList case_val; + case_val.push_back(std::make_unique(&i32, 5)); + ast::StatementList case_body; case_body.push_back(std::make_unique());