Store expressions in switch case statements.

This CL moves switch case statements to store Expression instead
of an IntLiteralExpression. The SEM is updated to store the
materialized constant instead of accessing the expression value
directly.

Bug: tint:1633
Change-Id: Id79dabb806be1049f775299732bc1c7b1bf0c05f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/106300
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-10-19 00:43:41 +00:00 committed by Dawn LUCI CQ
parent 00aa7ef462
commit d32fbe07e7
19 changed files with 360 additions and 120 deletions

View File

@ -25,7 +25,7 @@ namespace tint::ast {
CaseStatement::CaseStatement(ProgramID pid,
NodeID nid,
const Source& src,
utils::VectorRef<const IntLiteralExpression*> s,
utils::VectorRef<const Expression*> s,
const BlockStatement* b)
: Base(pid, nid, src), selectors(std::move(s)), body(b) {
TINT_ASSERT(AST, body);

View File

@ -18,7 +18,7 @@
#include <vector>
#include "src/tint/ast/block_statement.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/ast/expression.h"
namespace tint::ast {
@ -34,7 +34,7 @@ class CaseStatement final : public Castable<CaseStatement, Statement> {
CaseStatement(ProgramID pid,
NodeID nid,
const Source& src,
utils::VectorRef<const IntLiteralExpression*> selectors,
utils::VectorRef<const Expression*> selectors,
const BlockStatement* body);
/// Move constructor
CaseStatement(CaseStatement&&);
@ -50,7 +50,7 @@ class CaseStatement final : public Castable<CaseStatement, Statement> {
const CaseStatement* Clone(CloneContext* ctx) const override;
/// The case selectors, empty if none set
const utils::Vector<const IntLiteralExpression*, 4> selectors;
const utils::Vector<const Expression*, 4> selectors;
/// The case body
const BlockStatement* const body;

View File

@ -54,6 +54,7 @@
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/increment_decrement_statement.h"
#include "src/tint/ast/index_accessor_expression.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/invariant_attribute.h"
#include "src/tint/ast/let.h"
@ -2846,7 +2847,7 @@ class ProgramBuilder {
/// @param body the case body
/// @returns the case statement pointer
const ast::CaseStatement* Case(const Source& source,
utils::VectorRef<const ast::IntLiteralExpression*> selectors,
utils::VectorRef<const ast::Expression*> selectors,
const ast::BlockStatement* body = nullptr) {
return create<ast::CaseStatement>(source, std::move(selectors), body ? body : Block());
}
@ -2855,7 +2856,7 @@ class ProgramBuilder {
/// @param selectors list of selectors
/// @param body the case body
/// @returns the case statement pointer
const ast::CaseStatement* Case(utils::VectorRef<const ast::IntLiteralExpression*> selectors,
const ast::CaseStatement* Case(utils::VectorRef<const ast::Expression*> selectors,
const ast::BlockStatement* body = nullptr) {
return create<ast::CaseStatement>(std::move(selectors), body ? body : Block());
}
@ -2864,7 +2865,7 @@ class ProgramBuilder {
/// @param selector a single case selector
/// @param body the case body
/// @returns the case statement pointer
const ast::CaseStatement* Case(const ast::IntLiteralExpression* selector,
const ast::CaseStatement* Case(const ast::Expression* selector,
const ast::BlockStatement* body = nullptr) {
return Case(utils::Vector{selector}, body);
}

View File

@ -3024,7 +3024,7 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
for (size_t i = last_clause_index;; --i) {
// Create a list of integer literals for the selector values leading to
// this case clause.
utils::Vector<const ast::IntLiteralExpression*, 4> selectors;
utils::Vector<const ast::Expression*, 4> selectors;
const bool has_selectors = clause_heads[i]->case_values.has_value();
if (has_selectors) {
auto values = clause_heads[i]->case_values.value();

View File

@ -2148,21 +2148,19 @@ Maybe<const ast::CaseStatement*> ParserImpl::switch_body() {
}
// case_selectors
// : const_literal (COMMA const_literal)* COMMA?
// : expression (COMMA expression)* COMMA?
Expect<ParserImpl::CaseSelectorList> ParserImpl::expect_case_selectors() {
CaseSelectorList selectors;
while (continue_parsing()) {
auto cond = const_literal();
if (cond.errored) {
auto expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!cond.matched) {
break;
} else if (!cond->Is<ast::IntLiteralExpression>()) {
return add_error(cond.value->source, "invalid case selector must be an integer value");
}
selectors.Push(cond.value->As<ast::IntLiteralExpression>());
if (!expr.matched) {
break;
}
selectors.Push(expr.value);
if (!match(Token::Type::kComma)) {
break;

View File

@ -74,7 +74,7 @@ class ParserImpl {
/// Pre-determined small vector sizes for AST pointers
//! @cond Doxygen_Suppress
using AttributeList = utils::Vector<const ast::Attribute*, 4>;
using CaseSelectorList = utils::Vector<const ast::IntLiteralExpression*, 4>;
using CaseSelectorList = utils::Vector<const ast::Expression*, 4>;
using CaseStatementList = utils::Vector<const ast::CaseStatement*, 4>;
using ExpressionList = utils::Vector<const ast::Expression*, 8>;
using ParameterList = utils::Vector<const ast::Parameter*, 8>;

View File

@ -1339,14 +1339,6 @@ fn f() { switch(1) { case ^: } }
)");
}
TEST_F(ParserImplErrorTest, SwitchStmtInvalidCase2) {
EXPECT("fn f() { switch(1) { case false: } }",
R"(test.wgsl:1:27 error: invalid case selector must be an integer value
fn f() { switch(1) { case false: } }
^^^^^
)");
}
TEST_F(ParserImplErrorTest, SwitchStmtCaseMissingLBrace) {
EXPECT("fn f() { switch(1) { case 1: } }",
R"(test.wgsl:1:30 error: expected '{' for case statement

View File

@ -26,10 +26,42 @@ TEST_F(ParserImplTest, SwitchBody_Case) {
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
EXPECT_EQ(stmt->selectors[0]->value, 1);
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_EQ(e->body->statements.Length(), 1u);
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, SwitchBody_Case_Expression) {
auto p = parser("case 1 + 2 { a = 4; }");
auto e = p->switch_body();
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
ASSERT_TRUE(stmt->selectors[0]->Is<ast::BinaryExpression>());
auto* expr = stmt->selectors[0]->As<ast::BinaryExpression>();
EXPECT_EQ(ast::BinaryOp::kAdd, expr->op);
auto* v = expr->lhs->As<ast::IntLiteralExpression>();
ASSERT_NE(nullptr, v);
EXPECT_EQ(v->value, 1u);
v = expr->rhs->As<ast::IntLiteralExpression>();
ASSERT_NE(nullptr, v);
EXPECT_EQ(v->value, 2u);
ASSERT_EQ(e->body->statements.Length(), 1u);
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
}
@ -43,10 +75,14 @@ TEST_F(ParserImplTest, SwitchBody_Case_WithColon) {
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
EXPECT_EQ(stmt->selectors[0]->value, 1);
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_EQ(e->body->statements.Length(), 1u);
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
}
@ -62,9 +98,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma) {
EXPECT_FALSE(e->IsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 2u);
EXPECT_EQ(stmt->selectors[0]->value, 1);
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(stmt->selectors[1]->value, 2);
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) {
@ -76,15 +119,23 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) {
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 2u);
EXPECT_EQ(stmt->selectors[0]->value, 1);
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(stmt->selectors[1]->value, 2);
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
auto p = parser("case a == 4: { a = 4; }");
TEST_F(ParserImplTest, SwitchBody_Case_Invalid) {
auto p = parser("case if: { a = 4; }");
auto e = p->switch_body();
EXPECT_TRUE(p->has_error());
EXPECT_TRUE(e.errored);
@ -93,16 +144,6 @@ TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
EXPECT_EQ(p->error(), "1:6: unable to parse case selectors");
}
TEST_F(ParserImplTest, SwitchBody_Case_InvalidSelector_bool) {
auto p = parser("case true: { a = 4; }");
auto e = p->switch_body();
EXPECT_TRUE(p->has_error());
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
EXPECT_EQ(e.value, nullptr);
EXPECT_EQ(p->error(), "1:6: invalid case selector must be an integer value");
}
TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) {
auto p = parser("case: { a = 4; }");
auto e = p->switch_body();
@ -164,10 +205,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors) {
EXPECT_FALSE(e->IsDefault());
ASSERT_EQ(e->body->statements.Length(), 0u);
ASSERT_EQ(e->selectors.Length(), 2u);
ASSERT_EQ(e->selectors[0]->value, 1);
EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_EQ(e->selectors[1]->value, 2);
EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
expr = e->selectors[1]->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) {
@ -181,10 +228,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) {
EXPECT_FALSE(e->IsDefault());
ASSERT_EQ(e->body->statements.Length(), 0u);
ASSERT_EQ(e->selectors.Length(), 2u);
ASSERT_EQ(e->selectors[0]->value, 1);
EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_EQ(e->selectors[1]->value, 2);
EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
expr = e->selectors[1]->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectorsMissingComma) {

View File

@ -25,7 +25,7 @@ namespace {
class ResolverControlBlockValidationTest : public TestHelper, public testing::Test {};
TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpressionNoneIntegerType_Fail) {
TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_F32) {
// var a : f32 = 3.14;
// switch (a) {
// default: {}
@ -43,6 +43,24 @@ TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpressionNoneIntegerTy
"scalar integer type");
}
TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_bool) {
// var a : bool = true;
// switch (a) {
// default: {}
// }
auto* var = Var("a", ty.bool_(), Expr(false));
auto* block = Block(Decl(var), Switch(Expr(Source{{12, 34}}, "a"), //
DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: switch statement selector expression must be of a "
"scalar integer type");
}
TEST_F(ResolverControlBlockValidationTest, SwitchWithoutDefault_Fail) {
// var a : i32 = 2;
// switch (a) {
@ -214,7 +232,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT
auto* var = Var("a", ty.i32(), Expr(2_i));
auto* block = Block(Decl(var), Switch("a", //
Case(Source{{12, 34}}, utils::Vector{Expr(1_u)}), //
Case(Expr(Source{{12, 34}}, 1_u)), //
DefaultCase()));
WrapInFunction(block);
@ -234,7 +252,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT
auto* block = Block(Decl(var), //
Switch("a", //
Case(Source{{12, 34}}, utils::Vector{Expr(-1_i)}), //
Case(utils::Vector{Expr(Source{{12, 34}}, -1_i)}), //
DefaultCase()));
WrapInFunction(block);
@ -332,6 +350,74 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Pass) {
// var a : i32 = 2;
// switch (a) {
// default: {}
// case 5 + 6: {}
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
Case(Add(5_i, 6_i))));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixI32_Abstract) {
// var a = 2;
// switch (a) {
// default: {}
// case 5i + 6i: {}
// }
auto* var = Var("a", Expr(2_a));
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
Case(Add(5_i, 6_i))));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixU32_Abstract) {
// var a = 2u;
// switch (a) {
// default: {}
// case 5 + 6: {}
// }
auto* var = Var("a", Expr(2_u));
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
Case(Add(5_a, 6_a))));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Multiple) {
// var a = 2u;
// switch (a) {
// default: {}
// case 5 + 6, 7+9, 2*4: {}
// }
auto* var = Var("a", Expr(2_u));
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
Case(utils::Vector{Add(5_u, 6_u), Add(7_u, 9_u), Mul(2_u, 4_u)})));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) {
// type MyInt = u32;
// var v: MyInt;
@ -349,5 +435,85 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelector_Expression_Fail) {
// var a : i32 = 2i;
// switch (a) {
// case 10i: {}
// case 5i+5i: {}
// default: {}
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
auto* block = Block(Decl(var), //
Switch("a", //
Case(Expr(Source{{12, 34}}, 10_i)),
Case(Add(Source{{56, 78}}, 5_i, 5_i)), DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"56:78 error: duplicate switch case '10'\n"
"12:34 note: previous case declared here");
}
TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSameCase_BothExpression_Fail) {
// var a : i32 = 2i;
// switch (a) {
// case 5i+5i, 6i+4i: {}
// default: {}
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
auto* block = Block(Decl(var), //
Switch("a", //
Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i),
Add(Source{{12, 34}}, 6_i, 4_i)}),
DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: duplicate switch case '10'\n"
"56:78 note: previous case declared here");
}
TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSame_Case_Expression_Fail) {
// var a : i32 = 2i;
// switch (a) {
// case 5u+5u, 10i: {}
// default: {}
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
auto* block = Block(
Decl(var), //
Switch("a", //
Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), Expr(Source{{12, 34}}, 10_i)}),
DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: duplicate switch case '10'\n"
"56:78 note: previous case declared here");
}
TEST_F(ResolverControlBlockValidationTest, Switch_OverrideCondition_Fail) {
// override a : i32 = 2;
// switch (a) {
// default: {}
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
Override("b", ty.i32(), Expr(2_i));
auto* block = Block(Decl(var), //
Switch("a", //
Case(Expr(Source{{12, 34}}, "b")), DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: case selector must be a constant expression");
}
} // namespace
} // namespace tint::resolver

View File

@ -1234,18 +1234,34 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
});
}
sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt, const sem::Type* ty) {
auto* sem =
builder_->create<sem::CaseStatement>(stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
sem->Selectors().reserve(stmt->selectors.Length());
for (auto* sel : stmt->selectors) {
auto* expr = Expression(sel);
if (!expr) {
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
// The sem statement is created in the switch when attempting to determine the common
// type.
auto* materialized = Materialize(sem_.Get(sel), ty);
if (!materialized) {
return false;
}
sem->Selectors().emplace_back(expr);
if (!materialized->Type()->IsAnyOf<sem::I32, sem::U32>()) {
AddError("case selector must be an i32 or u32 value", sel->source);
return false;
}
auto const_value = materialized->ConstantValue();
if (!const_value) {
AddError("case selector must be a constant expression", sel->source);
return false;
}
sem->Selectors().emplace_back(const_value);
}
Mark(stmt->body);
auto* body = BlockStatement(stmt->body);
if (!body) {
@ -3082,27 +3098,16 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt
auto* cond_ty = cond->Type()->UnwrapRef();
utils::Vector<const sem::Type*, 8> types;
types.Push(cond_ty);
utils::Vector<sem::CaseStatement*, 4> cases;
cases.Reserve(stmt->body.Length());
for (auto* case_stmt : stmt->body) {
Mark(case_stmt);
auto* c = CaseStatement(case_stmt);
if (!c) {
return false;
}
for (auto* expr : c->Selectors()) {
types.Push(expr->Type()->UnwrapRef());
}
cases.Push(c);
behaviors.Add(c->Behaviors());
sem->Cases().emplace_back(c);
}
// Determine the common type across all selectors and the switch expression
// This must materialize to an integer scalar (non-abstract).
utils::Vector<const sem::Type*, 8> types;
types.Push(cond_ty);
for (auto* case_stmt : stmt->body) {
for (auto* expr : case_stmt->selectors) {
auto* sem_expr = Expression(expr);
types.Push(sem_expr->Type()->UnwrapRef());
}
}
auto* common_ty = sem::Type::Common(types);
if (!common_ty || !common_ty->is_integer_scalar()) {
// No common type found or the common type was abstract.
@ -3113,13 +3118,21 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt
if (!cond) {
return false;
}
for (auto* c : cases) {
for (auto*& sel : c->Selectors()) { // Note: pointer reference
sel = Materialize(sel, common_ty);
if (!sel) {
utils::Vector<sem::CaseStatement*, 4> cases;
cases.Reserve(stmt->body.Length());
for (auto* case_stmt : stmt->body) {
Mark(case_stmt);
auto* c = CaseStatement(case_stmt, common_ty);
if (!c) {
return false;
}
for (auto* expr : c->Selectors()) {
types.Push(expr->Type()->UnwrapRef());
}
cases.Push(c);
behaviors.Add(c->Behaviors());
sem->Cases().emplace_back(c);
}
if (behaviors.Contains(sem::Behavior::kBreak)) {

View File

@ -209,7 +209,7 @@ class Resolver {
sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
sem::Statement* BreakStatement(const ast::BreakStatement*);
sem::Statement* CallStatement(const ast::CallStatement*);
sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
sem::CaseStatement* CaseStatement(const ast::CaseStatement*, const sem::Type*);
sem::Statement* CompoundAssignmentStatement(const ast::CompoundAssignmentStatement*);
sem::Statement* ContinueStatement(const ast::ContinueStatement*);
sem::Statement* DiscardStatement(const ast::DiscardStatement*);

View File

@ -132,8 +132,6 @@ TEST_F(ResolverTest, Stmt_Case) {
ASSERT_EQ(sem->Cases().size(), 2u);
EXPECT_EQ(sem->Cases()[0]->Declaration(), cse);
ASSERT_EQ(sem->Cases()[0]->Selectors().size(), 1u);
EXPECT_EQ(sem->Cases()[0]->Selectors()[0]->Declaration(), sel);
EXPECT_EQ(sem->Cases()[1]->Declaration(), def);
EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 0u);
}

View File

@ -2338,22 +2338,36 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
has_default = true;
}
for (auto* selector : case_stmt->selectors) {
if (cond_ty != sem_.TypeOf(selector)) {
auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt);
auto& case_selectors = case_stmt->selectors;
auto& selector_values = case_sem->Selectors();
TINT_ASSERT(Resolver, case_selectors.Length() == selector_values.size());
for (size_t i = 0; i < case_sem->Selectors().size(); ++i) {
auto* selector = selector_values[i];
if (cond_ty != selector->Type()) {
AddError(
"the case selector values must have the same type as the selector expression.",
case_stmt->source);
case_selectors[i]->source);
return false;
}
auto it = selectors.find(selector->value);
auto value = selector->As<uint32_t>();
auto it = selectors.find(value);
if (it != selectors.end()) {
auto val = std::to_string(selector->value);
AddError("duplicate switch case '" + val + "'", selector->source);
std::string err = "duplicate switch case '";
if (selector->Type()->Is<sem::I32>()) {
err += std::to_string(selector->As<int32_t>());
} else {
err += std::to_string(value);
}
err += "'";
AddError(err, case_selectors[i]->source);
AddNote("previous case declared here", it->second);
return false;
}
selectors.emplace(selector->value, selector->source);
selectors.emplace(value, case_selectors[i]->source);
}
}

View File

@ -26,6 +26,7 @@ class SwitchStatement;
} // namespace tint::ast
namespace tint::sem {
class CaseStatement;
class Constant;
class Expression;
} // namespace tint::sem
@ -82,14 +83,14 @@ class CaseStatement final : public Castable<CaseStatement, CompoundStatement> {
const BlockStatement* Body() const { return body_; }
/// @returns the selectors for the case
std::vector<const Expression*>& Selectors() { return selectors_; }
std::vector<const Constant*>& Selectors() { return selectors_; }
/// @returns the selectors for the case
const std::vector<const Expression*>& Selectors() const { return selectors_; }
const std::vector<const Constant*>& Selectors() const { return selectors_; }
private:
const BlockStatement* body_ = nullptr;
std::vector<const Expression*> selectors_;
std::vector<const Constant*> selectors_;
};
} // namespace tint::sem

View File

@ -43,6 +43,7 @@
#include "src/tint/sem/statement.h"
#include "src/tint/sem/storage_texture.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
@ -1689,14 +1690,15 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
if (stmt->IsDefault()) {
line() << "default: {";
} else {
for (auto* selector : stmt->selectors) {
auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
for (auto* selector : sem->Selectors()) {
auto out = line();
out << "case ";
if (!EmitLiteral(out, selector)) {
if (!EmitConstant(out, selector)) {
return false;
}
out << ":";
if (selector == stmt->selectors.Back()) {
if (selector == sem->Selectors().back()) {
out << " {";
}
}

View File

@ -44,6 +44,7 @@
#include "src/tint/sem/statement.h"
#include "src/tint/sem/storage_texture.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
@ -2564,14 +2565,15 @@ bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) {
if (stmt->IsDefault()) {
line() << "default: {";
} else {
for (auto* selector : stmt->selectors) {
auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
for (auto* selector : sem->Selectors()) {
auto out = line();
out << "case ";
if (!EmitLiteral(out, selector)) {
if (!EmitConstant(out, selector)) {
return false;
}
out << ":";
if (selector == stmt->selectors.Back()) {
if (selector == sem->Selectors().back()) {
out << " {";
}
}

View File

@ -52,6 +52,7 @@
#include "src/tint/sem/sampled_texture.h"
#include "src/tint/sem/storage_texture.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/u32.h"
@ -1591,14 +1592,15 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
if (stmt->IsDefault()) {
line() << "default: {";
} else {
for (auto* selector : stmt->selectors) {
auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
for (auto* selector : sem->Selectors()) {
auto out = line();
out << "case ";
if (!EmitLiteral(out, selector)) {
if (!EmitConstant(out, selector)) {
return false;
}
out << ":";
if (selector == stmt->selectors.Back()) {
if (selector == sem->Selectors().back()) {
out << " {";
}
}

View File

@ -39,6 +39,7 @@
#include "src/tint/sem/sampled_texture.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
@ -3464,14 +3465,10 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) {
auto block_id = std::get<uint32_t>(block);
case_ids.push_back(block_id);
for (auto* selector : item->selectors) {
auto* int_literal = selector->As<ast::IntLiteralExpression>();
if (!int_literal) {
error_ = "expected integer literal for switch case label";
return false;
}
params.push_back(Operand(static_cast<uint32_t>(int_literal->value)));
auto* sem = builder_.Sem().Get<sem::CaseStatement>(item);
for (auto* selector : sem->Selectors()) {
params.push_back(Operand(selector->As<uint32_t>()));
params.push_back(Operand(block_id));
}
}

View File

@ -50,6 +50,7 @@
#include "src/tint/ast/void.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/utils/math.h"
#include "src/tint/utils/scoped_assignment.h"
#include "src/tint/writer/float_to_string.h"
@ -1030,13 +1031,13 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
out << "case ";
bool first = true;
for (auto* selector : stmt->selectors) {
for (auto* expr : stmt->selectors) {
if (!first) {
out << ", ";
}
first = false;
if (!EmitLiteral(out, selector)) {
if (!EmitExpression(out, expr)) {
return false;
}
}