diff --git a/src/tint/ast/case_statement.cc b/src/tint/ast/case_statement.cc index 9f1c20e8b7..3125f0516f 100644 --- a/src/tint/ast/case_statement.cc +++ b/src/tint/ast/case_statement.cc @@ -25,7 +25,7 @@ namespace tint::ast { CaseStatement::CaseStatement(ProgramID pid, NodeID nid, const Source& src, - utils::VectorRef s, + utils::VectorRef s, const BlockStatement* b) : Base(pid, nid, src), selectors(std::move(s)), body(b) { TINT_ASSERT(AST, body); diff --git a/src/tint/ast/case_statement.h b/src/tint/ast/case_statement.h index 47d2097e52..eda9c010a2 100644 --- a/src/tint/ast/case_statement.h +++ b/src/tint/ast/case_statement.h @@ -18,7 +18,7 @@ #include #include "src/tint/ast/block_statement.h" -#include "src/tint/ast/int_literal_expression.h" +#include "src/tint/ast/expression.h" namespace tint::ast { @@ -34,7 +34,7 @@ class CaseStatement final : public Castable { CaseStatement(ProgramID pid, NodeID nid, const Source& src, - utils::VectorRef selectors, + utils::VectorRef selectors, const BlockStatement* body); /// Move constructor CaseStatement(CaseStatement&&); @@ -50,7 +50,7 @@ class CaseStatement final : public Castable { const CaseStatement* Clone(CloneContext* ctx) const override; /// The case selectors, empty if none set - const utils::Vector selectors; + const utils::Vector selectors; /// The case body const BlockStatement* const body; diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index f300710f68..892cf27a8b 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -54,6 +54,7 @@ #include "src/tint/ast/if_statement.h" #include "src/tint/ast/increment_decrement_statement.h" #include "src/tint/ast/index_accessor_expression.h" +#include "src/tint/ast/int_literal_expression.h" #include "src/tint/ast/interpolate_attribute.h" #include "src/tint/ast/invariant_attribute.h" #include "src/tint/ast/let.h" @@ -2846,7 +2847,7 @@ class ProgramBuilder { /// @param body the case body /// @returns the case statement pointer const ast::CaseStatement* Case(const Source& source, - utils::VectorRef selectors, + utils::VectorRef selectors, const ast::BlockStatement* body = nullptr) { return create(source, std::move(selectors), body ? body : Block()); } @@ -2855,7 +2856,7 @@ class ProgramBuilder { /// @param selectors list of selectors /// @param body the case body /// @returns the case statement pointer - const ast::CaseStatement* Case(utils::VectorRef selectors, + const ast::CaseStatement* Case(utils::VectorRef selectors, const ast::BlockStatement* body = nullptr) { return create(std::move(selectors), body ? body : Block()); } @@ -2864,7 +2865,7 @@ class ProgramBuilder { /// @param selector a single case selector /// @param body the case body /// @returns the case statement pointer - const ast::CaseStatement* Case(const ast::IntLiteralExpression* selector, + const ast::CaseStatement* Case(const ast::Expression* selector, const ast::BlockStatement* body = nullptr) { return Case(utils::Vector{selector}, body); } diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc index 9c38818f2b..2cc35d81ee 100644 --- a/src/tint/reader/spirv/function.cc +++ b/src/tint/reader/spirv/function.cc @@ -3024,7 +3024,7 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) { for (size_t i = last_clause_index;; --i) { // Create a list of integer literals for the selector values leading to // this case clause. - utils::Vector selectors; + utils::Vector selectors; const bool has_selectors = clause_heads[i]->case_values.has_value(); if (has_selectors) { auto values = clause_heads[i]->case_values.value(); diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc index d8d799415f..cb8e217465 100644 --- a/src/tint/reader/wgsl/parser_impl.cc +++ b/src/tint/reader/wgsl/parser_impl.cc @@ -2148,21 +2148,19 @@ Maybe ParserImpl::switch_body() { } // case_selectors -// : const_literal (COMMA const_literal)* COMMA? +// : expression (COMMA expression)* COMMA? Expect ParserImpl::expect_case_selectors() { CaseSelectorList selectors; while (continue_parsing()) { - auto cond = const_literal(); - if (cond.errored) { + auto expr = expression(); + if (expr.errored) { return Failure::kErrored; - } else if (!cond.matched) { - break; - } else if (!cond->Is()) { - return add_error(cond.value->source, "invalid case selector must be an integer value"); } - - selectors.Push(cond.value->As()); + if (!expr.matched) { + break; + } + selectors.Push(expr.value); if (!match(Token::Type::kComma)) { break; diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h index 60b94deb86..3dbd7805ae 100644 --- a/src/tint/reader/wgsl/parser_impl.h +++ b/src/tint/reader/wgsl/parser_impl.h @@ -74,7 +74,7 @@ class ParserImpl { /// Pre-determined small vector sizes for AST pointers //! @cond Doxygen_Suppress using AttributeList = utils::Vector; - using CaseSelectorList = utils::Vector; + using CaseSelectorList = utils::Vector; using CaseStatementList = utils::Vector; using ExpressionList = utils::Vector; using ParameterList = utils::Vector; diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc index 9416a4d451..657b522d8c 100644 --- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc +++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc @@ -1339,14 +1339,6 @@ fn f() { switch(1) { case ^: } } )"); } -TEST_F(ParserImplErrorTest, SwitchStmtInvalidCase2) { - EXPECT("fn f() { switch(1) { case false: } }", - R"(test.wgsl:1:27 error: invalid case selector must be an integer value -fn f() { switch(1) { case false: } } - ^^^^^ -)"); -} - TEST_F(ParserImplErrorTest, SwitchStmtCaseMissingLBrace) { EXPECT("fn f() { switch(1) { case 1: } }", R"(test.wgsl:1:30 error: expected '{' for case statement diff --git a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc index 076b1c3f1a..2b8b0bd70d 100644 --- a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc +++ b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc @@ -26,10 +26,42 @@ TEST_F(ParserImplTest, SwitchBody_Case) { ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); EXPECT_FALSE(e->IsDefault()); + auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 1u); - EXPECT_EQ(stmt->selectors[0]->value, 1); - EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone); + ASSERT_TRUE(stmt->selectors[0]->Is()); + + auto* expr = stmt->selectors[0]->As(); + EXPECT_EQ(expr->value, 1); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); + ASSERT_EQ(e->body->statements.Length(), 1u); + EXPECT_TRUE(e->body->statements[0]->Is()); +} + +TEST_F(ParserImplTest, SwitchBody_Case_Expression) { + auto p = parser("case 1 + 2 { a = 4; }"); + auto e = p->switch_body(); + EXPECT_FALSE(p->has_error()) << p->error(); + EXPECT_TRUE(e.matched); + EXPECT_FALSE(e.errored); + ASSERT_NE(e.value, nullptr); + ASSERT_TRUE(e->Is()); + EXPECT_FALSE(e->IsDefault()); + + auto* stmt = e->As(); + ASSERT_EQ(stmt->selectors.Length(), 1u); + ASSERT_TRUE(stmt->selectors[0]->Is()); + auto* expr = stmt->selectors[0]->As(); + + EXPECT_EQ(ast::BinaryOp::kAdd, expr->op); + auto* v = expr->lhs->As(); + ASSERT_NE(nullptr, v); + EXPECT_EQ(v->value, 1u); + + v = expr->rhs->As(); + ASSERT_NE(nullptr, v); + EXPECT_EQ(v->value, 2u); + ASSERT_EQ(e->body->statements.Length(), 1u); EXPECT_TRUE(e->body->statements[0]->Is()); } @@ -43,10 +75,14 @@ TEST_F(ParserImplTest, SwitchBody_Case_WithColon) { ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); EXPECT_FALSE(e->IsDefault()); + auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 1u); - EXPECT_EQ(stmt->selectors[0]->value, 1); - EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone); + ASSERT_TRUE(stmt->selectors[0]->Is()); + + auto* expr = stmt->selectors[0]->As(); + EXPECT_EQ(expr->value, 1); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); ASSERT_EQ(e->body->statements.Length(), 1u); EXPECT_TRUE(e->body->statements[0]->Is()); } @@ -62,9 +98,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma) { EXPECT_FALSE(e->IsDefault()); auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 2u); - EXPECT_EQ(stmt->selectors[0]->value, 1); - EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone); - EXPECT_EQ(stmt->selectors[1]->value, 2); + ASSERT_TRUE(stmt->selectors[0]->Is()); + + auto* expr = stmt->selectors[0]->As(); + EXPECT_EQ(expr->value, 1); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); + + ASSERT_TRUE(stmt->selectors[1]->Is()); + expr = stmt->selectors[1]->As(); + EXPECT_EQ(expr->value, 2); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) { @@ -76,15 +119,23 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) { ASSERT_NE(e.value, nullptr); ASSERT_TRUE(e->Is()); EXPECT_FALSE(e->IsDefault()); + auto* stmt = e->As(); ASSERT_EQ(stmt->selectors.Length(), 2u); - EXPECT_EQ(stmt->selectors[0]->value, 1); - EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone); - EXPECT_EQ(stmt->selectors[1]->value, 2); + ASSERT_TRUE(stmt->selectors[0]->Is()); + + auto* expr = stmt->selectors[0]->As(); + EXPECT_EQ(expr->value, 1); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); + + ASSERT_TRUE(stmt->selectors[1]->Is()); + expr = stmt->selectors[1]->As(); + EXPECT_EQ(expr->value, 2); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } -TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) { - auto p = parser("case a == 4: { a = 4; }"); +TEST_F(ParserImplTest, SwitchBody_Case_Invalid) { + auto p = parser("case if: { a = 4; }"); auto e = p->switch_body(); EXPECT_TRUE(p->has_error()); EXPECT_TRUE(e.errored); @@ -93,16 +144,6 @@ TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) { EXPECT_EQ(p->error(), "1:6: unable to parse case selectors"); } -TEST_F(ParserImplTest, SwitchBody_Case_InvalidSelector_bool) { - auto p = parser("case true: { a = 4; }"); - auto e = p->switch_body(); - EXPECT_TRUE(p->has_error()); - EXPECT_TRUE(e.errored); - EXPECT_FALSE(e.matched); - EXPECT_EQ(e.value, nullptr); - EXPECT_EQ(p->error(), "1:6: invalid case selector must be an integer value"); -} - TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) { auto p = parser("case: { a = 4; }"); auto e = p->switch_body(); @@ -164,10 +205,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors) { EXPECT_FALSE(e->IsDefault()); ASSERT_EQ(e->body->statements.Length(), 0u); ASSERT_EQ(e->selectors.Length(), 2u); - ASSERT_EQ(e->selectors[0]->value, 1); - EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone); - ASSERT_EQ(e->selectors[1]->value, 2); - EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone); + ASSERT_TRUE(e->selectors[0]->Is()); + + auto* expr = e->selectors[0]->As(); + ASSERT_EQ(expr->value, 1); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); + + ASSERT_TRUE(e->selectors[1]->Is()); + expr = e->selectors[1]->As(); + ASSERT_EQ(expr->value, 2); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) { @@ -181,10 +228,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) { EXPECT_FALSE(e->IsDefault()); ASSERT_EQ(e->body->statements.Length(), 0u); ASSERT_EQ(e->selectors.Length(), 2u); - ASSERT_EQ(e->selectors[0]->value, 1); - EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone); - ASSERT_EQ(e->selectors[1]->value, 2); - EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone); + ASSERT_TRUE(e->selectors[0]->Is()); + + auto* expr = e->selectors[0]->As(); + ASSERT_EQ(expr->value, 1); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); + + ASSERT_TRUE(e->selectors[1]->Is()); + expr = e->selectors[1]->As(); + ASSERT_EQ(expr->value, 2); + EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone); } TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectorsMissingComma) { diff --git a/src/tint/resolver/control_block_validation_test.cc b/src/tint/resolver/control_block_validation_test.cc index 9b0d2899bd..403d0bc4ea 100644 --- a/src/tint/resolver/control_block_validation_test.cc +++ b/src/tint/resolver/control_block_validation_test.cc @@ -25,7 +25,7 @@ namespace { class ResolverControlBlockValidationTest : public TestHelper, public testing::Test {}; -TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpressionNoneIntegerType_Fail) { +TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_F32) { // var a : f32 = 3.14; // switch (a) { // default: {} @@ -43,6 +43,24 @@ TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpressionNoneIntegerTy "scalar integer type"); } +TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_bool) { + // var a : bool = true; + // switch (a) { + // default: {} + // } + auto* var = Var("a", ty.bool_(), Expr(false)); + + auto* block = Block(Decl(var), Switch(Expr(Source{{12, 34}}, "a"), // + DefaultCase())); + + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: switch statement selector expression must be of a " + "scalar integer type"); +} + TEST_F(ResolverControlBlockValidationTest, SwitchWithoutDefault_Fail) { // var a : i32 = 2; // switch (a) { @@ -213,8 +231,8 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT // } auto* var = Var("a", ty.i32(), Expr(2_i)); - auto* block = Block(Decl(var), Switch("a", // - Case(Source{{12, 34}}, utils::Vector{Expr(1_u)}), // + auto* block = Block(Decl(var), Switch("a", // + Case(Expr(Source{{12, 34}}, 1_u)), // DefaultCase())); WrapInFunction(block); @@ -234,7 +252,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT auto* block = Block(Decl(var), // Switch("a", // - Case(Source{{12, 34}}, utils::Vector{Expr(-1_i)}), // + Case(utils::Vector{Expr(Source{{12, 34}}, -1_i)}), // DefaultCase())); WrapInFunction(block); @@ -332,6 +350,74 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } +TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Pass) { + // var a : i32 = 2; + // switch (a) { + // default: {} + // case 5 + 6: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + + auto* block = Block(Decl(var), // + Switch("a", // + DefaultCase(Source{{12, 34}}), // + Case(Add(5_i, 6_i)))); + WrapInFunction(block); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixI32_Abstract) { + // var a = 2; + // switch (a) { + // default: {} + // case 5i + 6i: {} + // } + auto* var = Var("a", Expr(2_a)); + + auto* block = Block(Decl(var), // + Switch("a", // + DefaultCase(Source{{12, 34}}), // + Case(Add(5_i, 6_i)))); + WrapInFunction(block); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixU32_Abstract) { + // var a = 2u; + // switch (a) { + // default: {} + // case 5 + 6: {} + // } + auto* var = Var("a", Expr(2_u)); + + auto* block = Block(Decl(var), // + Switch("a", // + DefaultCase(Source{{12, 34}}), // + Case(Add(5_a, 6_a)))); + WrapInFunction(block); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Multiple) { + // var a = 2u; + // switch (a) { + // default: {} + // case 5 + 6, 7+9, 2*4: {} + // } + auto* var = Var("a", Expr(2_u)); + + auto* block = Block(Decl(var), // + Switch("a", // + DefaultCase(Source{{12, 34}}), // + Case(utils::Vector{Add(5_u, 6_u), Add(7_u, 9_u), Mul(2_u, 4_u)}))); + WrapInFunction(block); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) { // type MyInt = u32; // var v: MyInt; @@ -349,5 +435,85 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } +TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelector_Expression_Fail) { + // var a : i32 = 2i; + // switch (a) { + // case 10i: {} + // case 5i+5i: {} + // default: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + + auto* block = Block(Decl(var), // + Switch("a", // + Case(Expr(Source{{12, 34}}, 10_i)), + Case(Add(Source{{56, 78}}, 5_i, 5_i)), DefaultCase())); + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "56:78 error: duplicate switch case '10'\n" + "12:34 note: previous case declared here"); +} + +TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSameCase_BothExpression_Fail) { + // var a : i32 = 2i; + // switch (a) { + // case 5i+5i, 6i+4i: {} + // default: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + + auto* block = Block(Decl(var), // + Switch("a", // + Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), + Add(Source{{12, 34}}, 6_i, 4_i)}), + DefaultCase())); + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: duplicate switch case '10'\n" + "56:78 note: previous case declared here"); +} + +TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSame_Case_Expression_Fail) { + // var a : i32 = 2i; + // switch (a) { + // case 5u+5u, 10i: {} + // default: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + + auto* block = Block( + Decl(var), // + Switch("a", // + Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), Expr(Source{{12, 34}}, 10_i)}), + DefaultCase())); + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: duplicate switch case '10'\n" + "56:78 note: previous case declared here"); +} + +TEST_F(ResolverControlBlockValidationTest, Switch_OverrideCondition_Fail) { + // override a : i32 = 2; + // switch (a) { + // default: {} + // } + auto* var = Var("a", ty.i32(), Expr(2_i)); + Override("b", ty.i32(), Expr(2_i)); + + auto* block = Block(Decl(var), // + Switch("a", // + Case(Expr(Source{{12, 34}}, "b")), DefaultCase())); + WrapInFunction(block); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: case selector must be a constant expression"); +} + } // namespace } // namespace tint::resolver diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index c2524732b4..b2961bdfb4 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1234,18 +1234,34 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) { }); } -sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) { +sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt, const sem::Type* ty) { auto* sem = builder_->create(stmt, current_compound_statement_, current_function_); return StatementScope(stmt, sem, [&] { sem->Selectors().reserve(stmt->selectors.Length()); for (auto* sel : stmt->selectors) { - auto* expr = Expression(sel); - if (!expr) { + ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"}; + TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); + + // The sem statement is created in the switch when attempting to determine the common + // type. + auto* materialized = Materialize(sem_.Get(sel), ty); + if (!materialized) { return false; } - sem->Selectors().emplace_back(expr); + if (!materialized->Type()->IsAnyOf()) { + AddError("case selector must be an i32 or u32 value", sel->source); + return false; + } + auto const_value = materialized->ConstantValue(); + if (!const_value) { + AddError("case selector must be a constant expression", sel->source); + return false; + } + + sem->Selectors().emplace_back(const_value); } + Mark(stmt->body); auto* body = BlockStatement(stmt->body); if (!body) { @@ -3082,27 +3098,16 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt auto* cond_ty = cond->Type()->UnwrapRef(); - utils::Vector types; - types.Push(cond_ty); - - utils::Vector cases; - cases.Reserve(stmt->body.Length()); - for (auto* case_stmt : stmt->body) { - Mark(case_stmt); - auto* c = CaseStatement(case_stmt); - if (!c) { - return false; - } - for (auto* expr : c->Selectors()) { - types.Push(expr->Type()->UnwrapRef()); - } - cases.Push(c); - behaviors.Add(c->Behaviors()); - sem->Cases().emplace_back(c); - } - // Determine the common type across all selectors and the switch expression // This must materialize to an integer scalar (non-abstract). + utils::Vector types; + types.Push(cond_ty); + for (auto* case_stmt : stmt->body) { + for (auto* expr : case_stmt->selectors) { + auto* sem_expr = Expression(expr); + types.Push(sem_expr->Type()->UnwrapRef()); + } + } auto* common_ty = sem::Type::Common(types); if (!common_ty || !common_ty->is_integer_scalar()) { // No common type found or the common type was abstract. @@ -3113,13 +3118,21 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt if (!cond) { return false; } - for (auto* c : cases) { - for (auto*& sel : c->Selectors()) { // Note: pointer reference - sel = Materialize(sel, common_ty); - if (!sel) { - return false; - } + + utils::Vector cases; + cases.Reserve(stmt->body.Length()); + for (auto* case_stmt : stmt->body) { + Mark(case_stmt); + auto* c = CaseStatement(case_stmt, common_ty); + if (!c) { + return false; } + for (auto* expr : c->Selectors()) { + types.Push(expr->Type()->UnwrapRef()); + } + cases.Push(c); + behaviors.Add(c->Behaviors()); + sem->Cases().emplace_back(c); } if (behaviors.Contains(sem::Behavior::kBreak)) { diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index c25b48e13a..e38d62d376 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -209,7 +209,7 @@ class Resolver { sem::BlockStatement* BlockStatement(const ast::BlockStatement*); sem::Statement* BreakStatement(const ast::BreakStatement*); sem::Statement* CallStatement(const ast::CallStatement*); - sem::CaseStatement* CaseStatement(const ast::CaseStatement*); + sem::CaseStatement* CaseStatement(const ast::CaseStatement*, const sem::Type*); sem::Statement* CompoundAssignmentStatement(const ast::CompoundAssignmentStatement*); sem::Statement* ContinueStatement(const ast::ContinueStatement*); sem::Statement* DiscardStatement(const ast::DiscardStatement*); diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc index 5e4494c75b..4538bf4894 100644 --- a/src/tint/resolver/resolver_test.cc +++ b/src/tint/resolver/resolver_test.cc @@ -132,8 +132,6 @@ TEST_F(ResolverTest, Stmt_Case) { ASSERT_EQ(sem->Cases().size(), 2u); EXPECT_EQ(sem->Cases()[0]->Declaration(), cse); ASSERT_EQ(sem->Cases()[0]->Selectors().size(), 1u); - EXPECT_EQ(sem->Cases()[0]->Selectors()[0]->Declaration(), sel); - EXPECT_EQ(sem->Cases()[1]->Declaration(), def); EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 0u); } diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index e14cdd6fb4..f7382743bd 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -2338,22 +2338,36 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) { has_default = true; } - for (auto* selector : case_stmt->selectors) { - if (cond_ty != sem_.TypeOf(selector)) { + auto* case_sem = sem_.Get(case_stmt); + + auto& case_selectors = case_stmt->selectors; + auto& selector_values = case_sem->Selectors(); + TINT_ASSERT(Resolver, case_selectors.Length() == selector_values.size()); + for (size_t i = 0; i < case_sem->Selectors().size(); ++i) { + auto* selector = selector_values[i]; + if (cond_ty != selector->Type()) { AddError( "the case selector values must have the same type as the selector expression.", - case_stmt->source); + case_selectors[i]->source); return false; } - auto it = selectors.find(selector->value); + auto value = selector->As(); + auto it = selectors.find(value); if (it != selectors.end()) { - auto val = std::to_string(selector->value); - AddError("duplicate switch case '" + val + "'", selector->source); + std::string err = "duplicate switch case '"; + if (selector->Type()->Is()) { + err += std::to_string(selector->As()); + } else { + err += std::to_string(value); + } + err += "'"; + + AddError(err, case_selectors[i]->source); AddNote("previous case declared here", it->second); return false; } - selectors.emplace(selector->value, selector->source); + selectors.emplace(value, case_selectors[i]->source); } } diff --git a/src/tint/sem/switch_statement.h b/src/tint/sem/switch_statement.h index a6b5c00f92..7028c052e6 100644 --- a/src/tint/sem/switch_statement.h +++ b/src/tint/sem/switch_statement.h @@ -26,6 +26,7 @@ class SwitchStatement; } // namespace tint::ast namespace tint::sem { class CaseStatement; +class Constant; class Expression; } // namespace tint::sem @@ -82,14 +83,14 @@ class CaseStatement final : public Castable { const BlockStatement* Body() const { return body_; } /// @returns the selectors for the case - std::vector& Selectors() { return selectors_; } + std::vector& Selectors() { return selectors_; } /// @returns the selectors for the case - const std::vector& Selectors() const { return selectors_; } + const std::vector& Selectors() const { return selectors_; } private: const BlockStatement* body_ = nullptr; - std::vector selectors_; + std::vector selectors_; }; } // namespace tint::sem diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index 49da7d6686..cf7567818a 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -43,6 +43,7 @@ #include "src/tint/sem/statement.h" #include "src/tint/sem/storage_texture.h" #include "src/tint/sem/struct.h" +#include "src/tint/sem/switch_statement.h" #include "src/tint/sem/type_constructor.h" #include "src/tint/sem/type_conversion.h" #include "src/tint/sem/variable.h" @@ -1689,14 +1690,15 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) { if (stmt->IsDefault()) { line() << "default: {"; } else { - for (auto* selector : stmt->selectors) { + auto* sem = builder_.Sem().Get(stmt); + for (auto* selector : sem->Selectors()) { auto out = line(); out << "case "; - if (!EmitLiteral(out, selector)) { + if (!EmitConstant(out, selector)) { return false; } out << ":"; - if (selector == stmt->selectors.Back()) { + if (selector == sem->Selectors().back()) { out << " {"; } } diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 439f0eeb19..7bdb8e95f3 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -44,6 +44,7 @@ #include "src/tint/sem/statement.h" #include "src/tint/sem/storage_texture.h" #include "src/tint/sem/struct.h" +#include "src/tint/sem/switch_statement.h" #include "src/tint/sem/type_constructor.h" #include "src/tint/sem/type_conversion.h" #include "src/tint/sem/variable.h" @@ -2564,14 +2565,15 @@ bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) { if (stmt->IsDefault()) { line() << "default: {"; } else { - for (auto* selector : stmt->selectors) { + auto* sem = builder_.Sem().Get(stmt); + for (auto* selector : sem->Selectors()) { auto out = line(); out << "case "; - if (!EmitLiteral(out, selector)) { + if (!EmitConstant(out, selector)) { return false; } out << ":"; - if (selector == stmt->selectors.Back()) { + if (selector == sem->Selectors().back()) { out << " {"; } } diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 081444f26d..f765d93960 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -52,6 +52,7 @@ #include "src/tint/sem/sampled_texture.h" #include "src/tint/sem/storage_texture.h" #include "src/tint/sem/struct.h" +#include "src/tint/sem/switch_statement.h" #include "src/tint/sem/type_constructor.h" #include "src/tint/sem/type_conversion.h" #include "src/tint/sem/u32.h" @@ -1591,14 +1592,15 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) { if (stmt->IsDefault()) { line() << "default: {"; } else { - for (auto* selector : stmt->selectors) { + auto* sem = builder_.Sem().Get(stmt); + for (auto* selector : sem->Selectors()) { auto out = line(); out << "case "; - if (!EmitLiteral(out, selector)) { + if (!EmitConstant(out, selector)) { return false; } out << ":"; - if (selector == stmt->selectors.Back()) { + if (selector == sem->Selectors().back()) { out << " {"; } } diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index 96edf6c7b5..b72de6782f 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -39,6 +39,7 @@ #include "src/tint/sem/sampled_texture.h" #include "src/tint/sem/statement.h" #include "src/tint/sem/struct.h" +#include "src/tint/sem/switch_statement.h" #include "src/tint/sem/type_constructor.h" #include "src/tint/sem/type_conversion.h" #include "src/tint/sem/variable.h" @@ -3464,14 +3465,10 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { auto block_id = std::get(block); case_ids.push_back(block_id); - for (auto* selector : item->selectors) { - auto* int_literal = selector->As(); - if (!int_literal) { - error_ = "expected integer literal for switch case label"; - return false; - } - params.push_back(Operand(static_cast(int_literal->value))); + auto* sem = builder_.Sem().Get(item); + for (auto* selector : sem->Selectors()) { + params.push_back(Operand(selector->As())); params.push_back(Operand(block_id)); } } diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc index 7795455624..7183c74dc3 100644 --- a/src/tint/writer/wgsl/generator_impl.cc +++ b/src/tint/writer/wgsl/generator_impl.cc @@ -50,6 +50,7 @@ #include "src/tint/ast/void.h" #include "src/tint/ast/workgroup_attribute.h" #include "src/tint/sem/struct.h" +#include "src/tint/sem/switch_statement.h" #include "src/tint/utils/math.h" #include "src/tint/utils/scoped_assignment.h" #include "src/tint/writer/float_to_string.h" @@ -1030,13 +1031,13 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) { out << "case "; bool first = true; - for (auto* selector : stmt->selectors) { + for (auto* expr : stmt->selectors) { if (!first) { out << ", "; } first = false; - if (!EmitLiteral(out, selector)) { + if (!EmitExpression(out, expr)) { return false; } }