Store expressions in switch case statements.
This CL moves switch case statements to store Expression instead of an IntLiteralExpression. The SEM is updated to store the materialized constant instead of accessing the expression value directly. Bug: tint:1633 Change-Id: Id79dabb806be1049f775299732bc1c7b1bf0c05f Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/106300 Commit-Queue: Dan Sinclair <dsinclair@chromium.org> Reviewed-by: Ben Clayton <bclayton@google.com> Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
00aa7ef462
commit
d32fbe07e7
|
@ -25,7 +25,7 @@ namespace tint::ast {
|
|||
CaseStatement::CaseStatement(ProgramID pid,
|
||||
NodeID nid,
|
||||
const Source& src,
|
||||
utils::VectorRef<const IntLiteralExpression*> s,
|
||||
utils::VectorRef<const Expression*> s,
|
||||
const BlockStatement* b)
|
||||
: Base(pid, nid, src), selectors(std::move(s)), body(b) {
|
||||
TINT_ASSERT(AST, body);
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/block_statement.h"
|
||||
#include "src/tint/ast/int_literal_expression.h"
|
||||
#include "src/tint/ast/expression.h"
|
||||
|
||||
namespace tint::ast {
|
||||
|
||||
|
@ -34,7 +34,7 @@ class CaseStatement final : public Castable<CaseStatement, Statement> {
|
|||
CaseStatement(ProgramID pid,
|
||||
NodeID nid,
|
||||
const Source& src,
|
||||
utils::VectorRef<const IntLiteralExpression*> selectors,
|
||||
utils::VectorRef<const Expression*> selectors,
|
||||
const BlockStatement* body);
|
||||
/// Move constructor
|
||||
CaseStatement(CaseStatement&&);
|
||||
|
@ -50,7 +50,7 @@ class CaseStatement final : public Castable<CaseStatement, Statement> {
|
|||
const CaseStatement* Clone(CloneContext* ctx) const override;
|
||||
|
||||
/// The case selectors, empty if none set
|
||||
const utils::Vector<const IntLiteralExpression*, 4> selectors;
|
||||
const utils::Vector<const Expression*, 4> selectors;
|
||||
|
||||
/// The case body
|
||||
const BlockStatement* const body;
|
||||
|
|
|
@ -54,6 +54,7 @@
|
|||
#include "src/tint/ast/if_statement.h"
|
||||
#include "src/tint/ast/increment_decrement_statement.h"
|
||||
#include "src/tint/ast/index_accessor_expression.h"
|
||||
#include "src/tint/ast/int_literal_expression.h"
|
||||
#include "src/tint/ast/interpolate_attribute.h"
|
||||
#include "src/tint/ast/invariant_attribute.h"
|
||||
#include "src/tint/ast/let.h"
|
||||
|
@ -2846,7 +2847,7 @@ class ProgramBuilder {
|
|||
/// @param body the case body
|
||||
/// @returns the case statement pointer
|
||||
const ast::CaseStatement* Case(const Source& source,
|
||||
utils::VectorRef<const ast::IntLiteralExpression*> selectors,
|
||||
utils::VectorRef<const ast::Expression*> selectors,
|
||||
const ast::BlockStatement* body = nullptr) {
|
||||
return create<ast::CaseStatement>(source, std::move(selectors), body ? body : Block());
|
||||
}
|
||||
|
@ -2855,7 +2856,7 @@ class ProgramBuilder {
|
|||
/// @param selectors list of selectors
|
||||
/// @param body the case body
|
||||
/// @returns the case statement pointer
|
||||
const ast::CaseStatement* Case(utils::VectorRef<const ast::IntLiteralExpression*> selectors,
|
||||
const ast::CaseStatement* Case(utils::VectorRef<const ast::Expression*> selectors,
|
||||
const ast::BlockStatement* body = nullptr) {
|
||||
return create<ast::CaseStatement>(std::move(selectors), body ? body : Block());
|
||||
}
|
||||
|
@ -2864,7 +2865,7 @@ class ProgramBuilder {
|
|||
/// @param selector a single case selector
|
||||
/// @param body the case body
|
||||
/// @returns the case statement pointer
|
||||
const ast::CaseStatement* Case(const ast::IntLiteralExpression* selector,
|
||||
const ast::CaseStatement* Case(const ast::Expression* selector,
|
||||
const ast::BlockStatement* body = nullptr) {
|
||||
return Case(utils::Vector{selector}, body);
|
||||
}
|
||||
|
|
|
@ -3024,7 +3024,7 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
|
|||
for (size_t i = last_clause_index;; --i) {
|
||||
// Create a list of integer literals for the selector values leading to
|
||||
// this case clause.
|
||||
utils::Vector<const ast::IntLiteralExpression*, 4> selectors;
|
||||
utils::Vector<const ast::Expression*, 4> selectors;
|
||||
const bool has_selectors = clause_heads[i]->case_values.has_value();
|
||||
if (has_selectors) {
|
||||
auto values = clause_heads[i]->case_values.value();
|
||||
|
|
|
@ -2148,21 +2148,19 @@ Maybe<const ast::CaseStatement*> ParserImpl::switch_body() {
|
|||
}
|
||||
|
||||
// case_selectors
|
||||
// : const_literal (COMMA const_literal)* COMMA?
|
||||
// : expression (COMMA expression)* COMMA?
|
||||
Expect<ParserImpl::CaseSelectorList> ParserImpl::expect_case_selectors() {
|
||||
CaseSelectorList selectors;
|
||||
|
||||
while (continue_parsing()) {
|
||||
auto cond = const_literal();
|
||||
if (cond.errored) {
|
||||
auto expr = expression();
|
||||
if (expr.errored) {
|
||||
return Failure::kErrored;
|
||||
} else if (!cond.matched) {
|
||||
break;
|
||||
} else if (!cond->Is<ast::IntLiteralExpression>()) {
|
||||
return add_error(cond.value->source, "invalid case selector must be an integer value");
|
||||
}
|
||||
|
||||
selectors.Push(cond.value->As<ast::IntLiteralExpression>());
|
||||
if (!expr.matched) {
|
||||
break;
|
||||
}
|
||||
selectors.Push(expr.value);
|
||||
|
||||
if (!match(Token::Type::kComma)) {
|
||||
break;
|
||||
|
|
|
@ -74,7 +74,7 @@ class ParserImpl {
|
|||
/// Pre-determined small vector sizes for AST pointers
|
||||
//! @cond Doxygen_Suppress
|
||||
using AttributeList = utils::Vector<const ast::Attribute*, 4>;
|
||||
using CaseSelectorList = utils::Vector<const ast::IntLiteralExpression*, 4>;
|
||||
using CaseSelectorList = utils::Vector<const ast::Expression*, 4>;
|
||||
using CaseStatementList = utils::Vector<const ast::CaseStatement*, 4>;
|
||||
using ExpressionList = utils::Vector<const ast::Expression*, 8>;
|
||||
using ParameterList = utils::Vector<const ast::Parameter*, 8>;
|
||||
|
|
|
@ -1339,14 +1339,6 @@ fn f() { switch(1) { case ^: } }
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(ParserImplErrorTest, SwitchStmtInvalidCase2) {
|
||||
EXPECT("fn f() { switch(1) { case false: } }",
|
||||
R"(test.wgsl:1:27 error: invalid case selector must be an integer value
|
||||
fn f() { switch(1) { case false: } }
|
||||
^^^^^
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(ParserImplErrorTest, SwitchStmtCaseMissingLBrace) {
|
||||
EXPECT("fn f() { switch(1) { case 1: } }",
|
||||
R"(test.wgsl:1:30 error: expected '{' for case statement
|
||||
|
|
|
@ -26,10 +26,42 @@ TEST_F(ParserImplTest, SwitchBody_Case) {
|
|||
ASSERT_NE(e.value, nullptr);
|
||||
ASSERT_TRUE(e->Is<ast::CaseStatement>());
|
||||
EXPECT_FALSE(e->IsDefault());
|
||||
|
||||
auto* stmt = e->As<ast::CaseStatement>();
|
||||
ASSERT_EQ(stmt->selectors.Length(), 1u);
|
||||
EXPECT_EQ(stmt->selectors[0]->value, 1);
|
||||
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
|
||||
|
||||
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
|
||||
EXPECT_EQ(expr->value, 1);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_EQ(e->body->statements.Length(), 1u);
|
||||
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_Expression) {
|
||||
auto p = parser("case 1 + 2 { a = 4; }");
|
||||
auto e = p->switch_body();
|
||||
EXPECT_FALSE(p->has_error()) << p->error();
|
||||
EXPECT_TRUE(e.matched);
|
||||
EXPECT_FALSE(e.errored);
|
||||
ASSERT_NE(e.value, nullptr);
|
||||
ASSERT_TRUE(e->Is<ast::CaseStatement>());
|
||||
EXPECT_FALSE(e->IsDefault());
|
||||
|
||||
auto* stmt = e->As<ast::CaseStatement>();
|
||||
ASSERT_EQ(stmt->selectors.Length(), 1u);
|
||||
ASSERT_TRUE(stmt->selectors[0]->Is<ast::BinaryExpression>());
|
||||
auto* expr = stmt->selectors[0]->As<ast::BinaryExpression>();
|
||||
|
||||
EXPECT_EQ(ast::BinaryOp::kAdd, expr->op);
|
||||
auto* v = expr->lhs->As<ast::IntLiteralExpression>();
|
||||
ASSERT_NE(nullptr, v);
|
||||
EXPECT_EQ(v->value, 1u);
|
||||
|
||||
v = expr->rhs->As<ast::IntLiteralExpression>();
|
||||
ASSERT_NE(nullptr, v);
|
||||
EXPECT_EQ(v->value, 2u);
|
||||
|
||||
ASSERT_EQ(e->body->statements.Length(), 1u);
|
||||
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
|
||||
}
|
||||
|
@ -43,10 +75,14 @@ TEST_F(ParserImplTest, SwitchBody_Case_WithColon) {
|
|||
ASSERT_NE(e.value, nullptr);
|
||||
ASSERT_TRUE(e->Is<ast::CaseStatement>());
|
||||
EXPECT_FALSE(e->IsDefault());
|
||||
|
||||
auto* stmt = e->As<ast::CaseStatement>();
|
||||
ASSERT_EQ(stmt->selectors.Length(), 1u);
|
||||
EXPECT_EQ(stmt->selectors[0]->value, 1);
|
||||
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
|
||||
|
||||
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
|
||||
EXPECT_EQ(expr->value, 1);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_EQ(e->body->statements.Length(), 1u);
|
||||
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
|
||||
}
|
||||
|
@ -62,9 +98,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma) {
|
|||
EXPECT_FALSE(e->IsDefault());
|
||||
auto* stmt = e->As<ast::CaseStatement>();
|
||||
ASSERT_EQ(stmt->selectors.Length(), 2u);
|
||||
EXPECT_EQ(stmt->selectors[0]->value, 1);
|
||||
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
EXPECT_EQ(stmt->selectors[1]->value, 2);
|
||||
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
|
||||
|
||||
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
|
||||
EXPECT_EQ(expr->value, 1);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
|
||||
expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
|
||||
EXPECT_EQ(expr->value, 2);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) {
|
||||
|
@ -76,15 +119,23 @@ TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) {
|
|||
ASSERT_NE(e.value, nullptr);
|
||||
ASSERT_TRUE(e->Is<ast::CaseStatement>());
|
||||
EXPECT_FALSE(e->IsDefault());
|
||||
|
||||
auto* stmt = e->As<ast::CaseStatement>();
|
||||
ASSERT_EQ(stmt->selectors.Length(), 2u);
|
||||
EXPECT_EQ(stmt->selectors[0]->value, 1);
|
||||
EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
EXPECT_EQ(stmt->selectors[1]->value, 2);
|
||||
ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
|
||||
|
||||
auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
|
||||
EXPECT_EQ(expr->value, 1);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
|
||||
expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
|
||||
EXPECT_EQ(expr->value, 2);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
|
||||
auto p = parser("case a == 4: { a = 4; }");
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_Invalid) {
|
||||
auto p = parser("case if: { a = 4; }");
|
||||
auto e = p->switch_body();
|
||||
EXPECT_TRUE(p->has_error());
|
||||
EXPECT_TRUE(e.errored);
|
||||
|
@ -93,16 +144,6 @@ TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
|
|||
EXPECT_EQ(p->error(), "1:6: unable to parse case selectors");
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_InvalidSelector_bool) {
|
||||
auto p = parser("case true: { a = 4; }");
|
||||
auto e = p->switch_body();
|
||||
EXPECT_TRUE(p->has_error());
|
||||
EXPECT_TRUE(e.errored);
|
||||
EXPECT_FALSE(e.matched);
|
||||
EXPECT_EQ(e.value, nullptr);
|
||||
EXPECT_EQ(p->error(), "1:6: invalid case selector must be an integer value");
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) {
|
||||
auto p = parser("case: { a = 4; }");
|
||||
auto e = p->switch_body();
|
||||
|
@ -164,10 +205,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors) {
|
|||
EXPECT_FALSE(e->IsDefault());
|
||||
ASSERT_EQ(e->body->statements.Length(), 0u);
|
||||
ASSERT_EQ(e->selectors.Length(), 2u);
|
||||
ASSERT_EQ(e->selectors[0]->value, 1);
|
||||
EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_EQ(e->selectors[1]->value, 2);
|
||||
EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
|
||||
|
||||
auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
|
||||
ASSERT_EQ(expr->value, 1);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
|
||||
expr = e->selectors[1]->As<ast::IntLiteralExpression>();
|
||||
ASSERT_EQ(expr->value, 2);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) {
|
||||
|
@ -181,10 +228,16 @@ TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) {
|
|||
EXPECT_FALSE(e->IsDefault());
|
||||
ASSERT_EQ(e->body->statements.Length(), 0u);
|
||||
ASSERT_EQ(e->selectors.Length(), 2u);
|
||||
ASSERT_EQ(e->selectors[0]->value, 1);
|
||||
EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_EQ(e->selectors[1]->value, 2);
|
||||
EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
|
||||
|
||||
auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
|
||||
ASSERT_EQ(expr->value, 1);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
|
||||
ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
|
||||
expr = e->selectors[1]->As<ast::IntLiteralExpression>();
|
||||
ASSERT_EQ(expr->value, 2);
|
||||
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectorsMissingComma) {
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace {
|
|||
|
||||
class ResolverControlBlockValidationTest : public TestHelper, public testing::Test {};
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpressionNoneIntegerType_Fail) {
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_F32) {
|
||||
// var a : f32 = 3.14;
|
||||
// switch (a) {
|
||||
// default: {}
|
||||
|
@ -43,6 +43,24 @@ TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpressionNoneIntegerTy
|
|||
"scalar integer type");
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_bool) {
|
||||
// var a : bool = true;
|
||||
// switch (a) {
|
||||
// default: {}
|
||||
// }
|
||||
auto* var = Var("a", ty.bool_(), Expr(false));
|
||||
|
||||
auto* block = Block(Decl(var), Switch(Expr(Source{{12, 34}}, "a"), //
|
||||
DefaultCase()));
|
||||
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: switch statement selector expression must be of a "
|
||||
"scalar integer type");
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchWithoutDefault_Fail) {
|
||||
// var a : i32 = 2;
|
||||
// switch (a) {
|
||||
|
@ -214,7 +232,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT
|
|||
auto* var = Var("a", ty.i32(), Expr(2_i));
|
||||
|
||||
auto* block = Block(Decl(var), Switch("a", //
|
||||
Case(Source{{12, 34}}, utils::Vector{Expr(1_u)}), //
|
||||
Case(Expr(Source{{12, 34}}, 1_u)), //
|
||||
DefaultCase()));
|
||||
WrapInFunction(block);
|
||||
|
||||
|
@ -234,7 +252,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchConditionTypeMustMatchSelectorT
|
|||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
Case(Source{{12, 34}}, utils::Vector{Expr(-1_i)}), //
|
||||
Case(utils::Vector{Expr(Source{{12, 34}}, -1_i)}), //
|
||||
DefaultCase()));
|
||||
WrapInFunction(block);
|
||||
|
||||
|
@ -332,6 +350,74 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) {
|
|||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Pass) {
|
||||
// var a : i32 = 2;
|
||||
// switch (a) {
|
||||
// default: {}
|
||||
// case 5 + 6: {}
|
||||
// }
|
||||
auto* var = Var("a", ty.i32(), Expr(2_i));
|
||||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
DefaultCase(Source{{12, 34}}), //
|
||||
Case(Add(5_i, 6_i))));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixI32_Abstract) {
|
||||
// var a = 2;
|
||||
// switch (a) {
|
||||
// default: {}
|
||||
// case 5i + 6i: {}
|
||||
// }
|
||||
auto* var = Var("a", Expr(2_a));
|
||||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
DefaultCase(Source{{12, 34}}), //
|
||||
Case(Add(5_i, 6_i))));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixU32_Abstract) {
|
||||
// var a = 2u;
|
||||
// switch (a) {
|
||||
// default: {}
|
||||
// case 5 + 6: {}
|
||||
// }
|
||||
auto* var = Var("a", Expr(2_u));
|
||||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
DefaultCase(Source{{12, 34}}), //
|
||||
Case(Add(5_a, 6_a))));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Multiple) {
|
||||
// var a = 2u;
|
||||
// switch (a) {
|
||||
// default: {}
|
||||
// case 5 + 6, 7+9, 2*4: {}
|
||||
// }
|
||||
auto* var = Var("a", Expr(2_u));
|
||||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
DefaultCase(Source{{12, 34}}), //
|
||||
Case(utils::Vector{Add(5_u, 6_u), Add(7_u, 9_u), Mul(2_u, 4_u)})));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) {
|
||||
// type MyInt = u32;
|
||||
// var v: MyInt;
|
||||
|
@ -349,5 +435,85 @@ TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) {
|
|||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelector_Expression_Fail) {
|
||||
// var a : i32 = 2i;
|
||||
// switch (a) {
|
||||
// case 10i: {}
|
||||
// case 5i+5i: {}
|
||||
// default: {}
|
||||
// }
|
||||
auto* var = Var("a", ty.i32(), Expr(2_i));
|
||||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
Case(Expr(Source{{12, 34}}, 10_i)),
|
||||
Case(Add(Source{{56, 78}}, 5_i, 5_i)), DefaultCase()));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"56:78 error: duplicate switch case '10'\n"
|
||||
"12:34 note: previous case declared here");
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSameCase_BothExpression_Fail) {
|
||||
// var a : i32 = 2i;
|
||||
// switch (a) {
|
||||
// case 5i+5i, 6i+4i: {}
|
||||
// default: {}
|
||||
// }
|
||||
auto* var = Var("a", ty.i32(), Expr(2_i));
|
||||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i),
|
||||
Add(Source{{12, 34}}, 6_i, 4_i)}),
|
||||
DefaultCase()));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: duplicate switch case '10'\n"
|
||||
"56:78 note: previous case declared here");
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSame_Case_Expression_Fail) {
|
||||
// var a : i32 = 2i;
|
||||
// switch (a) {
|
||||
// case 5u+5u, 10i: {}
|
||||
// default: {}
|
||||
// }
|
||||
auto* var = Var("a", ty.i32(), Expr(2_i));
|
||||
|
||||
auto* block = Block(
|
||||
Decl(var), //
|
||||
Switch("a", //
|
||||
Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), Expr(Source{{12, 34}}, 10_i)}),
|
||||
DefaultCase()));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: duplicate switch case '10'\n"
|
||||
"56:78 note: previous case declared here");
|
||||
}
|
||||
|
||||
TEST_F(ResolverControlBlockValidationTest, Switch_OverrideCondition_Fail) {
|
||||
// override a : i32 = 2;
|
||||
// switch (a) {
|
||||
// default: {}
|
||||
// }
|
||||
auto* var = Var("a", ty.i32(), Expr(2_i));
|
||||
Override("b", ty.i32(), Expr(2_i));
|
||||
|
||||
auto* block = Block(Decl(var), //
|
||||
Switch("a", //
|
||||
Case(Expr(Source{{12, 34}}, "b")), DefaultCase()));
|
||||
WrapInFunction(block);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: case selector must be a constant expression");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::resolver
|
||||
|
|
|
@ -1234,18 +1234,34 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
|
|||
});
|
||||
}
|
||||
|
||||
sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
|
||||
sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt, const sem::Type* ty) {
|
||||
auto* sem =
|
||||
builder_->create<sem::CaseStatement>(stmt, current_compound_statement_, current_function_);
|
||||
return StatementScope(stmt, sem, [&] {
|
||||
sem->Selectors().reserve(stmt->selectors.Length());
|
||||
for (auto* sel : stmt->selectors) {
|
||||
auto* expr = Expression(sel);
|
||||
if (!expr) {
|
||||
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"};
|
||||
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
|
||||
|
||||
// The sem statement is created in the switch when attempting to determine the common
|
||||
// type.
|
||||
auto* materialized = Materialize(sem_.Get(sel), ty);
|
||||
if (!materialized) {
|
||||
return false;
|
||||
}
|
||||
sem->Selectors().emplace_back(expr);
|
||||
if (!materialized->Type()->IsAnyOf<sem::I32, sem::U32>()) {
|
||||
AddError("case selector must be an i32 or u32 value", sel->source);
|
||||
return false;
|
||||
}
|
||||
auto const_value = materialized->ConstantValue();
|
||||
if (!const_value) {
|
||||
AddError("case selector must be a constant expression", sel->source);
|
||||
return false;
|
||||
}
|
||||
|
||||
sem->Selectors().emplace_back(const_value);
|
||||
}
|
||||
|
||||
Mark(stmt->body);
|
||||
auto* body = BlockStatement(stmt->body);
|
||||
if (!body) {
|
||||
|
@ -3082,27 +3098,16 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt
|
|||
|
||||
auto* cond_ty = cond->Type()->UnwrapRef();
|
||||
|
||||
utils::Vector<const sem::Type*, 8> types;
|
||||
types.Push(cond_ty);
|
||||
|
||||
utils::Vector<sem::CaseStatement*, 4> cases;
|
||||
cases.Reserve(stmt->body.Length());
|
||||
for (auto* case_stmt : stmt->body) {
|
||||
Mark(case_stmt);
|
||||
auto* c = CaseStatement(case_stmt);
|
||||
if (!c) {
|
||||
return false;
|
||||
}
|
||||
for (auto* expr : c->Selectors()) {
|
||||
types.Push(expr->Type()->UnwrapRef());
|
||||
}
|
||||
cases.Push(c);
|
||||
behaviors.Add(c->Behaviors());
|
||||
sem->Cases().emplace_back(c);
|
||||
}
|
||||
|
||||
// Determine the common type across all selectors and the switch expression
|
||||
// This must materialize to an integer scalar (non-abstract).
|
||||
utils::Vector<const sem::Type*, 8> types;
|
||||
types.Push(cond_ty);
|
||||
for (auto* case_stmt : stmt->body) {
|
||||
for (auto* expr : case_stmt->selectors) {
|
||||
auto* sem_expr = Expression(expr);
|
||||
types.Push(sem_expr->Type()->UnwrapRef());
|
||||
}
|
||||
}
|
||||
auto* common_ty = sem::Type::Common(types);
|
||||
if (!common_ty || !common_ty->is_integer_scalar()) {
|
||||
// No common type found or the common type was abstract.
|
||||
|
@ -3113,13 +3118,21 @@ sem::SwitchStatement* Resolver::SwitchStatement(const ast::SwitchStatement* stmt
|
|||
if (!cond) {
|
||||
return false;
|
||||
}
|
||||
for (auto* c : cases) {
|
||||
for (auto*& sel : c->Selectors()) { // Note: pointer reference
|
||||
sel = Materialize(sel, common_ty);
|
||||
if (!sel) {
|
||||
|
||||
utils::Vector<sem::CaseStatement*, 4> cases;
|
||||
cases.Reserve(stmt->body.Length());
|
||||
for (auto* case_stmt : stmt->body) {
|
||||
Mark(case_stmt);
|
||||
auto* c = CaseStatement(case_stmt, common_ty);
|
||||
if (!c) {
|
||||
return false;
|
||||
}
|
||||
for (auto* expr : c->Selectors()) {
|
||||
types.Push(expr->Type()->UnwrapRef());
|
||||
}
|
||||
cases.Push(c);
|
||||
behaviors.Add(c->Behaviors());
|
||||
sem->Cases().emplace_back(c);
|
||||
}
|
||||
|
||||
if (behaviors.Contains(sem::Behavior::kBreak)) {
|
||||
|
|
|
@ -209,7 +209,7 @@ class Resolver {
|
|||
sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
|
||||
sem::Statement* BreakStatement(const ast::BreakStatement*);
|
||||
sem::Statement* CallStatement(const ast::CallStatement*);
|
||||
sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
|
||||
sem::CaseStatement* CaseStatement(const ast::CaseStatement*, const sem::Type*);
|
||||
sem::Statement* CompoundAssignmentStatement(const ast::CompoundAssignmentStatement*);
|
||||
sem::Statement* ContinueStatement(const ast::ContinueStatement*);
|
||||
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
|
||||
|
|
|
@ -132,8 +132,6 @@ TEST_F(ResolverTest, Stmt_Case) {
|
|||
ASSERT_EQ(sem->Cases().size(), 2u);
|
||||
EXPECT_EQ(sem->Cases()[0]->Declaration(), cse);
|
||||
ASSERT_EQ(sem->Cases()[0]->Selectors().size(), 1u);
|
||||
EXPECT_EQ(sem->Cases()[0]->Selectors()[0]->Declaration(), sel);
|
||||
EXPECT_EQ(sem->Cases()[1]->Declaration(), def);
|
||||
EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 0u);
|
||||
}
|
||||
|
||||
|
|
|
@ -2338,22 +2338,36 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
|
|||
has_default = true;
|
||||
}
|
||||
|
||||
for (auto* selector : case_stmt->selectors) {
|
||||
if (cond_ty != sem_.TypeOf(selector)) {
|
||||
auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt);
|
||||
|
||||
auto& case_selectors = case_stmt->selectors;
|
||||
auto& selector_values = case_sem->Selectors();
|
||||
TINT_ASSERT(Resolver, case_selectors.Length() == selector_values.size());
|
||||
for (size_t i = 0; i < case_sem->Selectors().size(); ++i) {
|
||||
auto* selector = selector_values[i];
|
||||
if (cond_ty != selector->Type()) {
|
||||
AddError(
|
||||
"the case selector values must have the same type as the selector expression.",
|
||||
case_stmt->source);
|
||||
case_selectors[i]->source);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto it = selectors.find(selector->value);
|
||||
auto value = selector->As<uint32_t>();
|
||||
auto it = selectors.find(value);
|
||||
if (it != selectors.end()) {
|
||||
auto val = std::to_string(selector->value);
|
||||
AddError("duplicate switch case '" + val + "'", selector->source);
|
||||
std::string err = "duplicate switch case '";
|
||||
if (selector->Type()->Is<sem::I32>()) {
|
||||
err += std::to_string(selector->As<int32_t>());
|
||||
} else {
|
||||
err += std::to_string(value);
|
||||
}
|
||||
err += "'";
|
||||
|
||||
AddError(err, case_selectors[i]->source);
|
||||
AddNote("previous case declared here", it->second);
|
||||
return false;
|
||||
}
|
||||
selectors.emplace(selector->value, selector->source);
|
||||
selectors.emplace(value, case_selectors[i]->source);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ class SwitchStatement;
|
|||
} // namespace tint::ast
|
||||
namespace tint::sem {
|
||||
class CaseStatement;
|
||||
class Constant;
|
||||
class Expression;
|
||||
} // namespace tint::sem
|
||||
|
||||
|
@ -82,14 +83,14 @@ class CaseStatement final : public Castable<CaseStatement, CompoundStatement> {
|
|||
const BlockStatement* Body() const { return body_; }
|
||||
|
||||
/// @returns the selectors for the case
|
||||
std::vector<const Expression*>& Selectors() { return selectors_; }
|
||||
std::vector<const Constant*>& Selectors() { return selectors_; }
|
||||
|
||||
/// @returns the selectors for the case
|
||||
const std::vector<const Expression*>& Selectors() const { return selectors_; }
|
||||
const std::vector<const Constant*>& Selectors() const { return selectors_; }
|
||||
|
||||
private:
|
||||
const BlockStatement* body_ = nullptr;
|
||||
std::vector<const Expression*> selectors_;
|
||||
std::vector<const Constant*> selectors_;
|
||||
};
|
||||
|
||||
} // namespace tint::sem
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/storage_texture.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/switch_statement.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/sem/type_conversion.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
@ -1689,14 +1690,15 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
|
|||
if (stmt->IsDefault()) {
|
||||
line() << "default: {";
|
||||
} else {
|
||||
for (auto* selector : stmt->selectors) {
|
||||
auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
|
||||
for (auto* selector : sem->Selectors()) {
|
||||
auto out = line();
|
||||
out << "case ";
|
||||
if (!EmitLiteral(out, selector)) {
|
||||
if (!EmitConstant(out, selector)) {
|
||||
return false;
|
||||
}
|
||||
out << ":";
|
||||
if (selector == stmt->selectors.Back()) {
|
||||
if (selector == sem->Selectors().back()) {
|
||||
out << " {";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/storage_texture.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/switch_statement.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/sem/type_conversion.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
@ -2564,14 +2565,15 @@ bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) {
|
|||
if (stmt->IsDefault()) {
|
||||
line() << "default: {";
|
||||
} else {
|
||||
for (auto* selector : stmt->selectors) {
|
||||
auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
|
||||
for (auto* selector : sem->Selectors()) {
|
||||
auto out = line();
|
||||
out << "case ";
|
||||
if (!EmitLiteral(out, selector)) {
|
||||
if (!EmitConstant(out, selector)) {
|
||||
return false;
|
||||
}
|
||||
out << ":";
|
||||
if (selector == stmt->selectors.Back()) {
|
||||
if (selector == sem->Selectors().back()) {
|
||||
out << " {";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@
|
|||
#include "src/tint/sem/sampled_texture.h"
|
||||
#include "src/tint/sem/storage_texture.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/switch_statement.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/sem/type_conversion.h"
|
||||
#include "src/tint/sem/u32.h"
|
||||
|
@ -1591,14 +1592,15 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
|
|||
if (stmt->IsDefault()) {
|
||||
line() << "default: {";
|
||||
} else {
|
||||
for (auto* selector : stmt->selectors) {
|
||||
auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
|
||||
for (auto* selector : sem->Selectors()) {
|
||||
auto out = line();
|
||||
out << "case ";
|
||||
if (!EmitLiteral(out, selector)) {
|
||||
if (!EmitConstant(out, selector)) {
|
||||
return false;
|
||||
}
|
||||
out << ":";
|
||||
if (selector == stmt->selectors.Back()) {
|
||||
if (selector == sem->Selectors().back()) {
|
||||
out << " {";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "src/tint/sem/sampled_texture.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/switch_statement.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/sem/type_conversion.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
@ -3464,14 +3465,10 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) {
|
|||
auto block_id = std::get<uint32_t>(block);
|
||||
|
||||
case_ids.push_back(block_id);
|
||||
for (auto* selector : item->selectors) {
|
||||
auto* int_literal = selector->As<ast::IntLiteralExpression>();
|
||||
if (!int_literal) {
|
||||
error_ = "expected integer literal for switch case label";
|
||||
return false;
|
||||
}
|
||||
|
||||
params.push_back(Operand(static_cast<uint32_t>(int_literal->value)));
|
||||
auto* sem = builder_.Sem().Get<sem::CaseStatement>(item);
|
||||
for (auto* selector : sem->Selectors()) {
|
||||
params.push_back(Operand(selector->As<uint32_t>()));
|
||||
params.push_back(Operand(block_id));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@
|
|||
#include "src/tint/ast/void.h"
|
||||
#include "src/tint/ast/workgroup_attribute.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/switch_statement.h"
|
||||
#include "src/tint/utils/math.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
#include "src/tint/writer/float_to_string.h"
|
||||
|
@ -1030,13 +1031,13 @@ bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
|
|||
out << "case ";
|
||||
|
||||
bool first = true;
|
||||
for (auto* selector : stmt->selectors) {
|
||||
for (auto* expr : stmt->selectors) {
|
||||
if (!first) {
|
||||
out << ", ";
|
||||
}
|
||||
|
||||
first = false;
|
||||
if (!EmitLiteral(out, selector)) {
|
||||
if (!EmitExpression(out, expr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue