Use the new Switch() inferred types

Change-Id: I48ecd18957101631caa27480e7b1937a10791118
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81106
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-02-25 23:02:22 +00:00 committed by Tint LUCI CQ
parent f33f1b41ff
commit 2e6269acb0
3 changed files with 61 additions and 105 deletions

View File

@ -1125,7 +1125,7 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
[&](const Matrix* matrix_type) {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);

View File

@ -121,20 +121,14 @@ bool Resolver::ResolveInternal() {
for (auto* decl : dependencies_.ordered_globals) {
Mark(decl);
if (!Switch(
decl, //
[&](const ast::TypeDecl* td) { //
return TypeDecl(td) != nullptr;
},
[&](const ast::Function* func) {
return Function(func) != nullptr;
},
[&](const ast::Variable* var) {
return GlobalVariable(var) != nullptr;
},
decl, //
[&](const ast::TypeDecl* td) { return TypeDecl(td); },
[&](const ast::Function* func) { return Function(func); },
[&](const ast::Variable* var) { return GlobalVariable(var); },
[&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "unhandled global declaration: " << decl->TypeInfo().name;
return false;
return nullptr;
})) {
return false;
}
@ -165,23 +159,13 @@ bool Resolver::ResolveInternal() {
sem::Type* Resolver::Type(const ast::Type* ty) {
Mark(ty);
auto* s = Switch(
ty,
[&](const ast::Void*) -> sem::Type* {
return builder_->create<sem::Void>();
},
[&](const ast::Bool*) -> sem::Type* {
return builder_->create<sem::Bool>();
},
[&](const ast::I32*) -> sem::Type* {
return builder_->create<sem::I32>();
},
[&](const ast::U32*) -> sem::Type* {
return builder_->create<sem::U32>();
},
[&](const ast::F32*) -> sem::Type* {
return builder_->create<sem::F32>();
},
[&](const ast::Vector* t) -> sem::Type* {
ty, //
[&](const ast::Void*) { return builder_->create<sem::Void>(); },
[&](const ast::Bool*) { return builder_->create<sem::Bool>(); },
[&](const ast::I32*) { return builder_->create<sem::I32>(); },
[&](const ast::U32*) { return builder_->create<sem::U32>(); },
[&](const ast::F32*) { return builder_->create<sem::F32>(); },
[&](const ast::Vector* t) -> sem::Vector* {
if (!t->type) {
AddError("missing vector element type", t->source.End());
return nullptr;
@ -195,7 +179,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
return nullptr;
},
[&](const ast::Matrix* t) -> sem::Type* {
[&](const ast::Matrix* t) -> sem::Matrix* {
if (!t->type) {
AddError("missing matrix element type", t->source.End());
return nullptr;
@ -212,8 +196,8 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
return nullptr;
},
[&](const ast::Array* t) -> sem::Type* { return Array(t); },
[&](const ast::Atomic* t) -> sem::Type* {
[&](const ast::Array* t) { return Array(t); },
[&](const ast::Atomic* t) -> sem::Atomic* {
if (auto* el = Type(t->type)) {
auto* a = builder_->create<sem::Atomic>(el);
if (!ValidateAtomic(t, a)) {
@ -223,7 +207,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
return nullptr;
},
[&](const ast::Pointer* t) -> sem::Type* {
[&](const ast::Pointer* t) -> sem::Pointer* {
if (auto* el = Type(t->type)) {
auto access = t->access;
if (access == ast::kUndefined) {
@ -233,28 +217,28 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
return nullptr;
},
[&](const ast::Sampler* t) -> sem::Type* {
[&](const ast::Sampler* t) {
return builder_->create<sem::Sampler>(t->kind);
},
[&](const ast::SampledTexture* t) -> sem::Type* {
[&](const ast::SampledTexture* t) -> sem::SampledTexture* {
if (auto* el = Type(t->type)) {
return builder_->create<sem::SampledTexture>(t->dim, el);
}
return nullptr;
},
[&](const ast::MultisampledTexture* t) -> sem::Type* {
[&](const ast::MultisampledTexture* t) -> sem::MultisampledTexture* {
if (auto* el = Type(t->type)) {
return builder_->create<sem::MultisampledTexture>(t->dim, el);
}
return nullptr;
},
[&](const ast::DepthTexture* t) -> sem::Type* {
[&](const ast::DepthTexture* t) {
return builder_->create<sem::DepthTexture>(t->dim);
},
[&](const ast::DepthMultisampledTexture* t) -> sem::Type* {
[&](const ast::DepthMultisampledTexture* t) {
return builder_->create<sem::DepthMultisampledTexture>(t->dim);
},
[&](const ast::StorageTexture* t) -> sem::Type* {
[&](const ast::StorageTexture* t) -> sem::StorageTexture* {
if (auto* el = Type(t->type)) {
if (!ValidateStorageTexture(t)) {
return nullptr;
@ -264,10 +248,10 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
return nullptr;
},
[&](const ast::ExternalTexture*) -> sem::Type* {
[&](const ast::ExternalTexture*) {
return builder_->create<sem::ExternalTexture>();
},
[&](Default) -> sem::Type* {
[&](Default) {
auto* resolved = ResolvedSymbol(ty);
return Switch(
resolved, //
@ -858,62 +842,40 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
stmt,
// Compound statements. These create their own sem::CompoundStatement
// bindings.
[&](const ast::BlockStatement* b) -> sem::Statement* {
return BlockStatement(b);
},
[&](const ast::ForLoopStatement* l) -> sem::Statement* {
return ForLoopStatement(l);
},
[&](const ast::LoopStatement* l) -> sem::Statement* {
return LoopStatement(l);
},
[&](const ast::IfStatement* i) -> sem::Statement* {
return IfStatement(i);
},
[&](const ast::SwitchStatement* s) -> sem::Statement* {
return SwitchStatement(s);
},
[&](const ast::BlockStatement* b) { return BlockStatement(b); },
[&](const ast::ForLoopStatement* l) { return ForLoopStatement(l); },
[&](const ast::LoopStatement* l) { return LoopStatement(l); },
[&](const ast::IfStatement* i) { return IfStatement(i); },
[&](const ast::SwitchStatement* s) { return SwitchStatement(s); },
// Non-Compound statements
[&](const ast::AssignmentStatement* a) -> sem::Statement* {
return AssignmentStatement(a);
},
[&](const ast::BreakStatement* b) -> sem::Statement* {
return BreakStatement(b);
},
[&](const ast::CallStatement* c) -> sem::Statement* {
return CallStatement(c);
},
[&](const ast::ContinueStatement* c) -> sem::Statement* {
return ContinueStatement(c);
},
[&](const ast::DiscardStatement* d) -> sem::Statement* {
return DiscardStatement(d);
},
[&](const ast::FallthroughStatement* f) -> sem::Statement* {
[&](const ast::AssignmentStatement* a) { return AssignmentStatement(a); },
[&](const ast::BreakStatement* b) { return BreakStatement(b); },
[&](const ast::CallStatement* c) { return CallStatement(c); },
[&](const ast::ContinueStatement* c) { return ContinueStatement(c); },
[&](const ast::DiscardStatement* d) { return DiscardStatement(d); },
[&](const ast::FallthroughStatement* f) {
return FallthroughStatement(f);
},
[&](const ast::ReturnStatement* r) -> sem::Statement* {
return ReturnStatement(r);
},
[&](const ast::VariableDeclStatement* v) -> sem::Statement* {
[&](const ast::ReturnStatement* r) { return ReturnStatement(r); },
[&](const ast::VariableDeclStatement* v) {
return VariableDeclStatement(v);
},
// Error cases
[&](const ast::CaseStatement*) -> sem::Statement* {
[&](const ast::CaseStatement*) {
AddError("case statement can only be used inside a switch statement",
stmt->source);
return nullptr;
},
[&](const ast::ElseStatement*) -> sem::Statement* {
[&](const ast::ElseStatement*) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::Statement() encountered an Else statement. Else "
"statements are embedded in If statements, so should never be "
"encountered as top-level statements";
return nullptr;
},
[&](Default) -> sem::Statement* {
[&](Default) {
AddError(
"unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source);
@ -1196,16 +1158,12 @@ sem::Expression* Resolver::IndexAccessor(
auto* obj_ty = obj_raw_ty->UnwrapRef();
auto* ty = Switch(
obj_ty, //
[&](const sem::Array* arr) -> const sem::Type* {
return arr->ElemType();
},
[&](const sem::Vector* vec) -> const sem::Type* { //
return vec->type();
},
[&](const sem::Matrix* mat) -> const sem::Type* {
[&](const sem::Array* arr) { return arr->ElemType(); },
[&](const sem::Vector* vec) { return vec->type(); },
[&](const sem::Matrix* mat) {
return builder_->create<sem::Vector>(mat->type(), mat->rows());
},
[&](Default) -> const sem::Type* {
[&](Default) {
AddError("cannot index type '" + TypeNameOf(obj_ty) + "'",
expr->source);
return nullptr;
@ -2188,19 +2146,19 @@ std::string Resolver::RawTypeNameOf(const sem::Type* ty) {
sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
return Switch(
lit,
[&](const ast::SintLiteralExpression*) -> sem::Type* {
[&](const ast::SintLiteralExpression*) {
return builder_->create<sem::I32>();
},
[&](const ast::UintLiteralExpression*) -> sem::Type* {
[&](const ast::UintLiteralExpression*) {
return builder_->create<sem::U32>();
},
[&](const ast::FloatLiteralExpression*) -> sem::Type* {
[&](const ast::FloatLiteralExpression*) {
return builder_->create<sem::F32>();
},
[&](const ast::BoolLiteralExpression*) -> sem::Type* {
[&](const ast::BoolLiteralExpression*) {
return builder_->create<sem::Bool>();
},
[&](Default) -> sem::Type* {
[&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Unhandled literal type: " << lit->TypeInfo().name;
return nullptr;

View File

@ -575,31 +575,29 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
[&](const ast::IndexAccessorExpression* a) {
return GenerateAccessorExpression(a);
},
[&](const ast::BinaryExpression* b) { //
[&](const ast::BinaryExpression* b) {
return GenerateBinaryExpression(b);
},
[&](const ast::BitcastExpression* b) { //
[&](const ast::BitcastExpression* b) {
return GenerateBitcastExpression(b);
},
[&](const ast::CallExpression* c) { //
return GenerateCallExpression(c);
},
[&](const ast::IdentifierExpression* i) { //
[&](const ast::CallExpression* c) { return GenerateCallExpression(c); },
[&](const ast::IdentifierExpression* i) {
return GenerateIdentifierExpression(i);
},
[&](const ast::LiteralExpression* l) { //
[&](const ast::LiteralExpression* l) {
return GenerateLiteralIfNeeded(nullptr, l);
},
[&](const ast::MemberAccessorExpression* m) { //
[&](const ast::MemberAccessorExpression* m) {
return GenerateAccessorExpression(m);
},
[&](const ast::UnaryOpExpression* u) { //
[&](const ast::UnaryOpExpression* u) {
return GenerateUnaryOpExpression(u);
},
[&](Default) -> uint32_t {
[&](Default) {
error_ =
"unknown expression type: " + std::string(expr->TypeInfo().name);
return 0;
@ -2271,7 +2269,7 @@ uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
[&](const sem::TypeConstructor*) {
return GenerateTypeConstructorOrConversion(call, nullptr);
},
[&](Default) -> uint32_t {
[&](Default) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name;
return 0;
@ -4101,7 +4099,7 @@ bool Builder::GenerateTextureType(const sem::Texture* texture,
[&](const sem::StorageTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](Default) -> uint32_t { //
[&](Default) {
return 0u;
});
if (type_id == 0u) {