Use the new Switch() inferred types

Change-Id: I48ecd18957101631caa27480e7b1937a10791118
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81106
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-02-25 23:02:22 +00:00 committed by Tint LUCI CQ
parent f33f1b41ff
commit 2e6269acb0
3 changed files with 61 additions and 105 deletions

View File

@ -1125,7 +1125,7 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
// Recursively flatten matrices, arrays, and structures. // Recursively flatten matrices, arrays, and structures.
return Switch( return Switch(
tip_type, tip_type,
[&](const Matrix* matrix_type) -> bool { [&](const Matrix* matrix_type) {
index_prefix.push_back(0); index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns); const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);

View File

@ -122,19 +122,13 @@ bool Resolver::ResolveInternal() {
Mark(decl); Mark(decl);
if (!Switch( if (!Switch(
decl, // decl, //
[&](const ast::TypeDecl* td) { // [&](const ast::TypeDecl* td) { return TypeDecl(td); },
return TypeDecl(td) != nullptr; [&](const ast::Function* func) { return Function(func); },
}, [&](const ast::Variable* var) { return GlobalVariable(var); },
[&](const ast::Function* func) {
return Function(func) != nullptr;
},
[&](const ast::Variable* var) {
return GlobalVariable(var) != nullptr;
},
[&](Default) { [&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_) TINT_UNREACHABLE(Resolver, diagnostics_)
<< "unhandled global declaration: " << decl->TypeInfo().name; << "unhandled global declaration: " << decl->TypeInfo().name;
return false; return nullptr;
})) { })) {
return false; return false;
} }
@ -165,23 +159,13 @@ bool Resolver::ResolveInternal() {
sem::Type* Resolver::Type(const ast::Type* ty) { sem::Type* Resolver::Type(const ast::Type* ty) {
Mark(ty); Mark(ty);
auto* s = Switch( auto* s = Switch(
ty, ty, //
[&](const ast::Void*) -> sem::Type* { [&](const ast::Void*) { return builder_->create<sem::Void>(); },
return builder_->create<sem::Void>(); [&](const ast::Bool*) { return builder_->create<sem::Bool>(); },
}, [&](const ast::I32*) { return builder_->create<sem::I32>(); },
[&](const ast::Bool*) -> sem::Type* { [&](const ast::U32*) { return builder_->create<sem::U32>(); },
return builder_->create<sem::Bool>(); [&](const ast::F32*) { return builder_->create<sem::F32>(); },
}, [&](const ast::Vector* t) -> sem::Vector* {
[&](const ast::I32*) -> sem::Type* {
return builder_->create<sem::I32>();
},
[&](const ast::U32*) -> sem::Type* {
return builder_->create<sem::U32>();
},
[&](const ast::F32*) -> sem::Type* {
return builder_->create<sem::F32>();
},
[&](const ast::Vector* t) -> sem::Type* {
if (!t->type) { if (!t->type) {
AddError("missing vector element type", t->source.End()); AddError("missing vector element type", t->source.End());
return nullptr; return nullptr;
@ -195,7 +179,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
} }
return nullptr; return nullptr;
}, },
[&](const ast::Matrix* t) -> sem::Type* { [&](const ast::Matrix* t) -> sem::Matrix* {
if (!t->type) { if (!t->type) {
AddError("missing matrix element type", t->source.End()); AddError("missing matrix element type", t->source.End());
return nullptr; return nullptr;
@ -212,8 +196,8 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
} }
return nullptr; return nullptr;
}, },
[&](const ast::Array* t) -> sem::Type* { return Array(t); }, [&](const ast::Array* t) { return Array(t); },
[&](const ast::Atomic* t) -> sem::Type* { [&](const ast::Atomic* t) -> sem::Atomic* {
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
auto* a = builder_->create<sem::Atomic>(el); auto* a = builder_->create<sem::Atomic>(el);
if (!ValidateAtomic(t, a)) { if (!ValidateAtomic(t, a)) {
@ -223,7 +207,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
} }
return nullptr; return nullptr;
}, },
[&](const ast::Pointer* t) -> sem::Type* { [&](const ast::Pointer* t) -> sem::Pointer* {
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
auto access = t->access; auto access = t->access;
if (access == ast::kUndefined) { if (access == ast::kUndefined) {
@ -233,28 +217,28 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
} }
return nullptr; return nullptr;
}, },
[&](const ast::Sampler* t) -> sem::Type* { [&](const ast::Sampler* t) {
return builder_->create<sem::Sampler>(t->kind); return builder_->create<sem::Sampler>(t->kind);
}, },
[&](const ast::SampledTexture* t) -> sem::Type* { [&](const ast::SampledTexture* t) -> sem::SampledTexture* {
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
return builder_->create<sem::SampledTexture>(t->dim, el); return builder_->create<sem::SampledTexture>(t->dim, el);
} }
return nullptr; return nullptr;
}, },
[&](const ast::MultisampledTexture* t) -> sem::Type* { [&](const ast::MultisampledTexture* t) -> sem::MultisampledTexture* {
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
return builder_->create<sem::MultisampledTexture>(t->dim, el); return builder_->create<sem::MultisampledTexture>(t->dim, el);
} }
return nullptr; return nullptr;
}, },
[&](const ast::DepthTexture* t) -> sem::Type* { [&](const ast::DepthTexture* t) {
return builder_->create<sem::DepthTexture>(t->dim); return builder_->create<sem::DepthTexture>(t->dim);
}, },
[&](const ast::DepthMultisampledTexture* t) -> sem::Type* { [&](const ast::DepthMultisampledTexture* t) {
return builder_->create<sem::DepthMultisampledTexture>(t->dim); return builder_->create<sem::DepthMultisampledTexture>(t->dim);
}, },
[&](const ast::StorageTexture* t) -> sem::Type* { [&](const ast::StorageTexture* t) -> sem::StorageTexture* {
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
if (!ValidateStorageTexture(t)) { if (!ValidateStorageTexture(t)) {
return nullptr; return nullptr;
@ -264,10 +248,10 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
} }
return nullptr; return nullptr;
}, },
[&](const ast::ExternalTexture*) -> sem::Type* { [&](const ast::ExternalTexture*) {
return builder_->create<sem::ExternalTexture>(); return builder_->create<sem::ExternalTexture>();
}, },
[&](Default) -> sem::Type* { [&](Default) {
auto* resolved = ResolvedSymbol(ty); auto* resolved = ResolvedSymbol(ty);
return Switch( return Switch(
resolved, // resolved, //
@ -858,62 +842,40 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
stmt, stmt,
// Compound statements. These create their own sem::CompoundStatement // Compound statements. These create their own sem::CompoundStatement
// bindings. // bindings.
[&](const ast::BlockStatement* b) -> sem::Statement* { [&](const ast::BlockStatement* b) { return BlockStatement(b); },
return BlockStatement(b); [&](const ast::ForLoopStatement* l) { return ForLoopStatement(l); },
}, [&](const ast::LoopStatement* l) { return LoopStatement(l); },
[&](const ast::ForLoopStatement* l) -> sem::Statement* { [&](const ast::IfStatement* i) { return IfStatement(i); },
return ForLoopStatement(l); [&](const ast::SwitchStatement* s) { return SwitchStatement(s); },
},
[&](const ast::LoopStatement* l) -> sem::Statement* {
return LoopStatement(l);
},
[&](const ast::IfStatement* i) -> sem::Statement* {
return IfStatement(i);
},
[&](const ast::SwitchStatement* s) -> sem::Statement* {
return SwitchStatement(s);
},
// Non-Compound statements // Non-Compound statements
[&](const ast::AssignmentStatement* a) -> sem::Statement* { [&](const ast::AssignmentStatement* a) { return AssignmentStatement(a); },
return AssignmentStatement(a); [&](const ast::BreakStatement* b) { return BreakStatement(b); },
}, [&](const ast::CallStatement* c) { return CallStatement(c); },
[&](const ast::BreakStatement* b) -> sem::Statement* { [&](const ast::ContinueStatement* c) { return ContinueStatement(c); },
return BreakStatement(b); [&](const ast::DiscardStatement* d) { return DiscardStatement(d); },
}, [&](const ast::FallthroughStatement* f) {
[&](const ast::CallStatement* c) -> sem::Statement* {
return CallStatement(c);
},
[&](const ast::ContinueStatement* c) -> sem::Statement* {
return ContinueStatement(c);
},
[&](const ast::DiscardStatement* d) -> sem::Statement* {
return DiscardStatement(d);
},
[&](const ast::FallthroughStatement* f) -> sem::Statement* {
return FallthroughStatement(f); return FallthroughStatement(f);
}, },
[&](const ast::ReturnStatement* r) -> sem::Statement* { [&](const ast::ReturnStatement* r) { return ReturnStatement(r); },
return ReturnStatement(r); [&](const ast::VariableDeclStatement* v) {
},
[&](const ast::VariableDeclStatement* v) -> sem::Statement* {
return VariableDeclStatement(v); return VariableDeclStatement(v);
}, },
// Error cases // Error cases
[&](const ast::CaseStatement*) -> sem::Statement* { [&](const ast::CaseStatement*) {
AddError("case statement can only be used inside a switch statement", AddError("case statement can only be used inside a switch statement",
stmt->source); stmt->source);
return nullptr; return nullptr;
}, },
[&](const ast::ElseStatement*) -> sem::Statement* { [&](const ast::ElseStatement*) {
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "Resolver::Statement() encountered an Else statement. Else " << "Resolver::Statement() encountered an Else statement. Else "
"statements are embedded in If statements, so should never be " "statements are embedded in If statements, so should never be "
"encountered as top-level statements"; "encountered as top-level statements";
return nullptr; return nullptr;
}, },
[&](Default) -> sem::Statement* { [&](Default) {
AddError( AddError(
"unknown statement type: " + std::string(stmt->TypeInfo().name), "unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source); stmt->source);
@ -1196,16 +1158,12 @@ sem::Expression* Resolver::IndexAccessor(
auto* obj_ty = obj_raw_ty->UnwrapRef(); auto* obj_ty = obj_raw_ty->UnwrapRef();
auto* ty = Switch( auto* ty = Switch(
obj_ty, // obj_ty, //
[&](const sem::Array* arr) -> const sem::Type* { [&](const sem::Array* arr) { return arr->ElemType(); },
return arr->ElemType(); [&](const sem::Vector* vec) { return vec->type(); },
}, [&](const sem::Matrix* mat) {
[&](const sem::Vector* vec) -> const sem::Type* { //
return vec->type();
},
[&](const sem::Matrix* mat) -> const sem::Type* {
return builder_->create<sem::Vector>(mat->type(), mat->rows()); return builder_->create<sem::Vector>(mat->type(), mat->rows());
}, },
[&](Default) -> const sem::Type* { [&](Default) {
AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", AddError("cannot index type '" + TypeNameOf(obj_ty) + "'",
expr->source); expr->source);
return nullptr; return nullptr;
@ -2188,19 +2146,19 @@ std::string Resolver::RawTypeNameOf(const sem::Type* ty) {
sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) { sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
return Switch( return Switch(
lit, lit,
[&](const ast::SintLiteralExpression*) -> sem::Type* { [&](const ast::SintLiteralExpression*) {
return builder_->create<sem::I32>(); return builder_->create<sem::I32>();
}, },
[&](const ast::UintLiteralExpression*) -> sem::Type* { [&](const ast::UintLiteralExpression*) {
return builder_->create<sem::U32>(); return builder_->create<sem::U32>();
}, },
[&](const ast::FloatLiteralExpression*) -> sem::Type* { [&](const ast::FloatLiteralExpression*) {
return builder_->create<sem::F32>(); return builder_->create<sem::F32>();
}, },
[&](const ast::BoolLiteralExpression*) -> sem::Type* { [&](const ast::BoolLiteralExpression*) {
return builder_->create<sem::Bool>(); return builder_->create<sem::Bool>();
}, },
[&](Default) -> sem::Type* { [&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_) TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Unhandled literal type: " << lit->TypeInfo().name; << "Unhandled literal type: " << lit->TypeInfo().name;
return nullptr; return nullptr;

View File

@ -575,31 +575,29 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
uint32_t Builder::GenerateExpression(const ast::Expression* expr) { uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
return Switch( return Switch(
expr, expr,
[&](const ast::IndexAccessorExpression* a) { // [&](const ast::IndexAccessorExpression* a) {
return GenerateAccessorExpression(a); return GenerateAccessorExpression(a);
}, },
[&](const ast::BinaryExpression* b) { // [&](const ast::BinaryExpression* b) {
return GenerateBinaryExpression(b); return GenerateBinaryExpression(b);
}, },
[&](const ast::BitcastExpression* b) { // [&](const ast::BitcastExpression* b) {
return GenerateBitcastExpression(b); return GenerateBitcastExpression(b);
}, },
[&](const ast::CallExpression* c) { // [&](const ast::CallExpression* c) { return GenerateCallExpression(c); },
return GenerateCallExpression(c); [&](const ast::IdentifierExpression* i) {
},
[&](const ast::IdentifierExpression* i) { //
return GenerateIdentifierExpression(i); return GenerateIdentifierExpression(i);
}, },
[&](const ast::LiteralExpression* l) { // [&](const ast::LiteralExpression* l) {
return GenerateLiteralIfNeeded(nullptr, l); return GenerateLiteralIfNeeded(nullptr, l);
}, },
[&](const ast::MemberAccessorExpression* m) { // [&](const ast::MemberAccessorExpression* m) {
return GenerateAccessorExpression(m); return GenerateAccessorExpression(m);
}, },
[&](const ast::UnaryOpExpression* u) { // [&](const ast::UnaryOpExpression* u) {
return GenerateUnaryOpExpression(u); return GenerateUnaryOpExpression(u);
}, },
[&](Default) -> uint32_t { [&](Default) {
error_ = error_ =
"unknown expression type: " + std::string(expr->TypeInfo().name); "unknown expression type: " + std::string(expr->TypeInfo().name);
return 0; return 0;
@ -2271,7 +2269,7 @@ uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
[&](const sem::TypeConstructor*) { [&](const sem::TypeConstructor*) {
return GenerateTypeConstructorOrConversion(call, nullptr); return GenerateTypeConstructorOrConversion(call, nullptr);
}, },
[&](Default) -> uint32_t { [&](Default) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name; << "unhandled call target: " << target->TypeInfo().name;
return 0; return 0;
@ -4101,7 +4099,7 @@ bool Builder::GenerateTextureType(const sem::Texture* texture,
[&](const sem::StorageTexture* t) { [&](const sem::StorageTexture* t) {
return GenerateTypeIfNeeded(t->type()); return GenerateTypeIfNeeded(t->type());
}, },
[&](Default) -> uint32_t { // [&](Default) {
return 0u; return 0u;
}); });
if (type_id == 0u) { if (type_id == 0u) {