From de857e1c581cfa9b7b2b511a44424713cbe47822 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 4 Feb 2022 15:38:23 +0000 Subject: [PATCH] Add tint::Switch() A type dispatch helper with replaces chains of: if (auto* a = obj->As()) { ... } else if (auto* b = obj->As()) { ... } else { ... } with: Switch(obj, [&](A* a) { ... }, [&](B* b) { ... }, [&](Default) { ... }); This new helper provides greater opportunities for optimizations, avoids scoping issues with if-else blocks, and is slightly cleaner (IMO). Bug: tint:1383 Change-Id: Ice469a03342ef57cbcf65f69753e4b528ac50137 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78543 Kokoro: Kokoro Reviewed-by: David Neto Commit-Queue: Ben Clayton --- src/CMakeLists.txt | 1 + src/ast/module.cc | 50 +- src/ast/traverse_expressions.h | 65 +- src/castable.h | 99 +++ src/castable_bench.cc | 270 ++++++++ src/castable_test.cc | 145 +++++ src/reader/spirv/function.cc | 402 ++++++------ src/writer/hlsl/generator_impl.cc | 742 +++++++++++---------- src/writer/msl/generator_impl.cc | 1002 +++++++++++++++-------------- src/writer/spirv/builder.cc | 499 +++++++------- src/writer/wgsl/generator_impl.cc | 708 +++++++++++--------- 11 files changed, 2413 insertions(+), 1570 deletions(-) create mode 100644 src/castable_bench.cc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8c0f20a1fe..600e91f7e2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1160,6 +1160,7 @@ if(TINT_BUILD_BENCHMARKS) endif() set(TINT_BENCHMARK_SRC + "castable_bench.cc" "bench/benchmark.cc" "reader/wgsl/parser_bench.cc" ) diff --git a/src/ast/module.cc b/src/ast/module.cc index 24999d2f1d..3f06a31855 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -35,16 +35,15 @@ Module::Module(ProgramID pid, continue; } - if (auto* ty = decl->As()) { - type_decls_.push_back(ty); - } else if (auto* func = decl->As()) { - functions_.push_back(func); - } else if (auto* var = decl->As()) { - global_variables_.push_back(var); - } else { - diag::List diagnostics; - TINT_ICE(AST, diagnostics) << "Unknown global declaration type"; - } + Switch( + decl, // + [&](const ast::TypeDecl* type) { type_decls_.push_back(type); }, + [&](const Function* func) { functions_.push_back(func); }, + [&](const Variable* var) { global_variables_.push_back(var); }, + [&](Default) { + diag::List diagnostics; + TINT_ICE(AST, diagnostics) << "Unknown global declaration type"; + }); } } @@ -101,19 +100,24 @@ void Module::Copy(CloneContext* ctx, const Module* src) { << "src global declaration was nullptr"; continue; } - if (auto* type = decl->As()) { - TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id); - type_decls_.push_back(type); - } else if (auto* func = decl->As()) { - TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id); - functions_.push_back(func); - } else if (auto* var = decl->As()) { - TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id); - global_variables_.push_back(var); - } else { - TINT_ICE(AST, ctx->dst->Diagnostics()) - << "Unknown global declaration type"; - } + Switch( + decl, + [&](const ast::TypeDecl* type) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id); + type_decls_.push_back(type); + }, + [&](const Function* func) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id); + functions_.push_back(func); + }, + [&](const Variable* var) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id); + global_variables_.push_back(var); + }, + [&](Default) { + TINT_ICE(AST, ctx->dst->Diagnostics()) + << "Unknown global declaration type"; + }); } } diff --git a/src/ast/traverse_expressions.h b/src/ast/traverse_expressions.h index 88d3dfc24d..b5789410cd 100644 --- a/src/ast/traverse_expressions.h +++ b/src/ast/traverse_expressions.h @@ -101,30 +101,47 @@ bool TraverseExpressions(const ast::Expression* root, } } - if (auto* idx = expr->As()) { - push_pair(idx->object, idx->index); - } else if (auto* bin_op = expr->As()) { - push_pair(bin_op->lhs, bin_op->rhs); - } else if (auto* bitcast = expr->As()) { - to_visit.push_back(bitcast->expr); - } else if (auto* call = expr->As()) { - // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the - // function name in the traversal. - // to_visit.push_back(call->func); - push_list(call->args); - } else if (auto* member = expr->As()) { - // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the - // member name in the traversal. - // push_pair(member->structure, member->member); - to_visit.push_back(member->structure); - } else if (auto* unary = expr->As()) { - to_visit.push_back(unary->expr); - } else if (expr->IsAnyOf()) { - // Leaf expression - } else { - TINT_ICE(AST, diags) << "unhandled expression type: " - << expr->TypeInfo().name; + bool ok = Switch( + expr, + [&](const IndexAccessorExpression* idx) { + push_pair(idx->object, idx->index); + return true; + }, + [&](const BinaryExpression* bin_op) { + push_pair(bin_op->lhs, bin_op->rhs); + return true; + }, + [&](const BitcastExpression* bitcast) { + to_visit.push_back(bitcast->expr); + return true; + }, + [&](const CallExpression* call) { + // TODO(crbug.com/tint/1257): Resolver breaks if we actually include + // the function name in the traversal. to_visit.push_back(call->func); + push_list(call->args); + return true; + }, + [&](const MemberAccessorExpression* member) { + // TODO(crbug.com/tint/1257): Resolver breaks if we actually include + // the member name in the traversal. push_pair(member->structure, + // member->member); + to_visit.push_back(member->structure); + return true; + }, + [&](const UnaryOpExpression* unary) { + to_visit.push_back(unary->expr); + return true; + }, + [&](Default) { + if (expr->IsAnyOf()) { + return true; // Leaf expression + } + TINT_ICE(AST, diags) + << "unhandled expression type: " << expr->TypeInfo().name; + return false; + }); + if (!ok) { return false; } } diff --git a/src/castable.h b/src/castable.h index 3104492965..04f2dcc789 100644 --- a/src/castable.h +++ b/src/castable.h @@ -453,6 +453,105 @@ class Castable : public BASE { } }; +/// Default can be used as the default case for a Switch(), when all previous +/// cases failed to match. +/// +/// Example: +/// ``` +/// Switch(object, +/// [&](TypeA*) { /* ... */ }, +/// [&](TypeB*) { /* ... */ }, +/// [&](Default) { /* If not TypeA or TypeB */ }); +/// ``` +struct Default {}; + +/// Switch is used to dispatch one of the provided callback case handler +/// functions based on the type of `object` and the parameter type of the case +/// handlers. Switch will sequentially check the type of `object` against each +/// of the switch case handler functions, and will invoke the first case handler +/// function which has a parameter type that matches the object type. When a +/// case handler is matched, it will be called with the single argument of +/// `object` cast to the case handler's parameter type. Switch will invoke at +/// most one case handler. Each of the case functions must have the signature +/// `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R` +/// is the return type, consistent across all case handlers. +/// +/// An optional default case function with the signature `R(Default)` can be +/// used as the last case. This default case will be called if all previous +/// cases failed to match. +/// +/// Example: +/// ``` +/// Switch(object, +/// [&](TypeA*) { /* ... */ }, +/// [&](TypeB*) { /* ... */ }); +/// +/// Switch(object, +/// [&](TypeA*) { /* ... */ }, +/// [&](TypeB*) { /* ... */ }, +/// [&](Default) { /* Called if object is not TypeA or TypeB */ }); +/// ``` +/// +/// @param object the object who's type is used to +/// @param first_case the first switch case +/// @param other_cases additional switch cases (optional) +/// @return the value returned by the called case. If no cases matched, then the +/// zero value for the consistent case type. +template +traits::ReturnType // +Switch(T* object, FIRST_CASE&& first_case, OTHER_CASES&&... other_cases) { + using ReturnType = traits::ReturnType; + using CaseType = std::remove_pointer_t>; + static constexpr bool kHasReturnType = !std::is_same_v; + static_assert(traits::SignatureOfT::parameter_count == 1, + "Switch case must have a single parameter"); + if constexpr (std::is_same_v) { + // Default case. Must be last. + (void)object; // 'object' is not used by the Default case. + static_assert(sizeof...(other_cases) == 0, + "Switch Default case must come last"); + if constexpr (kHasReturnType) { + return first_case({}); + } else { + first_case({}); + return; + } + } else { + // Regular case. + static_assert(traits::IsTypeOrDerived::value, + "Switch case parameter is not a Castable pointer"); + // Does the case match? + if (auto* ptr = As(object)) { + if constexpr (kHasReturnType) { + return first_case(ptr); + } else { + first_case(ptr); + return; + } + } + // Case did not match. Got any more cases to try? + if constexpr (sizeof...(other_cases) > 0) { + // Try the next cases... + if constexpr (kHasReturnType) { + auto res = Switch(object, std::forward(other_cases)...); + static_assert(std::is_same_v, + "Switch case types do not have consistent return type"); + return res; + } else { + Switch(object, std::forward(other_cases)...); + return; + } + } else { + // That was the last case. No cases matched. + if constexpr (kHasReturnType) { + return {}; + } else { + return; + } + } + } +} + } // namespace tint TINT_CASTABLE_POP_DISABLE_WARNINGS(); diff --git a/src/castable_bench.cc b/src/castable_bench.cc new file mode 100644 index 0000000000..839a932f5c --- /dev/null +++ b/src/castable_bench.cc @@ -0,0 +1,270 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bench/benchmark.h" + +namespace tint { +namespace { + +struct Base : public tint::Castable {}; +struct A : public tint::Castable {}; +struct AA : public tint::Castable {}; +struct AAA : public tint::Castable {}; +struct AAB : public tint::Castable {}; +struct AAC : public tint::Castable {}; +struct AB : public tint::Castable {}; +struct ABA : public tint::Castable {}; +struct ABB : public tint::Castable {}; +struct ABC : public tint::Castable {}; +struct AC : public tint::Castable {}; +struct ACA : public tint::Castable {}; +struct ACB : public tint::Castable {}; +struct ACC : public tint::Castable {}; +struct B : public tint::Castable {}; +struct BA : public tint::Castable {}; +struct BAA : public tint::Castable {}; +struct BAB : public tint::Castable {}; +struct BAC : public tint::Castable {}; +struct BB : public tint::Castable {}; +struct BBA : public tint::Castable {}; +struct BBB : public tint::Castable {}; +struct BBC : public tint::Castable {}; +struct BC : public tint::Castable {}; +struct BCA : public tint::Castable {}; +struct BCB : public tint::Castable {}; +struct BCC : public tint::Castable {}; +struct C : public tint::Castable {}; +struct CA : public tint::Castable {}; +struct CAA : public tint::Castable {}; +struct CAB : public tint::Castable {}; +struct CAC : public tint::Castable {}; +struct CB : public tint::Castable {}; +struct CBA : public tint::Castable {}; +struct CBB : public tint::Castable {}; +struct CBC : public tint::Castable {}; +struct CC : public tint::Castable {}; +struct CCA : public tint::Castable {}; +struct CCB : public tint::Castable {}; +struct CCC : public tint::Castable {}; + +using AllTypes = std::tuple; + +std::vector> MakeObjects() { + std::vector> out; + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + out.emplace_back(std::make_unique()); + return out; +} + +void CastableLargeSwitch(::benchmark::State& state) { + auto objects = MakeObjects(); + size_t i = 0; + for (auto _ : state) { + auto* object = objects[i % objects.size()].get(); + Switch( + object, // + [&](const AAA*) { ::benchmark::DoNotOptimize(i += 40); }, + [&](const AAB*) { ::benchmark::DoNotOptimize(i += 50); }, + [&](const AAC*) { ::benchmark::DoNotOptimize(i += 60); }, + [&](const ABA*) { ::benchmark::DoNotOptimize(i += 80); }, + [&](const ABB*) { ::benchmark::DoNotOptimize(i += 90); }, + [&](const ABC*) { ::benchmark::DoNotOptimize(i += 100); }, + [&](const ACA*) { ::benchmark::DoNotOptimize(i += 120); }, + [&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); }, + [&](const ACC*) { ::benchmark::DoNotOptimize(i += 140); }, + [&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); }, + [&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); }, + [&](const BAC*) { ::benchmark::DoNotOptimize(i += 190); }, + [&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); }, + [&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); }, + [&](const BBC*) { ::benchmark::DoNotOptimize(i += 230); }, + [&](const BCA*) { ::benchmark::DoNotOptimize(i += 250); }, + [&](const BCB*) { ::benchmark::DoNotOptimize(i += 260); }, + [&](const BCC*) { ::benchmark::DoNotOptimize(i += 270); }, + [&](const CA*) { ::benchmark::DoNotOptimize(i += 290); }, + [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); }, + [&](const CAB*) { ::benchmark::DoNotOptimize(i += 310); }, + [&](const CAC*) { ::benchmark::DoNotOptimize(i += 320); }, + [&](const CBA*) { ::benchmark::DoNotOptimize(i += 340); }, + [&](const CBB*) { ::benchmark::DoNotOptimize(i += 350); }, + [&](const CBC*) { ::benchmark::DoNotOptimize(i += 360); }, + [&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); }, + [&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); }, + [&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); }, + [&](Default) { ::benchmark::DoNotOptimize(i += 123); }); + i = (i * 31) ^ (i << 5); + } +} + +BENCHMARK(CastableLargeSwitch); + +void CastableMediumSwitch(::benchmark::State& state) { + auto objects = MakeObjects(); + size_t i = 0; + for (auto _ : state) { + auto* object = objects[i % objects.size()].get(); + Switch( + object, // + [&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); }, + [&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); }, + [&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); }, + [&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); }, + [&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); }, + [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); }, + [&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); }, + [&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); }, + [&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); }, + [&](Default) { ::benchmark::DoNotOptimize(i += 123); }); + i = (i * 31) ^ (i << 5); + } +} + +BENCHMARK(CastableMediumSwitch); + +void CastableSmallSwitch(::benchmark::State& state) { + auto objects = MakeObjects(); + size_t i = 0; + for (auto _ : state) { + auto* object = objects[i % objects.size()].get(); + Switch( + object, // + [&](const AAB*) { ::benchmark::DoNotOptimize(i += 30); }, + [&](const CAC*) { ::benchmark::DoNotOptimize(i += 290); }, + [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); }); + i = (i * 31) ^ (i << 5); + } +} + +BENCHMARK(CastableSmallSwitch); + +} // namespace +} // namespace tint + +TINT_INSTANTIATE_TYPEINFO(tint::Base); +TINT_INSTANTIATE_TYPEINFO(tint::A); +TINT_INSTANTIATE_TYPEINFO(tint::AA); +TINT_INSTANTIATE_TYPEINFO(tint::AAA); +TINT_INSTANTIATE_TYPEINFO(tint::AAB); +TINT_INSTANTIATE_TYPEINFO(tint::AAC); +TINT_INSTANTIATE_TYPEINFO(tint::AB); +TINT_INSTANTIATE_TYPEINFO(tint::ABA); +TINT_INSTANTIATE_TYPEINFO(tint::ABB); +TINT_INSTANTIATE_TYPEINFO(tint::ABC); +TINT_INSTANTIATE_TYPEINFO(tint::AC); +TINT_INSTANTIATE_TYPEINFO(tint::ACA); +TINT_INSTANTIATE_TYPEINFO(tint::ACB); +TINT_INSTANTIATE_TYPEINFO(tint::ACC); +TINT_INSTANTIATE_TYPEINFO(tint::B); +TINT_INSTANTIATE_TYPEINFO(tint::BA); +TINT_INSTANTIATE_TYPEINFO(tint::BAA); +TINT_INSTANTIATE_TYPEINFO(tint::BAB); +TINT_INSTANTIATE_TYPEINFO(tint::BAC); +TINT_INSTANTIATE_TYPEINFO(tint::BB); +TINT_INSTANTIATE_TYPEINFO(tint::BBA); +TINT_INSTANTIATE_TYPEINFO(tint::BBB); +TINT_INSTANTIATE_TYPEINFO(tint::BBC); +TINT_INSTANTIATE_TYPEINFO(tint::BC); +TINT_INSTANTIATE_TYPEINFO(tint::BCA); +TINT_INSTANTIATE_TYPEINFO(tint::BCB); +TINT_INSTANTIATE_TYPEINFO(tint::BCC); +TINT_INSTANTIATE_TYPEINFO(tint::C); +TINT_INSTANTIATE_TYPEINFO(tint::CA); +TINT_INSTANTIATE_TYPEINFO(tint::CAA); +TINT_INSTANTIATE_TYPEINFO(tint::CAB); +TINT_INSTANTIATE_TYPEINFO(tint::CAC); +TINT_INSTANTIATE_TYPEINFO(tint::CB); +TINT_INSTANTIATE_TYPEINFO(tint::CBA); +TINT_INSTANTIATE_TYPEINFO(tint::CBB); +TINT_INSTANTIATE_TYPEINFO(tint::CBC); +TINT_INSTANTIATE_TYPEINFO(tint::CC); +TINT_INSTANTIATE_TYPEINFO(tint::CCA); +TINT_INSTANTIATE_TYPEINFO(tint::CCB); +TINT_INSTANTIATE_TYPEINFO(tint::CCC); diff --git a/src/castable_test.cc b/src/castable_test.cc index e44983ba8f..2a9a71a742 100644 --- a/src/castable_test.cc +++ b/src/castable_test.cc @@ -252,6 +252,151 @@ TEST(Castable, As) { ASSERT_EQ(gecko->As(), static_cast(gecko.get())); } +TEST(Castable, SwitchNoDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + bool frog_matched_amphibian = false; + Switch( + frog.get(), // + [&](Reptile*) { FAIL() << "frog is not reptile"; }, + [&](Mammal*) { FAIL() << "frog is not mammal"; }, + [&](Amphibian* amphibian) { + EXPECT_EQ(amphibian, frog.get()); + frog_matched_amphibian = true; + }); + EXPECT_TRUE(frog_matched_amphibian); + } + { + bool bear_matched_mammal = false; + Switch( + bear.get(), // + [&](Reptile*) { FAIL() << "bear is not reptile"; }, + [&](Amphibian*) { FAIL() << "bear is not amphibian"; }, + [&](Mammal* mammal) { + EXPECT_EQ(mammal, bear.get()); + bear_matched_mammal = true; + }); + EXPECT_TRUE(bear_matched_mammal); + } + { + bool gecko_matched_reptile = false; + Switch( + gecko.get(), // + [&](Mammal*) { FAIL() << "gecko is not mammal"; }, + [&](Amphibian*) { FAIL() << "gecko is not amphibian"; }, + [&](Reptile* reptile) { + EXPECT_EQ(reptile, gecko.get()); + gecko_matched_reptile = true; + }); + EXPECT_TRUE(gecko_matched_reptile); + } +} + +TEST(Castable, SwitchWithUnusedDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + bool frog_matched_amphibian = false; + Switch( + frog.get(), // + [&](Reptile*) { FAIL() << "frog is not reptile"; }, + [&](Mammal*) { FAIL() << "frog is not mammal"; }, + [&](Amphibian* amphibian) { + EXPECT_EQ(amphibian, frog.get()); + frog_matched_amphibian = true; + }, + [&](Default) { FAIL() << "default should not have been selected"; }); + EXPECT_TRUE(frog_matched_amphibian); + } + { + bool bear_matched_mammal = false; + Switch( + bear.get(), // + [&](Reptile*) { FAIL() << "bear is not reptile"; }, + [&](Amphibian*) { FAIL() << "bear is not amphibian"; }, + [&](Mammal* mammal) { + EXPECT_EQ(mammal, bear.get()); + bear_matched_mammal = true; + }, + [&](Default) { FAIL() << "default should not have been selected"; }); + EXPECT_TRUE(bear_matched_mammal); + } + { + bool gecko_matched_reptile = false; + Switch( + gecko.get(), // + [&](Mammal*) { FAIL() << "gecko is not mammal"; }, + [&](Amphibian*) { FAIL() << "gecko is not amphibian"; }, + [&](Reptile* reptile) { + EXPECT_EQ(reptile, gecko.get()); + gecko_matched_reptile = true; + }, + [&](Default) { FAIL() << "default should not have been selected"; }); + EXPECT_TRUE(gecko_matched_reptile); + } +} + +TEST(Castable, SwitchDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + bool frog_matched_default = false; + Switch( + frog.get(), // + [&](Reptile*) { FAIL() << "frog is not reptile"; }, + [&](Mammal*) { FAIL() << "frog is not mammal"; }, + [&](Default) { frog_matched_default = true; }); + EXPECT_TRUE(frog_matched_default); + } + { + bool bear_matched_default = false; + Switch( + bear.get(), // + [&](Reptile*) { FAIL() << "bear is not reptile"; }, + [&](Amphibian*) { FAIL() << "bear is not amphibian"; }, + [&](Default) { bear_matched_default = true; }); + EXPECT_TRUE(bear_matched_default); + } + { + bool gecko_matched_default = false; + Switch( + gecko.get(), // + [&](Mammal*) { FAIL() << "gecko is not mammal"; }, + [&](Amphibian*) { FAIL() << "gecko is not amphibian"; }, + [&](Default) { gecko_matched_default = true; }); + EXPECT_TRUE(gecko_matched_default); + } +} +TEST(Castable, SwitchMatchFirst) { + std::unique_ptr frog = std::make_unique(); + { + bool frog_matched_animal = false; + Switch( + frog.get(), + [&](Animal* animal) { + EXPECT_EQ(animal, frog.get()); + frog_matched_animal = true; + }, + [&](Amphibian*) { FAIL() << "animal should have been matched first"; }); + EXPECT_TRUE(frog_matched_animal); + } + { + bool frog_matched_amphibian = false; + Switch( + frog.get(), + [&](Amphibian* amphibain) { + EXPECT_EQ(amphibain, frog.get()); + frog_matched_amphibian = true; + }, + [&](Animal*) { FAIL() << "amphibian should have been matched first"; }); + EXPECT_TRUE(frog_matched_amphibian); + } +} + } // namespace TINT_INSTANTIATE_TYPEINFO(Animal); diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index b88916ecab..ed35ca4e0d 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -953,7 +953,7 @@ const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() { bool FunctionEmitter::EmitPipelineInput(std::string var_name, const Type* var_type, - ast::AttributeList* decos, + ast::AttributeList* attrs, std::vector index_prefix, const Type* tip_type, const Type* forced_param_type, @@ -966,105 +966,121 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name, } // Recursively flatten matrices, arrays, and structures. - if (auto* matrix_type = tip_type->As()) { - index_prefix.push_back(0); - const auto num_columns = static_cast(matrix_type->columns); - const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); - for (int col = 0; col < num_columns; col++) { - index_prefix.back() = col; - if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty, - forced_param_type, params, statements)) { - return false; - } - } - return success(); - } else if (auto* array_type = tip_type->As()) { - if (array_type->size == 0) { - return Fail() << "runtime-size array not allowed on pipeline IO"; - } - index_prefix.push_back(0); - const Type* elem_ty = array_type->type; - for (int i = 0; i < static_cast(array_type->size); i++) { - index_prefix.back() = i; - if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty, - forced_param_type, params, statements)) { - return false; - } - } - return success(); - } else if (auto* struct_type = tip_type->As()) { - const auto& members = struct_type->members; - index_prefix.push_back(0); - for (int i = 0; i < static_cast(members.size()); ++i) { - index_prefix.back() = i; - ast::AttributeList member_decos(*decos); - if (!parser_impl_.ConvertPipelineDecorations( - struct_type, - parser_impl_.GetMemberPipelineDecorations(*struct_type, i), - &member_decos)) { - return false; - } - if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix, - members[i], forced_param_type, params, - statements)) { - return false; - } - // Copy the location as updated by nested expansion of the member. - parser_impl_.SetLocation(decos, GetLocation(member_decos)); - } - return success(); - } + return Switch( + tip_type, + [&](const Matrix* matrix_type) -> bool { + index_prefix.push_back(0); + const auto num_columns = static_cast(matrix_type->columns); + const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); + for (int col = 0; col < num_columns; col++) { + index_prefix.back() = col; + if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix, + vec_ty, forced_param_type, params, + statements)) { + return false; + } + } + return success(); + }, + [&](const Array* array_type) -> bool { + if (array_type->size == 0) { + return Fail() << "runtime-size array not allowed on pipeline IO"; + } + index_prefix.push_back(0); + const Type* elem_ty = array_type->type; + for (int i = 0; i < static_cast(array_type->size); i++) { + index_prefix.back() = i; + if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix, + elem_ty, forced_param_type, params, + statements)) { + return false; + } + } + return success(); + }, + [&](const Struct* struct_type) -> bool { + const auto& members = struct_type->members; + index_prefix.push_back(0); + for (int i = 0; i < static_cast(members.size()); ++i) { + index_prefix.back() = i; + ast::AttributeList member_attrs(*attrs); + if (!parser_impl_.ConvertPipelineDecorations( + struct_type, + parser_impl_.GetMemberPipelineDecorations(*struct_type, i), + &member_attrs)) { + return false; + } + if (!EmitPipelineInput(var_name, var_type, &member_attrs, + index_prefix, members[i], forced_param_type, + params, statements)) { + return false; + } + // Copy the location as updated by nested expansion of the member. + parser_impl_.SetLocation(attrs, GetLocation(member_attrs)); + } + return success(); + }, + [&](Default) { + const bool is_builtin = + ast::HasAttribute(*attrs); - const bool is_builtin = ast::HasAttribute(*decos); + const Type* param_type = is_builtin ? forced_param_type : tip_type; - const Type* param_type = is_builtin ? forced_param_type : tip_type; + const auto param_name = namer_.MakeDerivedName(var_name + "_param"); + // Create the parameter. + // TODO(dneto): Note: If the parameter has non-location decorations, + // then those decoration AST nodes will be reused between multiple + // elements of a matrix, array, or structure. Normally that's + // disallowed but currently the SPIR-V reader will make duplicates when + // the entire AST is cloned at the top level of the SPIR-V reader flow. + // Consider rewriting this to avoid this node-sharing. + params->push_back( + builder_.Param(param_name, param_type->Build(builder_), *attrs)); - const auto param_name = namer_.MakeDerivedName(var_name + "_param"); - // Create the parameter. - // TODO(dneto): Note: If the parameter has non-location decorations, - // then those decoration AST nodes will be reused between multiple elements - // of a matrix, array, or structure. Normally that's disallowed but currently - // the SPIR-V reader will make duplicates when the entire AST is cloned - // at the top level of the SPIR-V reader flow. Consider rewriting this - // to avoid this node-sharing. - params->push_back( - builder_.Param(param_name, param_type->Build(builder_), *decos)); + // Add a body statement to copy the parameter to the corresponding + // private variable. + const ast::Expression* param_value = builder_.Expr(param_name); + const ast::Expression* store_dest = builder_.Expr(var_name); - // Add a body statement to copy the parameter to the corresponding private - // variable. - const ast::Expression* param_value = builder_.Expr(param_name); - const ast::Expression* store_dest = builder_.Expr(var_name); + // Index into the LHS as needed. + auto* current_type = + var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); + for (auto index : index_prefix) { + Switch( + current_type, + [&](const Matrix* matrix_type) { + store_dest = + builder_.IndexAccessor(store_dest, builder_.Expr(index)); + current_type = ty_.Vector(matrix_type->type, matrix_type->rows); + }, + [&](const Array* array_type) { + store_dest = + builder_.IndexAccessor(store_dest, builder_.Expr(index)); + current_type = array_type->type->UnwrapAlias(); + }, + [&](const Struct* struct_type) { + store_dest = builder_.MemberAccessor( + store_dest, builder_.Expr(parser_impl_.GetMemberName( + *struct_type, index))); + current_type = struct_type->members[index]; + }); + } - // Index into the LHS as needed. - auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); - for (auto index : index_prefix) { - if (auto* matrix_type = current_type->As()) { - store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index)); - current_type = ty_.Vector(matrix_type->type, matrix_type->rows); - } else if (auto* array_type = current_type->As()) { - store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index)); - current_type = array_type->type->UnwrapAlias(); - } else if (auto* struct_type = current_type->As()) { - store_dest = builder_.MemberAccessor( - store_dest, - builder_.Expr(parser_impl_.GetMemberName(*struct_type, index))); - current_type = struct_type->members[index]; - } - } + if (is_builtin && (tip_type != forced_param_type)) { + // The parameter will have the WGSL type, but we need bitcast to + // the variable store type. + param_value = create( + tip_type->Build(builder_), param_value); + } - if (is_builtin && (tip_type != forced_param_type)) { - // The parameter will have the WGSL type, but we need bitcast to - // the variable store type. - param_value = - create(tip_type->Build(builder_), param_value); - } + statements->push_back(builder_.Assign(store_dest, param_value)); - statements->push_back(builder_.Assign(store_dest, param_value)); + // Increment the location attribute, in case more parameters will + // follow. + IncrementLocation(attrs); - // Increment the location attribute, in case more parameters will follow. - IncrementLocation(decos); - - return success(); + return success(); + }); } void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) { @@ -1102,106 +1118,120 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name, } // Recursively flatten matrices, arrays, and structures. - if (auto* matrix_type = tip_type->As()) { - index_prefix.push_back(0); - const auto num_columns = static_cast(matrix_type->columns); - const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); - for (int col = 0; col < num_columns; col++) { - index_prefix.back() = col; - if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty, - forced_member_type, return_members, - return_exprs)) { - return false; - } - } - return success(); - } else if (auto* array_type = tip_type->As()) { - if (array_type->size == 0) { - return Fail() << "runtime-size array not allowed on pipeline IO"; - } - index_prefix.push_back(0); - const Type* elem_ty = array_type->type; - for (int i = 0; i < static_cast(array_type->size); i++) { - index_prefix.back() = i; - if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty, - forced_member_type, return_members, - return_exprs)) { - return false; - } - } - return success(); - } else if (auto* struct_type = tip_type->As()) { - const auto& members = struct_type->members; - index_prefix.push_back(0); - for (int i = 0; i < static_cast(members.size()); ++i) { - index_prefix.back() = i; - ast::AttributeList member_decos(*decos); - if (!parser_impl_.ConvertPipelineDecorations( - struct_type, - parser_impl_.GetMemberPipelineDecorations(*struct_type, i), - &member_decos)) { - return false; - } - if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix, - members[i], forced_member_type, return_members, - return_exprs)) { - return false; - } - // Copy the location as updated by nested expansion of the member. - parser_impl_.SetLocation(decos, GetLocation(member_decos)); - } - return success(); - } + return Switch( + tip_type, + [&](const Matrix* matrix_type) -> bool { + index_prefix.push_back(0); + const auto num_columns = static_cast(matrix_type->columns); + const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); + for (int col = 0; col < num_columns; col++) { + index_prefix.back() = col; + if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, + vec_ty, forced_member_type, return_members, + return_exprs)) { + return false; + } + } + return success(); + }, + [&](const Array* array_type) -> bool { + if (array_type->size == 0) { + return Fail() << "runtime-size array not allowed on pipeline IO"; + } + index_prefix.push_back(0); + const Type* elem_ty = array_type->type; + for (int i = 0; i < static_cast(array_type->size); i++) { + index_prefix.back() = i; + if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, + elem_ty, forced_member_type, return_members, + return_exprs)) { + return false; + } + } + return success(); + }, + [&](const Struct* struct_type) -> bool { + const auto& members = struct_type->members; + index_prefix.push_back(0); + for (int i = 0; i < static_cast(members.size()); ++i) { + index_prefix.back() = i; + ast::AttributeList member_attrs(*decos); + if (!parser_impl_.ConvertPipelineDecorations( + struct_type, + parser_impl_.GetMemberPipelineDecorations(*struct_type, i), + &member_attrs)) { + return false; + } + if (!EmitPipelineOutput(var_name, var_type, &member_attrs, + index_prefix, members[i], forced_member_type, + return_members, return_exprs)) { + return false; + } + // Copy the location as updated by nested expansion of the member. + parser_impl_.SetLocation(decos, GetLocation(member_attrs)); + } + return success(); + }, + [&](Default) { + const bool is_builtin = + ast::HasAttribute(*decos); - const bool is_builtin = ast::HasAttribute(*decos); + const Type* member_type = is_builtin ? forced_member_type : tip_type; + // Derive the member name directly from the variable name. They can't + // collide. + const auto member_name = namer_.MakeDerivedName(var_name); + // Create the member. + // TODO(dneto): Note: If the parameter has non-location decorations, + // then those decoration AST nodes will be reused between multiple + // elements of a matrix, array, or structure. Normally that's + // disallowed but currently the SPIR-V reader will make duplicates when + // the entire AST is cloned at the top level of the SPIR-V reader flow. + // Consider rewriting this to avoid this node-sharing. + return_members->push_back( + builder_.Member(member_name, member_type->Build(builder_), *decos)); - const Type* member_type = is_builtin ? forced_member_type : tip_type; - // Derive the member name directly from the variable name. They can't - // collide. - const auto member_name = namer_.MakeDerivedName(var_name); - // Create the member. - // TODO(dneto): Note: If the parameter has non-location decorations, - // then those decoration AST nodes will be reused between multiple elements - // of a matrix, array, or structure. Normally that's disallowed but currently - // the SPIR-V reader will make duplicates when the entire AST is cloned - // at the top level of the SPIR-V reader flow. Consider rewriting this - // to avoid this node-sharing. - return_members->push_back( - builder_.Member(member_name, member_type->Build(builder_), *decos)); + // Create an expression to evaluate the part of the variable indexed by + // the index_prefix. + const ast::Expression* load_source = builder_.Expr(var_name); - // Create an expression to evaluate the part of the variable indexed by - // the index_prefix. - const ast::Expression* load_source = builder_.Expr(var_name); + // Index into the variable as needed to pick out the flattened member. + auto* current_type = + var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); + for (auto index : index_prefix) { + Switch( + current_type, + [&](const Matrix* matrix_type) { + load_source = + builder_.IndexAccessor(load_source, builder_.Expr(index)); + current_type = ty_.Vector(matrix_type->type, matrix_type->rows); + }, + [&](const Array* array_type) { + load_source = + builder_.IndexAccessor(load_source, builder_.Expr(index)); + current_type = array_type->type->UnwrapAlias(); + }, + [&](const Struct* struct_type) { + load_source = builder_.MemberAccessor( + load_source, builder_.Expr(parser_impl_.GetMemberName( + *struct_type, index))); + current_type = struct_type->members[index]; + }); + } - // Index into the variable as needed to pick out the flattened member. - auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); - for (auto index : index_prefix) { - if (auto* matrix_type = current_type->As()) { - load_source = builder_.IndexAccessor(load_source, builder_.Expr(index)); - current_type = ty_.Vector(matrix_type->type, matrix_type->rows); - } else if (auto* array_type = current_type->As()) { - load_source = builder_.IndexAccessor(load_source, builder_.Expr(index)); - current_type = array_type->type->UnwrapAlias(); - } else if (auto* struct_type = current_type->As()) { - load_source = builder_.MemberAccessor( - load_source, - builder_.Expr(parser_impl_.GetMemberName(*struct_type, index))); - current_type = struct_type->members[index]; - } - } + if (is_builtin && (tip_type != forced_member_type)) { + // The member will have the WGSL type, but we need bitcast to + // the variable store type. + load_source = create( + forced_member_type->Build(builder_), load_source); + } + return_exprs->push_back(load_source); - if (is_builtin && (tip_type != forced_member_type)) { - // The member will have the WGSL type, but we need bitcast to - // the variable store type. - load_source = create( - forced_member_type->Build(builder_), load_source); - } - return_exprs->push_back(load_source); + // Increment the location attribute, in case more parameters will + // follow. + IncrementLocation(decos); - // Increment the location attribute, in case more parameters will follow. - IncrementLocation(decos); - - return success(); + return success(); + }); } bool FunctionEmitter::EmitEntryPointAsWrapper() { diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 6d20aceaa8..8502427d76 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -239,39 +239,41 @@ bool GeneratorImpl::Generate() { } last_kind = kind; - if (auto* global = decl->As()) { - if (!EmitGlobalVariable(global)) { - return false; - } - } else if (auto* str = decl->As()) { - auto* ty = builder_.Sem().Get(str); - auto storage_class_uses = ty->StorageClassUsage(); - if (storage_class_uses.size() != - (storage_class_uses.count(ast::StorageClass::kStorage) + - storage_class_uses.count(ast::StorageClass::kUniform))) { - // The structure is used as something other than a storage buffer or - // uniform buffer, so it needs to be emitted. - // Storage buffer are read and written to via a ByteAddressBuffer - // instead of true structure. - // Structures used as uniform buffer are read from an array of vectors - // instead of true structure. - if (!EmitStructType(current_buffer_, ty)) { + bool ok = Switch( + decl, + [&](const ast::Variable* global) { // + return EmitGlobalVariable(global); + }, + [&](const ast::Struct* str) { + auto* ty = builder_.Sem().Get(str); + auto storage_class_uses = ty->StorageClassUsage(); + if (storage_class_uses.size() != + (storage_class_uses.count(ast::StorageClass::kStorage) + + storage_class_uses.count(ast::StorageClass::kUniform))) { + // The structure is used as something other than a storage buffer or + // uniform buffer, so it needs to be emitted. + // Storage buffer are read and written to via a ByteAddressBuffer + // instead of true structure. + // Structures used as uniform buffer are read from an array of + // vectors instead of true structure. + return EmitStructType(current_buffer_, ty); + } + return true; + }, + [&](const ast::Function* func) { + if (func->IsEntryPoint()) { + return EmitEntryPointFunction(func); + } + return EmitFunction(func); + }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) + << "unhandled module-scope declaration: " + << decl->TypeInfo().name; return false; - } - } - } else if (auto* func = decl->As()) { - if (func->IsEntryPoint()) { - if (!EmitEntryPointFunction(func)) { - return false; - } - } else { - if (!EmitFunction(func)) { - return false; - } - } - } else { - TINT_ICE(Writer, diagnostics_) - << "unhandled module-scope declaration: " << decl->TypeInfo().name; + }); + + if (!ok) { return false; } } @@ -929,22 +931,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) { auto* call = builder_.Sem().Get(expr); auto* target = call->Target(); - - if (auto* func = target->As()) { - return EmitFunctionCall(out, call, func); - } - if (auto* builtin = target->As()) { - return EmitBuiltinCall(out, call, builtin); - } - if (auto* conv = target->As()) { - return EmitTypeConversion(out, call, conv); - } - if (auto* ctor = target->As()) { - return EmitTypeConstructor(out, call, ctor); - } - TINT_ICE(Writer, diagnostics_) - << "unhandled call target: " << target->TypeInfo().name; - return false; + return Switch( + target, + [&](const sem::Function* func) { + return EmitFunctionCall(out, call, func); + }, + [&](const sem::Builtin* builtin) { + return EmitBuiltinCall(out, call, builtin); + }, + [&](const sem::TypeConversion* conv) { + return EmitTypeConversion(out, call, conv); + }, + [&](const sem::TypeConstructor* ctor) { + return EmitTypeConstructor(out, call, ctor); + }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) + << "unhandled call target: " << target->TypeInfo().name; + return false; + }); } bool GeneratorImpl::EmitFunctionCall(std::ostream& out, @@ -2639,35 +2644,38 @@ bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) { bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) { - if (auto* a = expr->As()) { - return EmitIndexAccessor(out, a); - } - if (auto* b = expr->As()) { - return EmitBinary(out, b); - } - if (auto* b = expr->As()) { - return EmitBitcast(out, b); - } - if (auto* c = expr->As()) { - return EmitCall(out, c); - } - if (auto* i = expr->As()) { - return EmitIdentifier(out, i); - } - if (auto* l = expr->As()) { - return EmitLiteral(out, l); - } - if (auto* m = expr->As()) { - return EmitMemberAccessor(out, m); - } - if (auto* u = expr->As()) { - return EmitUnaryOp(out, u); - } - - diagnostics_.add_error( - diag::System::Writer, - "unknown expression type: " + std::string(expr->TypeInfo().name)); - return false; + return Switch( + expr, + [&](const ast::IndexAccessorExpression* a) { // + return EmitIndexAccessor(out, a); + }, + [&](const ast::BinaryExpression* b) { // + return EmitBinary(out, b); + }, + [&](const ast::BitcastExpression* b) { // + return EmitBitcast(out, b); + }, + [&](const ast::CallExpression* c) { // + return EmitCall(out, c); + }, + [&](const ast::IdentifierExpression* i) { // + return EmitIdentifier(out, i); + }, + [&](const ast::LiteralExpression* l) { // + return EmitLiteral(out, l); + }, + [&](const ast::MemberAccessorExpression* m) { // + return EmitMemberAccessor(out, m); + }, + [&](const ast::UnaryOpExpression* u) { // + return EmitUnaryOp(out, u); + }, + [&](Default) { // + diagnostics_.add_error( + diag::System::Writer, + "unknown expression type: " + std::string(expr->TypeInfo().name)); + return false; + }); } bool GeneratorImpl::EmitIdentifier(std::ostream& out, @@ -3127,80 +3135,108 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* lit) { - if (auto* l = lit->As()) { - out << (l->value ? "true" : "false"); - } else if (auto* fl = lit->As()) { - if (std::isinf(fl->value)) { - out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)"); - } else if (std::isnan(fl->value)) { - out << "asfloat(0x7fc00000u)"; - } else { - out << FloatToString(fl->value) << "f"; - } - } else if (auto* sl = lit->As()) { - out << sl->value; - } else if (auto* ul = lit->As()) { - out << ul->value << "u"; - } else { - diagnostics_.add_error(diag::System::Writer, "unknown literal type"); - return false; - } - return true; + return Switch( + lit, + [&](const ast::BoolLiteralExpression* l) { + out << (l->value ? "true" : "false"); + return true; + }, + [&](const ast::FloatLiteralExpression* fl) { + if (std::isinf(fl->value)) { + out << (fl->value >= 0 ? "asfloat(0x7f800000u)" + : "asfloat(0xff800000u)"); + } else if (std::isnan(fl->value)) { + out << "asfloat(0x7fc00000u)"; + } else { + out << FloatToString(fl->value) << "f"; + } + return true; + }, + [&](const ast::SintLiteralExpression* sl) { + out << sl->value; + return true; + }, + [&](const ast::UintLiteralExpression* ul) { + out << ul->value << "u"; + return true; + }, + [&](Default) { + diagnostics_.add_error(diag::System::Writer, "unknown literal type"); + return false; + }); } bool GeneratorImpl::EmitValue(std::ostream& out, const sem::Type* type, int value) { - if (type->Is()) { - out << (value == 0 ? "false" : "true"); - } else if (type->Is()) { - out << value << ".0f"; - } else if (type->Is()) { - out << value; - } else if (type->Is()) { - out << value << "u"; - } else if (auto* vec = type->As()) { - if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, - "")) { - return false; - } - ScopedParen sp(out); - for (uint32_t i = 0; i < vec->Width(); i++) { - if (i != 0) { - out << ", "; - } - if (!EmitValue(out, vec->type(), value)) { + return Switch( + type, + [&](const sem::Bool*) { + out << (value == 0 ? "false" : "true"); + return true; + }, + [&](const sem::F32*) { + out << value << ".0f"; + return true; + }, + [&](const sem::I32*) { + out << value; + return true; + }, + [&](const sem::U32*) { + out << value << "u"; + return true; + }, + [&](const sem::Vector* vec) { + if (!EmitType(out, type, ast::StorageClass::kNone, + ast::Access::kReadWrite, "")) { + return false; + } + ScopedParen sp(out); + for (uint32_t i = 0; i < vec->Width(); i++) { + if (i != 0) { + out << ", "; + } + if (!EmitValue(out, vec->type(), value)) { + return false; + } + } + return true; + }, + [&](const sem::Matrix* mat) { + if (!EmitType(out, type, ast::StorageClass::kNone, + ast::Access::kReadWrite, "")) { + return false; + } + ScopedParen sp(out); + for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) { + if (i != 0) { + out << ", "; + } + if (!EmitValue(out, mat->type(), value)) { + return false; + } + } + return true; + }, + [&](const sem::Struct*) { + out << "("; + TINT_DEFER(out << ")" << value); + return EmitType(out, type, ast::StorageClass::kNone, + ast::Access::kUndefined, ""); + }, + [&](const sem::Array*) { + out << "("; + TINT_DEFER(out << ")" << value); + return EmitType(out, type, ast::StorageClass::kNone, + ast::Access::kUndefined, ""); + }, + [&](Default) { + diagnostics_.add_error( + diag::System::Writer, + "Invalid type for value emission: " + type->type_name()); return false; - } - } - } else if (auto* mat = type->As()) { - if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, - "")) { - return false; - } - ScopedParen sp(out); - for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) { - if (i != 0) { - out << ", "; - } - if (!EmitValue(out, mat->type(), value)) { - return false; - } - } - } else if (type->IsAnyOf()) { - out << "("; - if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, - "")) { - return false; - } - out << ")" << value; - } else { - diagnostics_.add_error( - diag::System::Writer, - "Invalid type for value emission: " + type->type_name()); - return false; - } - return true; + }); } bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { @@ -3375,56 +3411,59 @@ bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) { } bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { - if (auto* a = stmt->As()) { - return EmitAssign(a); - } - if (auto* b = stmt->As()) { - return EmitBlock(b); - } - if (auto* b = stmt->As()) { - return EmitBreak(b); - } - if (auto* c = stmt->As()) { - auto out = line(); - if (!EmitCall(out, c->expr)) { - return false; - } - out << ";"; - return true; - } - if (auto* c = stmt->As()) { - return EmitContinue(c); - } - if (auto* d = stmt->As()) { - return EmitDiscard(d); - } - if (stmt->As()) { - line() << "/* fallthrough */"; - return true; - } - if (auto* i = stmt->As()) { - return EmitIf(i); - } - if (auto* l = stmt->As()) { - return EmitLoop(l); - } - if (auto* l = stmt->As()) { - return EmitForLoop(l); - } - if (auto* r = stmt->As()) { - return EmitReturn(r); - } - if (auto* s = stmt->As()) { - return EmitSwitch(s); - } - if (auto* v = stmt->As()) { - return EmitVariable(v->variable); - } - - diagnostics_.add_error( - diag::System::Writer, - "unknown statement type: " + std::string(stmt->TypeInfo().name)); - return false; + return Switch( + stmt, + [&](const ast::AssignmentStatement* a) { // + return EmitAssign(a); + }, + [&](const ast::BlockStatement* b) { // + return EmitBlock(b); + }, + [&](const ast::BreakStatement* b) { // + return EmitBreak(b); + }, + [&](const ast::CallStatement* c) { // + auto out = line(); + if (!EmitCall(out, c->expr)) { + return false; + } + out << ";"; + return true; + }, + [&](const ast::ContinueStatement* c) { // + return EmitContinue(c); + }, + [&](const ast::DiscardStatement* d) { // + return EmitDiscard(d); + }, + [&](const ast::FallthroughStatement*) { // + line() << "/* fallthrough */"; + return true; + }, + [&](const ast::IfStatement* i) { // + return EmitIf(i); + }, + [&](const ast::LoopStatement* l) { // + return EmitLoop(l); + }, + [&](const ast::ForLoopStatement* l) { // + return EmitForLoop(l); + }, + [&](const ast::ReturnStatement* r) { // + return EmitReturn(r); + }, + [&](const ast::SwitchStatement* s) { // + return EmitSwitch(s); + }, + [&](const ast::VariableDeclStatement* v) { // + return EmitVariable(v->variable); + }, + [&](Default) { // + diagnostics_.add_error( + diag::System::Writer, + "unknown statement type: " + std::string(stmt->TypeInfo().name)); + return false; + }); } bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) { @@ -3516,156 +3555,181 @@ bool GeneratorImpl::EmitType(std::ostream& out, break; } - if (auto* ary = type->As()) { - const sem::Type* base_type = ary; - std::vector sizes; - while (auto* arr = base_type->As()) { - if (arr->IsRuntimeSized()) { + return Switch( + type, + [&](const sem::Array* ary) { + const sem::Type* base_type = ary; + std::vector sizes; + while (auto* arr = base_type->As()) { + if (arr->IsRuntimeSized()) { + TINT_ICE(Writer, diagnostics_) + << "Runtime arrays may only exist in storage buffers, which " + "should " + "have been transformed into a ByteAddressBuffer"; + return false; + } + sizes.push_back(arr->Count()); + base_type = arr->ElemType(); + } + if (!EmitType(out, base_type, storage_class, access, "")) { + return false; + } + if (!name.empty()) { + out << " " << name; + if (name_printed) { + *name_printed = true; + } + } + for (uint32_t size : sizes) { + out << "[" << size << "]"; + } + return true; + }, + [&](const sem::Bool*) { + out << "bool"; + return true; + }, + [&](const sem::F32*) { + out << "float"; + return true; + }, + [&](const sem::I32*) { + out << "int"; + return true; + }, + [&](const sem::Matrix* mat) { + if (!EmitType(out, mat->type(), storage_class, access, "")) { + return false; + } + // Note: HLSL's matrices are declared as NxM, where N is the + // number of rows and M is the number of columns. Despite HLSL's + // matrices being column-major by default, the index operator and + // constructors actually operate on row-vectors, where as WGSL operates + // on column vectors. To simplify everything we use the transpose of the + // matrices. See: + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering + out << mat->columns() << "x" << mat->rows(); + return true; + }, + [&](const sem::Pointer*) { TINT_ICE(Writer, diagnostics_) - << "Runtime arrays may only exist in storage buffers, which should " - "have been transformed into a ByteAddressBuffer"; + << "Attempting to emit pointer type. These should have been " + "removed with the InlinePointerLets transform"; return false; - } - sizes.push_back(arr->Count()); - base_type = arr->ElemType(); - } - if (!EmitType(out, base_type, storage_class, access, "")) { - return false; - } - if (!name.empty()) { - out << " " << name; - if (name_printed) { - *name_printed = true; - } - } - for (uint32_t size : sizes) { - out << "[" << size << "]"; - } - } else if (type->Is()) { - out << "bool"; - } else if (type->Is()) { - out << "float"; - } else if (type->Is()) { - out << "int"; - } else if (auto* mat = type->As()) { - if (!EmitType(out, mat->type(), storage_class, access, "")) { - return false; - } - // Note: HLSL's matrices are declared as NxM, where N is the number of - // rows and M is the number of columns. Despite HLSL's matrices being - // column-major by default, the index operator and constructors actually - // operate on row-vectors, where as WGSL operates on column vectors. - // To simplify everything we use the transpose of the matrices. - // See: - // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering - out << mat->columns() << "x" << mat->rows(); - } else if (type->Is()) { - TINT_ICE(Writer, diagnostics_) - << "Attempting to emit pointer type. These should have been removed " - "with the InlinePointerLets transform"; - return false; - } else if (auto* sampler = type->As()) { - out << "Sampler"; - if (sampler->IsComparison()) { - out << "Comparison"; - } - out << "State"; - } else if (auto* str = type->As()) { - out << StructName(str); - } else if (auto* tex = type->As()) { - auto* storage = tex->As(); - auto* ms = tex->As(); - auto* depth_ms = tex->As(); - auto* sampled = tex->As(); + }, + [&](const sem::Sampler* sampler) { + out << "Sampler"; + if (sampler->IsComparison()) { + out << "Comparison"; + } + out << "State"; + return true; + }, + [&](const sem::Struct* str) { + out << StructName(str); + return true; + }, + [&](const sem::Texture* tex) { + auto* storage = tex->As(); + auto* ms = tex->As(); + auto* depth_ms = tex->As(); + auto* sampled = tex->As(); - if (storage && storage->access() != ast::Access::kRead) { - out << "RW"; - } - out << "Texture"; + if (storage && storage->access() != ast::Access::kRead) { + out << "RW"; + } + out << "Texture"; - switch (tex->dim()) { - case ast::TextureDimension::k1d: - out << "1D"; - break; - case ast::TextureDimension::k2d: - out << ((ms || depth_ms) ? "2DMS" : "2D"); - break; - case ast::TextureDimension::k2dArray: - out << ((ms || depth_ms) ? "2DMSArray" : "2DArray"); - break; - case ast::TextureDimension::k3d: - out << "3D"; - break; - case ast::TextureDimension::kCube: - out << "Cube"; - break; - case ast::TextureDimension::kCubeArray: - out << "CubeArray"; - break; - default: - TINT_UNREACHABLE(Writer, diagnostics_) - << "unexpected TextureDimension " << tex->dim(); - return false; - } + switch (tex->dim()) { + case ast::TextureDimension::k1d: + out << "1D"; + break; + case ast::TextureDimension::k2d: + out << ((ms || depth_ms) ? "2DMS" : "2D"); + break; + case ast::TextureDimension::k2dArray: + out << ((ms || depth_ms) ? "2DMSArray" : "2DArray"); + break; + case ast::TextureDimension::k3d: + out << "3D"; + break; + case ast::TextureDimension::kCube: + out << "Cube"; + break; + case ast::TextureDimension::kCubeArray: + out << "CubeArray"; + break; + default: + TINT_UNREACHABLE(Writer, diagnostics_) + << "unexpected TextureDimension " << tex->dim(); + return false; + } - if (storage) { - auto* component = image_format_to_rwtexture_type(storage->texel_format()); - if (component == nullptr) { - TINT_ICE(Writer, diagnostics_) - << "Unsupported StorageTexture TexelFormat: " - << static_cast(storage->texel_format()); + if (storage) { + auto* component = + image_format_to_rwtexture_type(storage->texel_format()); + if (component == nullptr) { + TINT_ICE(Writer, diagnostics_) + << "Unsupported StorageTexture TexelFormat: " + << static_cast(storage->texel_format()); + return false; + } + out << "<" << component << ">"; + } else if (depth_ms) { + out << ""; + } else if (sampled || ms) { + auto* subtype = sampled ? sampled->type() : ms->type(); + out << "<"; + if (subtype->Is()) { + out << "float4"; + } else if (subtype->Is()) { + out << "int4"; + } else if (subtype->Is()) { + out << "uint4"; + } else { + TINT_ICE(Writer, diagnostics_) + << "Unsupported multisampled texture type"; + return false; + } + out << ">"; + } + return true; + }, + [&](const sem::U32*) { + out << "uint"; + return true; + }, + [&](const sem::Vector* vec) { + auto width = vec->Width(); + if (vec->type()->Is() && width >= 1 && width <= 4) { + out << "float" << width; + } else if (vec->type()->Is() && width >= 1 && width <= 4) { + out << "int" << width; + } else if (vec->type()->Is() && width >= 1 && width <= 4) { + out << "uint" << width; + } else if (vec->type()->Is() && width >= 1 && width <= 4) { + out << "bool" << width; + } else { + out << "vector<"; + if (!EmitType(out, vec->type(), storage_class, access, "")) { + return false; + } + out << ", " << width << ">"; + } + return true; + }, + [&](const sem::Atomic* atomic) { + return EmitType(out, atomic->Type(), storage_class, access, name); + }, + [&](const sem::Void*) { + out << "void"; + return true; + }, + [&](Default) { + diagnostics_.add_error(diag::System::Writer, + "unknown type in EmitType"); return false; - } - out << "<" << component << ">"; - } else if (depth_ms) { - out << ""; - } else if (sampled || ms) { - auto* subtype = sampled ? sampled->type() : ms->type(); - out << "<"; - if (subtype->Is()) { - out << "float4"; - } else if (subtype->Is()) { - out << "int4"; - } else if (subtype->Is()) { - out << "uint4"; - } else { - TINT_ICE(Writer, diagnostics_) - << "Unsupported multisampled texture type"; - return false; - } - out << ">"; - } - } else if (type->Is()) { - out << "uint"; - } else if (auto* vec = type->As()) { - auto width = vec->Width(); - if (vec->type()->Is() && width >= 1 && width <= 4) { - out << "float" << width; - } else if (vec->type()->Is() && width >= 1 && width <= 4) { - out << "int" << width; - } else if (vec->type()->Is() && width >= 1 && width <= 4) { - out << "uint" << width; - } else if (vec->type()->Is() && width >= 1 && width <= 4) { - out << "bool" << width; - } else { - out << "vector<"; - if (!EmitType(out, vec->type(), storage_class, access, "")) { - return false; - } - out << ", " << width << ">"; - } - } else if (auto* atomic = type->As()) { - if (!EmitType(out, atomic->Type(), storage_class, access, name)) { - return false; - } - } else if (type->Is()) { - out << "void"; - } else { - diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType"); - return false; - } - - return true; + }); } bool GeneratorImpl::EmitTypeAndName(std::ostream& out, diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index ec2d748dbd..129e9692a4 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -538,23 +538,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) { auto* call = program_->Sem().Get(expr); auto* target = call->Target(); - - if (auto* func = target->As()) { - return EmitFunctionCall(out, call, func); - } - if (auto* builtin = target->As()) { - return EmitBuiltinCall(out, call, builtin); - } - if (auto* conv = target->As()) { - return EmitTypeConversion(out, call, conv); - } - if (auto* ctor = target->As()) { - return EmitTypeConstructor(out, call, ctor); - } - - TINT_ICE(Writer, diagnostics_) - << "unhandled call target: " << target->TypeInfo().name; - return false; + return Switch( + target, + [&](const sem::Function* func) { + return EmitFunctionCall(out, call, func); + }, + [&](const sem::Builtin* builtin) { + return EmitBuiltinCall(out, call, builtin); + }, + [&](const sem::TypeConversion* conv) { + return EmitTypeConversion(out, call, conv); + }, + [&](const sem::TypeConstructor* ctor) { + return EmitTypeConstructor(out, call, ctor); + }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) + << "unhandled call target: " << target->TypeInfo().name; + return false; + }); } bool GeneratorImpl::EmitFunctionCall(std::ostream& out, @@ -1476,106 +1478,128 @@ bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) { } bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { - if (type->Is()) { - out << "false"; - } else if (type->Is()) { - out << "0.0f"; - } else if (type->Is()) { - out << "0"; - } else if (type->Is()) { - out << "0u"; - } else if (auto* vec = type->As()) { - return EmitZeroValue(out, vec->type()); - } else if (auto* mat = type->As()) { - if (!EmitType(out, mat, "")) { - return false; - } - out << "("; - if (!EmitZeroValue(out, mat->type())) { - return false; - } - out << ")"; - } else if (auto* arr = type->As()) { - out << "{"; - if (!EmitZeroValue(out, arr->ElemType())) { - return false; - } - out << "}"; - } else if (type->As()) { - out << "{}"; - } else { - diagnostics_.add_error( - diag::System::Writer, - "Invalid type for zero emission: " + type->type_name()); - return false; - } - return true; + return Switch( + type, + [&](const sem::Bool*) { + out << "false"; + return true; + }, + [&](const sem::F32*) { + out << "0.0f"; + return true; + }, + [&](const sem::I32*) { + out << "0"; + return true; + }, + [&](const sem::U32*) { + out << "0u"; + return true; + }, + [&](const sem::Vector* vec) { // + return EmitZeroValue(out, vec->type()); + }, + [&](const sem::Matrix* mat) { + if (!EmitType(out, mat, "")) { + return false; + } + out << "("; + TINT_DEFER(out << ")"); + return EmitZeroValue(out, mat->type()); + }, + [&](const sem::Array* arr) { + out << "{"; + TINT_DEFER(out << "}"); + return EmitZeroValue(out, arr->ElemType()); + }, + [&](const sem::Struct*) { + out << "{}"; + return true; + }, + [&](Default) { + diagnostics_.add_error( + diag::System::Writer, + "Invalid type for zero emission: " + type->type_name()); + return false; + }); } bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* lit) { - if (auto* l = lit->As()) { - out << (l->value ? "true" : "false"); - } else if (auto* fl = lit->As()) { - if (std::isinf(fl->value)) { - out << (fl->value >= 0 ? "INFINITY" : "-INFINITY"); - } else if (std::isnan(fl->value)) { - out << "NAN"; - } else { - out << FloatToString(fl->value) << "f"; - } - } else if (auto* sl = lit->As()) { - // MSL (and C++) parse `-2147483648` as a `long` because it parses unary - // minus and `2147483648` as separate tokens, and the latter doesn't - // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid - // issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which - // ensures the expression type is `int`. - const auto int_min = std::numeric_limits::min(); - if (sl->ValueAsI32() == int_min) { - out << "(" << int_min + 1 << " - 1)"; - } else { - out << sl->value; - } - } else if (auto* ul = lit->As()) { - out << ul->value << "u"; - } else { - diagnostics_.add_error(diag::System::Writer, "unknown literal type"); - return false; - } - return true; + return Switch( + lit, + [&](const ast::BoolLiteralExpression* l) { + out << (l->value ? "true" : "false"); + return true; + }, + [&](const ast::FloatLiteralExpression* l) { + if (std::isinf(l->value)) { + out << (l->value >= 0 ? "INFINITY" : "-INFINITY"); + } else if (std::isnan(l->value)) { + out << "NAN"; + } else { + out << FloatToString(l->value) << "f"; + } + return true; + }, + [&](const ast::SintLiteralExpression* l) { + // MSL (and C++) parse `-2147483648` as a `long` because it parses unary + // minus and `2147483648` as separate tokens, and the latter doesn't + // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To + // avoid issues with `long` to `int` casts, emit `(2147483647 - 1)` + // instead, which ensures the expression type is `int`. + const auto int_min = std::numeric_limits::min(); + if (l->ValueAsI32() == int_min) { + out << "(" << int_min + 1 << " - 1)"; + } else { + out << l->value; + } + return true; + }, + [&](const ast::UintLiteralExpression* l) { + out << l->value << "u"; + return true; + }, + [&](Default) { + diagnostics_.add_error(diag::System::Writer, "unknown literal type"); + return false; + }); } bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) { - if (auto* a = expr->As()) { - return EmitIndexAccessor(out, a); - } - if (auto* b = expr->As()) { - return EmitBinary(out, b); - } - if (auto* b = expr->As()) { - return EmitBitcast(out, b); - } - if (auto* c = expr->As()) { - return EmitCall(out, c); - } - if (auto* i = expr->As()) { - return EmitIdentifier(out, i); - } - if (auto* l = expr->As()) { - return EmitLiteral(out, l); - } - if (auto* m = expr->As()) { - return EmitMemberAccessor(out, m); - } - if (auto* u = expr->As()) { - return EmitUnaryOp(out, u); - } - - diagnostics_.add_error( - diag::System::Writer, - "unknown expression type: " + std::string(expr->TypeInfo().name)); - return false; + return Switch( + expr, + [&](const ast::IndexAccessorExpression* a) { // + return EmitIndexAccessor(out, a); + }, + [&](const ast::BinaryExpression* b) { // + return EmitBinary(out, b); + }, + [&](const ast::BitcastExpression* b) { // + return EmitBitcast(out, b); + }, + [&](const ast::CallExpression* c) { // + return EmitCall(out, c); + }, + [&](const ast::IdentifierExpression* i) { // + return EmitIdentifier(out, i); + }, + [&](const ast::LiteralExpression* l) { // + return EmitLiteral(out, l); + }, + [&](const ast::MemberAccessorExpression* m) { // + return EmitMemberAccessor(out, m); + }, + [&](const ast::UnaryOpExpression* u) { // + return EmitUnaryOp(out, u); + }, + [&](Default) { // + diagnostics_.add_error( + diag::System::Writer, + "unknown expression type: " + std::string(expr->TypeInfo().name)); + return false; + }); } void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) { @@ -2106,57 +2130,60 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) { } bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { - if (auto* a = stmt->As()) { - return EmitAssign(a); - } - if (auto* b = stmt->As()) { - return EmitBlock(b); - } - if (auto* b = stmt->As()) { - return EmitBreak(b); - } - if (auto* c = stmt->As()) { - auto out = line(); - if (!EmitCall(out, c->expr)) { - return false; - } - out << ";"; - return true; - } - if (auto* c = stmt->As()) { - return EmitContinue(c); - } - if (auto* d = stmt->As()) { - return EmitDiscard(d); - } - if (stmt->As()) { - line() << "/* fallthrough */"; - return true; - } - if (auto* i = stmt->As()) { - return EmitIf(i); - } - if (auto* l = stmt->As()) { - return EmitLoop(l); - } - if (auto* l = stmt->As()) { - return EmitForLoop(l); - } - if (auto* r = stmt->As()) { - return EmitReturn(r); - } - if (auto* s = stmt->As()) { - return EmitSwitch(s); - } - if (auto* v = stmt->As()) { - auto* var = program_->Sem().Get(v->variable); - return EmitVariable(var); - } - - diagnostics_.add_error( - diag::System::Writer, - "unknown statement type: " + std::string(stmt->TypeInfo().name)); - return false; + return Switch( + stmt, + [&](const ast::AssignmentStatement* a) { // + return EmitAssign(a); + }, + [&](const ast::BlockStatement* b) { // + return EmitBlock(b); + }, + [&](const ast::BreakStatement* b) { // + return EmitBreak(b); + }, + [&](const ast::CallStatement* c) { // + auto out = line(); + if (!EmitCall(out, c->expr)) { // + return false; + } + out << ";"; + return true; + }, + [&](const ast::ContinueStatement* c) { // + return EmitContinue(c); + }, + [&](const ast::DiscardStatement* d) { // + return EmitDiscard(d); + }, + [&](const ast::FallthroughStatement*) { // + line() << "/* fallthrough */"; + return true; + }, + [&](const ast::IfStatement* i) { // + return EmitIf(i); + }, + [&](const ast::LoopStatement* l) { // + return EmitLoop(l); + }, + [&](const ast::ForLoopStatement* l) { // + return EmitForLoop(l); + }, + [&](const ast::ReturnStatement* r) { // + return EmitReturn(r); + }, + [&](const ast::SwitchStatement* s) { // + return EmitSwitch(s); + }, + [&](const ast::VariableDeclStatement* v) { // + auto* var = program_->Sem().Get(v->variable); + return EmitVariable(var); + }, + [&](Default) { + diagnostics_.add_error( + diag::System::Writer, + "unknown statement type: " + std::string(stmt->TypeInfo().name)); + return false; + }); } bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) { @@ -2204,203 +2231,210 @@ bool GeneratorImpl::EmitType(std::ostream& out, if (name_printed) { *name_printed = false; } - if (auto* atomic = type->As()) { - if (atomic->Type()->Is()) { - out << "atomic_int"; - return true; - } - if (atomic->Type()->Is()) { - out << "atomic_uint"; - return true; - } - TINT_ICE(Writer, diagnostics_) - << "unhandled atomic type " << atomic->Type()->type_name(); - return false; - } - if (auto* ary = type->As()) { - const sem::Type* base_type = ary; - std::vector sizes; - while (auto* arr = base_type->As()) { - if (arr->IsRuntimeSized()) { - sizes.push_back(1); - } else { - sizes.push_back(arr->Count()); - } - base_type = arr->ElemType(); - } - if (!EmitType(out, base_type, "")) { - return false; - } - if (!name.empty()) { - out << " " << name; - if (name_printed) { - *name_printed = true; - } - } - for (uint32_t size : sizes) { - out << "[" << size << "]"; - } - return true; - } - - if (type->Is()) { - out << "bool"; - return true; - } - - if (type->Is()) { - out << "float"; - return true; - } - - if (type->Is()) { - out << "int"; - return true; - } - - if (auto* mat = type->As()) { - if (!EmitType(out, mat->type(), "")) { - return false; - } - out << mat->columns() << "x" << mat->rows(); - return true; - } - - if (auto* ptr = type->As()) { - if (ptr->Access() == ast::Access::kRead) { - out << "const "; - } - if (!EmitStorageClass(out, ptr->StorageClass())) { - return false; - } - out << " "; - if (ptr->StoreType()->Is()) { - std::string inner = "(*" + name + ")"; - if (!EmitType(out, ptr->StoreType(), inner)) { + return Switch( + type, + [&](const sem::Atomic* atomic) { + if (atomic->Type()->Is()) { + out << "atomic_int"; + return true; + } + if (atomic->Type()->Is()) { + out << "atomic_uint"; + return true; + } + TINT_ICE(Writer, diagnostics_) + << "unhandled atomic type " << atomic->Type()->type_name(); return false; - } - if (name_printed) { - *name_printed = true; - } - } else { - if (!EmitType(out, ptr->StoreType(), "")) { + }, + [&](const sem::Array* ary) { + const sem::Type* base_type = ary; + std::vector sizes; + while (auto* arr = base_type->As()) { + if (arr->IsRuntimeSized()) { + sizes.push_back(1); + } else { + sizes.push_back(arr->Count()); + } + base_type = arr->ElemType(); + } + if (!EmitType(out, base_type, "")) { + return false; + } + if (!name.empty()) { + out << " " << name; + if (name_printed) { + *name_printed = true; + } + } + for (uint32_t size : sizes) { + out << "[" << size << "]"; + } + return true; + }, + [&](const sem::Bool*) { + out << "bool"; + return true; + }, + [&](const sem::F32*) { + out << "float"; + return true; + }, + [&](const sem::I32*) { + out << "int"; + return true; + }, + [&](const sem::Matrix* mat) { + if (!EmitType(out, mat->type(), "")) { + return false; + } + out << mat->columns() << "x" << mat->rows(); + return true; + }, + [&](const sem::Pointer* ptr) { + if (ptr->Access() == ast::Access::kRead) { + out << "const "; + } + if (!EmitStorageClass(out, ptr->StorageClass())) { + return false; + } + out << " "; + if (ptr->StoreType()->Is()) { + std::string inner = "(*" + name + ")"; + if (!EmitType(out, ptr->StoreType(), inner)) { + return false; + } + if (name_printed) { + *name_printed = true; + } + } else { + if (!EmitType(out, ptr->StoreType(), "")) { + return false; + } + out << "* " << name; + if (name_printed) { + *name_printed = true; + } + } + return true; + }, + [&](const sem::Sampler*) { + out << "sampler"; + return true; + }, + [&](const sem::Struct* str) { + // The struct type emits as just the name. The declaration would be + // emitted as part of emitting the declared types. + out << StructName(str); + return true; + }, + [&](const sem::Texture* tex) { + if (tex->IsAnyOf()) { + out << "depth"; + } else { + out << "texture"; + } + + switch (tex->dim()) { + case ast::TextureDimension::k1d: + out << "1d"; + break; + case ast::TextureDimension::k2d: + out << "2d"; + break; + case ast::TextureDimension::k2dArray: + out << "2d_array"; + break; + case ast::TextureDimension::k3d: + out << "3d"; + break; + case ast::TextureDimension::kCube: + out << "cube"; + break; + case ast::TextureDimension::kCubeArray: + out << "cube_array"; + break; + default: + diagnostics_.add_error(diag::System::Writer, + "Invalid texture dimensions"); + return false; + } + if (tex->IsAnyOf()) { + out << "_ms"; + } + out << "<"; + TINT_DEFER(out << ">"); + + return Switch( + tex, + [&](const sem::DepthTexture*) { + out << "float, access::sample"; + return true; + }, + [&](const sem::DepthMultisampledTexture*) { + out << "float, access::read"; + return true; + }, + [&](const sem::StorageTexture* storage) { + if (!EmitType(out, storage->type(), "")) { + return false; + } + + std::string access_str; + if (storage->access() == ast::Access::kRead) { + out << ", access::read"; + } else if (storage->access() == ast::Access::kWrite) { + out << ", access::write"; + } else { + diagnostics_.add_error( + diag::System::Writer, + "Invalid access control for storage texture"); + return false; + } + return true; + }, + [&](const sem::MultisampledTexture* ms) { + if (!EmitType(out, ms->type(), "")) { + return false; + } + out << ", access::read"; + return true; + }, + [&](const sem::SampledTexture* sampled) { + if (!EmitType(out, sampled->type(), "")) { + return false; + } + out << ", access::sample"; + return true; + }, + [&](Default) { + diagnostics_.add_error(diag::System::Writer, + "invalid texture type"); + return false; + }); + }, + [&](const sem::U32*) { + out << "uint"; + return true; + }, + [&](const sem::Vector* vec) { + if (!EmitType(out, vec->type(), "")) { + return false; + } + out << vec->Width(); + return true; + }, + [&](const sem::Void*) { + out << "void"; + return true; + }, + [&](Default) { + diagnostics_.add_error( + diag::System::Writer, + "unknown type in EmitType: " + type->type_name()); return false; - } - out << "* " << name; - if (name_printed) { - *name_printed = true; - } - } - return true; - } - - if (type->Is()) { - out << "sampler"; - return true; - } - - if (auto* str = type->As()) { - // The struct type emits as just the name. The declaration would be emitted - // as part of emitting the declared types. - out << StructName(str); - return true; - } - - if (auto* tex = type->As()) { - if (tex->IsAnyOf()) { - out << "depth"; - } else { - out << "texture"; - } - - switch (tex->dim()) { - case ast::TextureDimension::k1d: - out << "1d"; - break; - case ast::TextureDimension::k2d: - out << "2d"; - break; - case ast::TextureDimension::k2dArray: - out << "2d_array"; - break; - case ast::TextureDimension::k3d: - out << "3d"; - break; - case ast::TextureDimension::kCube: - out << "cube"; - break; - case ast::TextureDimension::kCubeArray: - out << "cube_array"; - break; - default: - diagnostics_.add_error(diag::System::Writer, - "Invalid texture dimensions"); - return false; - } - if (tex->IsAnyOf()) { - out << "_ms"; - } - out << "<"; - if (tex->Is()) { - out << "float, access::sample"; - } else if (tex->Is()) { - out << "float, access::read"; - } else if (auto* storage = tex->As()) { - if (!EmitType(out, storage->type(), "")) { - return false; - } - - std::string access_str; - if (storage->access() == ast::Access::kRead) { - out << ", access::read"; - } else if (storage->access() == ast::Access::kWrite) { - out << ", access::write"; - } else { - diagnostics_.add_error(diag::System::Writer, - "Invalid access control for storage texture"); - return false; - } - } else if (auto* ms = tex->As()) { - if (!EmitType(out, ms->type(), "")) { - return false; - } - out << ", access::read"; - } else if (auto* sampled = tex->As()) { - if (!EmitType(out, sampled->type(), "")) { - return false; - } - out << ", access::sample"; - } else { - diagnostics_.add_error(diag::System::Writer, "invalid texture type"); - return false; - } - out << ">"; - return true; - } - - if (type->Is()) { - out << "uint"; - return true; - } - - if (auto* vec = type->As()) { - if (!EmitType(out, vec->type(), "")) { - return false; - } - out << vec->Width(); - return true; - } - - if (type->Is()) { - out << "void"; - return true; - } - - diagnostics_.add_error(diag::System::Writer, - "unknown type in EmitType: " + type->type_name()); - return false; + }); } bool GeneratorImpl::EmitTypeAndName(std::ostream& out, @@ -2542,55 +2576,72 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) { // Emit attributes if (auto* decl = mem->Declaration()) { for (auto* attr : decl->attributes) { - if (auto* builtin = attr->As()) { - auto name = builtin_to_attribute(builtin->builtin); - if (name.empty()) { - diagnostics_.add_error(diag::System::Writer, "unknown builtin"); - return false; - } - out << " [[" << name << "]]"; - } else if (auto* loc = attr->As()) { - auto& pipeline_stage_uses = str->PipelineStageUses(); - if (pipeline_stage_uses.size() != 1) { - TINT_ICE(Writer, diagnostics_) - << "invalid entry point IO struct uses"; - } + bool ok = Switch( + attr, + [&](const ast::BuiltinAttribute* builtin) { + auto name = builtin_to_attribute(builtin->builtin); + if (name.empty()) { + diagnostics_.add_error(diag::System::Writer, "unknown builtin"); + return false; + } + out << " [[" << name << "]]"; + return true; + }, + [&](const ast::LocationAttribute* loc) { + auto& pipeline_stage_uses = str->PipelineStageUses(); + if (pipeline_stage_uses.size() != 1) { + TINT_ICE(Writer, diagnostics_) + << "invalid entry point IO struct uses"; + return false; + } - if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kVertexInput)) { - out << " [[attribute(" + std::to_string(loc->value) + ")]]"; - } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kVertexOutput)) { - out << " [[user(locn" + std::to_string(loc->value) + ")]]"; - } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kFragmentInput)) { - out << " [[user(locn" + std::to_string(loc->value) + ")]]"; - } else if (pipeline_stage_uses.count( - sem::PipelineStageUsage::kFragmentOutput)) { - out << " [[color(" + std::to_string(loc->value) + ")]]"; - } else { - TINT_ICE(Writer, diagnostics_) - << "invalid use of location attribute"; - } - } else if (auto* interpolate = attr->As()) { - auto name = interpolation_to_attribute(interpolate->type, - interpolate->sampling); - if (name.empty()) { - diagnostics_.add_error(diag::System::Writer, - "unknown interpolation attribute"); - return false; - } - out << " [[" << name << "]]"; - } else if (attr->Is()) { - if (invariant_define_name_.empty()) { - invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT"); - } - out << " " << invariant_define_name_; - } else if (!attr->IsAnyOf()) { - TINT_ICE(Writer, diagnostics_) - << "unhandled struct member attribute: " << attr->Name(); + if (pipeline_stage_uses.count( + sem::PipelineStageUsage::kVertexInput)) { + out << " [[attribute(" + std::to_string(loc->value) + ")]]"; + } else if (pipeline_stage_uses.count( + sem::PipelineStageUsage::kVertexOutput)) { + out << " [[user(locn" + std::to_string(loc->value) + ")]]"; + } else if (pipeline_stage_uses.count( + sem::PipelineStageUsage::kFragmentInput)) { + out << " [[user(locn" + std::to_string(loc->value) + ")]]"; + } else if (pipeline_stage_uses.count( + sem::PipelineStageUsage::kFragmentOutput)) { + out << " [[color(" + std::to_string(loc->value) + ")]]"; + } else { + TINT_ICE(Writer, diagnostics_) + << "invalid use of location decoration"; + return false; + } + return true; + }, + [&](const ast::InterpolateAttribute* interpolate) { + auto name = interpolation_to_attribute(interpolate->type, + interpolate->sampling); + if (name.empty()) { + diagnostics_.add_error(diag::System::Writer, + "unknown interpolation attribute"); + return false; + } + out << " [[" << name << "]]"; + return true; + }, + [&](const ast::InvariantAttribute*) { + if (invariant_define_name_.empty()) { + invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT"); + } + out << " " << invariant_define_name_; + return true; + }, + [&](const ast::StructMemberOffsetAttribute*) { return true; }, + [&](const ast::StructMemberAlignAttribute*) { return true; }, + [&](const ast::StructMemberSizeAttribute*) { return true; }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) + << "unhandled struct member attribute: " << attr->Name(); + return false; + }); + if (!ok) { + return false; } } } @@ -2796,77 +2847,96 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign( const sem::Type* ty) { - if (ty->IsAnyOf()) { - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // 2.1 Scalar Data Types - return {4, 4}; - } + return Switch( + ty, - if (auto* vec = ty->As()) { - auto num_els = vec->Width(); - auto* el_ty = vec->type(); - if (el_ty->IsAnyOf()) { - // Use a packed_vec type for 3-element vectors only. - if (num_els == 3) { + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // 2.1 Scalar Data Types + [&](const sem::U32*) { + return SizeAndAlign{4, 4}; + }, + [&](const sem::I32*) { + return SizeAndAlign{4, 4}; + }, + [&](const sem::F32*) { + return SizeAndAlign{4, 4}; + }, + + [&](const sem::Vector* vec) { + auto num_els = vec->Width(); + auto* el_ty = vec->type(); + if (el_ty->IsAnyOf()) { + // Use a packed_vec type for 3-element vectors only. + if (num_els == 3) { + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // 2.2.3 Packed Vector Types + return SizeAndAlign{num_els * 4, 4}; + } else { + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // 2.2 Vector Data Types + return SizeAndAlign{num_els * 4, num_els * 4}; + } + } + TINT_UNREACHABLE(Writer, diagnostics_) + << "Unhandled vector element type " << el_ty->TypeInfo().name; + return SizeAndAlign{}; + }, + + [&](const sem::Matrix* mat) { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // 2.2.3 Packed Vector Types - return SizeAndAlign{num_els * 4, 4}; - } else { - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // 2.2 Vector Data Types - return SizeAndAlign{num_els * 4, num_els * 4}; - } - } - } + // 2.3 Matrix Data Types + auto cols = mat->columns(); + auto rows = mat->rows(); + auto* el_ty = mat->type(); + if (el_ty->IsAnyOf()) { + static constexpr SizeAndAlign table[] = { + /* float2x2 */ {16, 8}, + /* float2x3 */ {32, 16}, + /* float2x4 */ {32, 16}, + /* float3x2 */ {24, 8}, + /* float3x3 */ {48, 16}, + /* float3x4 */ {48, 16}, + /* float4x2 */ {32, 8}, + /* float4x3 */ {64, 16}, + /* float4x4 */ {64, 16}, + }; + if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) { + return table[(3 * (cols - 2)) + (rows - 2)]; + } + } - if (auto* mat = ty->As()) { - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // 2.3 Matrix Data Types - auto cols = mat->columns(); - auto rows = mat->rows(); - auto* el_ty = mat->type(); - if (el_ty->IsAnyOf()) { - static constexpr SizeAndAlign table[] = { - /* float2x2 */ {16, 8}, - /* float2x3 */ {32, 16}, - /* float2x4 */ {32, 16}, - /* float3x2 */ {24, 8}, - /* float3x3 */ {48, 16}, - /* float3x4 */ {48, 16}, - /* float4x2 */ {32, 8}, - /* float4x3 */ {64, 16}, - /* float4x4 */ {64, 16}, - }; - if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) { - return table[(3 * (cols - 2)) + (rows - 2)]; - } - } - } + TINT_UNREACHABLE(Writer, diagnostics_) + << "Unhandled matrix element type " << el_ty->TypeInfo().name; + return SizeAndAlign{}; + }, - if (auto* arr = ty->As()) { - if (!arr->IsStrideImplicit()) { - TINT_ICE(Writer, diagnostics_) - << "arrays with explicit strides should have " - "removed with the PadArrayElements transform"; - return {}; - } - auto num_els = std::max(arr->Count(), 1); - return SizeAndAlign{arr->Stride() * num_els, arr->Align()}; - } + [&](const sem::Array* arr) { + if (!arr->IsStrideImplicit()) { + TINT_ICE(Writer, diagnostics_) + << "arrays with explicit strides should have " + "removed with the PadArrayElements transform"; + return SizeAndAlign{}; + } + auto num_els = std::max(arr->Count(), 1); + return SizeAndAlign{arr->Stride() * num_els, arr->Align()}; + }, - if (auto* str = ty->As()) { - // TODO(crbug.com/tint/650): There's an assumption here that MSL's default - // structure size and alignment matches WGSL's. We need to confirm this. - return SizeAndAlign{str->Size(), str->Align()}; - } + [&](const sem::Struct* str) { + // TODO(crbug.com/tint/650): There's an assumption here that MSL's + // default structure size and alignment matches WGSL's. We need to + // confirm this. + return SizeAndAlign{str->Size(), str->Align()}; + }, - if (auto* atomic = ty->As()) { - return MslPackedTypeSizeAndAlign(atomic->Type()); - } + [&](const sem::Atomic* atomic) { + return MslPackedTypeSizeAndAlign(atomic->Type()); + }, - TINT_UNREACHABLE(Writer, diagnostics_) - << "Unhandled type " << ty->TypeInfo().name; - return {}; + [&](Default) { + TINT_UNREACHABLE(Writer, diagnostics_) + << "Unhandled type " << ty->TypeInfo().name; + return SizeAndAlign{}; + }); } template diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 19707c0e32..932ae8c093 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -560,33 +560,37 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) { } uint32_t Builder::GenerateExpression(const ast::Expression* expr) { - if (auto* a = expr->As()) { - return GenerateAccessorExpression(a); - } - if (auto* b = expr->As()) { - return GenerateBinaryExpression(b); - } - if (auto* b = expr->As()) { - return GenerateBitcastExpression(b); - } - if (auto* c = expr->As()) { - return GenerateCallExpression(c); - } - if (auto* i = expr->As()) { - return GenerateIdentifierExpression(i); - } - if (auto* l = expr->As()) { - return GenerateLiteralIfNeeded(nullptr, l); - } - if (auto* m = expr->As()) { - return GenerateAccessorExpression(m); - } - if (auto* u = expr->As()) { - return GenerateUnaryOpExpression(u); - } - - error_ = "unknown expression type: " + std::string(expr->TypeInfo().name); - return 0; + return Switch( + expr, + [&](const ast::IndexAccessorExpression* a) { // + return GenerateAccessorExpression(a); + }, + [&](const ast::BinaryExpression* b) { // + return GenerateBinaryExpression(b); + }, + [&](const ast::BitcastExpression* b) { // + return GenerateBitcastExpression(b); + }, + [&](const ast::CallExpression* c) { // + return GenerateCallExpression(c); + }, + [&](const ast::IdentifierExpression* i) { // + return GenerateIdentifierExpression(i); + }, + [&](const ast::LiteralExpression* l) { // + return GenerateLiteralIfNeeded(nullptr, l); + }, + [&](const ast::MemberAccessorExpression* m) { // + return GenerateAccessorExpression(m); + }, + [&](const ast::UnaryOpExpression* u) { // + return GenerateUnaryOpExpression(u); + }, + [&](Default) -> uint32_t { + error_ = + "unknown expression type: " + std::string(expr->TypeInfo().name); + return 0; + }); } bool Builder::GenerateFunction(const ast::Function* func_ast) { @@ -861,33 +865,56 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) { push_type(spv::Op::OpVariable, std::move(ops)); for (auto* attr : var->attributes) { - if (auto* builtin = attr->As()) { - push_annot(spv::Op::OpDecorate, - {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn), - Operand::Int( - ConvertBuiltin(builtin->builtin, sem->StorageClass()))}); - } else if (auto* location = attr->As()) { - push_annot(spv::Op::OpDecorate, - {Operand::Int(var_id), Operand::Int(SpvDecorationLocation), - Operand::Int(location->value)}); - } else if (auto* interpolate = attr->As()) { - AddInterpolationDecorations(var_id, interpolate->type, - interpolate->sampling); - } else if (attr->Is()) { - push_annot(spv::Op::OpDecorate, - {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)}); - } else if (auto* binding = attr->As()) { - push_annot(spv::Op::OpDecorate, - {Operand::Int(var_id), Operand::Int(SpvDecorationBinding), - Operand::Int(binding->value)}); - } else if (auto* group = attr->As()) { - push_annot(spv::Op::OpDecorate, {Operand::Int(var_id), - Operand::Int(SpvDecorationDescriptorSet), - Operand::Int(group->value)}); - } else if (attr->Is()) { - // Spec constants are handled elsewhere - } else if (!attr->Is()) { - error_ = "unknown attribute"; + bool ok = Switch( + attr, + [&](const ast::BuiltinAttribute* builtin) { + push_annot(spv::Op::OpDecorate, + {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn), + Operand::Int(ConvertBuiltin(builtin->builtin, + sem->StorageClass()))}); + return true; + }, + [&](const ast::LocationAttribute* location) { + push_annot(spv::Op::OpDecorate, + {Operand::Int(var_id), Operand::Int(SpvDecorationLocation), + Operand::Int(location->value)}); + return true; + }, + [&](const ast::InterpolateAttribute* interpolate) { + AddInterpolationDecorations(var_id, interpolate->type, + interpolate->sampling); + return true; + }, + [&](const ast::InvariantAttribute*) { + push_annot( + spv::Op::OpDecorate, + {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)}); + return true; + }, + [&](const ast::BindingAttribute* binding) { + push_annot(spv::Op::OpDecorate, + {Operand::Int(var_id), Operand::Int(SpvDecorationBinding), + Operand::Int(binding->value)}); + return true; + }, + [&](const ast::GroupAttribute* group) { + push_annot( + spv::Op::OpDecorate, + {Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet), + Operand::Int(group->value)}); + return true; + }, + [&](const ast::OverrideAttribute*) { + return true; // Spec constants are handled elsewhere + }, + [&](const ast::InternalAttribute*) { + return true; // ignored + }, + [&](Default) { + error_ = "unknown attribute"; + return false; + }); + if (!ok) { return false; } } @@ -1123,19 +1150,21 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) { // promoted to storage with the VarForDynamicIndex transform. for (auto* accessor : accessors) { - if (auto* array = accessor->As()) { - if (!GenerateIndexAccessor(array, &info)) { - return 0; - } - } else if (auto* member = accessor->As()) { - if (!GenerateMemberAccessor(member, &info)) { - return 0; - } - - } else { - error_ = - "invalid accessor in list: " + std::string(accessor->TypeInfo().name); - return 0; + bool ok = Switch( + accessor, + [&](const ast::IndexAccessorExpression* array) { + return GenerateIndexAccessor(array, &info); + }, + [&](const ast::MemberAccessorExpression* member) { + return GenerateMemberAccessor(member, &info); + }, + [&](Default) { + error_ = "invalid accessor in list: " + + std::string(accessor->TypeInfo().name); + return false; + }); + if (!ok) { + return false; } } @@ -1653,21 +1682,28 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var, constant.constant_id = global->ConstantId(); } - if (auto* l = lit->As()) { - constant.kind = ScalarConstant::Kind::kBool; - constant.value.b = l->value; - } else if (auto* sl = lit->As()) { - constant.kind = ScalarConstant::Kind::kI32; - constant.value.i32 = sl->value; - } else if (auto* ul = lit->As()) { - constant.kind = ScalarConstant::Kind::kU32; - constant.value.u32 = ul->value; - } else if (auto* fl = lit->As()) { - constant.kind = ScalarConstant::Kind::kF32; - constant.value.f32 = fl->value; - } else { - error_ = "unknown literal type"; - return 0; + Switch( + lit, + [&](const ast::BoolLiteralExpression* l) { + constant.kind = ScalarConstant::Kind::kBool; + constant.value.b = l->value; + }, + [&](const ast::SintLiteralExpression* sl) { + constant.kind = ScalarConstant::Kind::kI32; + constant.value.i32 = sl->value; + }, + [&](const ast::UintLiteralExpression* ul) { + constant.kind = ScalarConstant::Kind::kU32; + constant.value.u32 = ul->value; + }, + [&](const ast::FloatLiteralExpression* fl) { + constant.kind = ScalarConstant::Kind::kF32; + constant.value.f32 = fl->value; + }, + [&](Default) { error_ = "unknown literal type"; }); + + if (!error_.empty()) { + return false; } return GenerateConstantIfNeeded(constant); @@ -2209,19 +2245,25 @@ bool Builder::GenerateBlockStatementWithoutScoping( uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) { auto* call = builder_.Sem().Get(expr); auto* target = call->Target(); - - if (auto* func = target->As()) { - return GenerateFunctionCall(call, func); - } - if (auto* builtin = target->As()) { - return GenerateBuiltinCall(call, builtin); - } - if (target->IsAnyOf()) { - return GenerateTypeConstructorOrConversion(call, nullptr); - } - TINT_ICE(Writer, builder_.Diagnostics()) - << "unhandled call target: " << target->TypeInfo().name; - return false; + return Switch( + target, + [&](const sem::Function* func) { + return GenerateFunctionCall(call, func); + }, + [&](const sem::Builtin* builtin) { + return GenerateBuiltinCall(call, builtin); + }, + [&](const sem::TypeConversion*) { + return GenerateTypeConstructorOrConversion(call, nullptr); + }, + [&](const sem::TypeConstructor*) { + return GenerateTypeConstructorOrConversion(call, nullptr); + }, + [&](Default) -> uint32_t { + TINT_ICE(Writer, builder_.Diagnostics()) + << "unhandled call target: " << target->TypeInfo().name; + return 0; + }); } uint32_t Builder::GenerateFunctionCall(const sem::Call* call, @@ -3790,46 +3832,49 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) { } bool Builder::GenerateStatement(const ast::Statement* stmt) { - if (auto* a = stmt->As()) { - return GenerateAssignStatement(a); - } - if (auto* b = stmt->As()) { - return GenerateBlockStatement(b); - } - if (auto* b = stmt->As()) { - return GenerateBreakStatement(b); - } - if (auto* c = stmt->As()) { - return GenerateCallExpression(c->expr) != 0; - } - if (auto* c = stmt->As()) { - return GenerateContinueStatement(c); - } - if (auto* d = stmt->As()) { - return GenerateDiscardStatement(d); - } - if (stmt->Is()) { - // Do nothing here, the fallthrough gets handled by the switch code. - return true; - } - if (auto* i = stmt->As()) { - return GenerateIfStatement(i); - } - if (auto* l = stmt->As()) { - return GenerateLoopStatement(l); - } - if (auto* r = stmt->As()) { - return GenerateReturnStatement(r); - } - if (auto* s = stmt->As()) { - return GenerateSwitchStatement(s); - } - if (auto* v = stmt->As()) { - return GenerateVariableDeclStatement(v); - } - - error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name); - return false; + return Switch( + stmt, + [&](const ast::AssignmentStatement* a) { + return GenerateAssignStatement(a); + }, + [&](const ast::BlockStatement* b) { // + return GenerateBlockStatement(b); + }, + [&](const ast::BreakStatement* b) { // + return GenerateBreakStatement(b); + }, + [&](const ast::CallStatement* c) { + return GenerateCallExpression(c->expr) != 0; + }, + [&](const ast::ContinueStatement* c) { + return GenerateContinueStatement(c); + }, + [&](const ast::DiscardStatement* d) { + return GenerateDiscardStatement(d); + }, + [&](const ast::FallthroughStatement*) { + // Do nothing here, the fallthrough gets handled by the switch code. + return true; + }, + [&](const ast::IfStatement* i) { // + return GenerateIfStatement(i); + }, + [&](const ast::LoopStatement* l) { // + return GenerateLoopStatement(l); + }, + [&](const ast::ReturnStatement* r) { // + return GenerateReturnStatement(r); + }, + [&](const ast::SwitchStatement* s) { // + return GenerateSwitchStatement(s); + }, + [&](const ast::VariableDeclStatement* v) { + return GenerateVariableDeclStatement(v); + }, + [&](Default) { + error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name); + return false; + }); } bool Builder::GenerateVariableDeclStatement( @@ -3872,78 +3917,91 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) { return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t { auto result = result_op(); auto id = result.to_i(); - if (auto* arr = type->As()) { - if (!GenerateArrayType(arr, result)) { - return 0; - } - } else if (type->Is()) { - push_type(spv::Op::OpTypeBool, {result}); - } else if (type->Is()) { - push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)}); - } else if (type->Is()) { - push_type(spv::Op::OpTypeInt, - {result, Operand::Int(32), Operand::Int(1)}); - } else if (auto* mat = type->As()) { - if (!GenerateMatrixType(mat, result)) { - return 0; - } - } else if (auto* ptr = type->As()) { - if (!GeneratePointerType(ptr, result)) { - return 0; - } - } else if (auto* ref = type->As()) { - if (!GenerateReferenceType(ref, result)) { - return 0; - } - } else if (auto* str = type->As()) { - if (!GenerateStructType(str, result)) { - return 0; - } - } else if (type->Is()) { - push_type(spv::Op::OpTypeInt, - {result, Operand::Int(32), Operand::Int(0)}); - } else if (auto* vec = type->As()) { - if (!GenerateVectorType(vec, result)) { - return 0; - } - } else if (type->Is()) { - push_type(spv::Op::OpTypeVoid, {result}); - } else if (auto* tex = type->As()) { - if (!GenerateTextureType(tex, result)) { - return 0; - } + bool ok = Switch( + type, + [&](const sem::Array* arr) { // + return GenerateArrayType(arr, result); + }, + [&](const sem::Bool*) { + push_type(spv::Op::OpTypeBool, {result}); + return true; + }, + [&](const sem::F32*) { + push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)}); + return true; + }, + [&](const sem::I32*) { + push_type(spv::Op::OpTypeInt, + {result, Operand::Int(32), Operand::Int(1)}); + return true; + }, + [&](const sem::Matrix* mat) { // + return GenerateMatrixType(mat, result); + }, + [&](const sem::Pointer* ptr) { // + return GeneratePointerType(ptr, result); + }, + [&](const sem::Reference* ref) { // + return GenerateReferenceType(ref, result); + }, + [&](const sem::Struct* str) { // + return GenerateStructType(str, result); + }, + [&](const sem::U32*) { + push_type(spv::Op::OpTypeInt, + {result, Operand::Int(32), Operand::Int(0)}); + return true; + }, + [&](const sem::Vector* vec) { // + return GenerateVectorType(vec, result); + }, + [&](const sem::Void*) { + push_type(spv::Op::OpTypeVoid, {result}); + return true; + }, + [&](const sem::StorageTexture* tex) { + if (!GenerateTextureType(tex, result)) { + return false; + } - if (auto* st = tex->As()) { - // Register all three access types of StorageTexture names. In SPIR-V, - // we must output a single type, while the variable is annotated with - // the access type. Doing this ensures we de-dupe. - type_name_to_id_[builder_ - .create( - st->dim(), st->texel_format(), - ast::Access::kRead, st->type()) - ->type_name()] = id; - type_name_to_id_[builder_ - .create( - st->dim(), st->texel_format(), - ast::Access::kWrite, st->type()) - ->type_name()] = id; - type_name_to_id_[builder_ - .create( - st->dim(), st->texel_format(), - ast::Access::kReadWrite, st->type()) - ->type_name()] = id; - } + // Register all three access types of StorageTexture names. In + // SPIR-V, we must output a single type, while the variable is + // annotated with the access type. Doing this ensures we de-dupe. + type_name_to_id_[builder_ + .create( + tex->dim(), tex->texel_format(), + ast::Access::kRead, tex->type()) + ->type_name()] = id; + type_name_to_id_[builder_ + .create( + tex->dim(), tex->texel_format(), + ast::Access::kWrite, tex->type()) + ->type_name()] = id; + type_name_to_id_[builder_ + .create( + tex->dim(), tex->texel_format(), + ast::Access::kReadWrite, tex->type()) + ->type_name()] = id; + return true; + }, + [&](const sem::Texture* tex) { + return GenerateTextureType(tex, result); + }, + [&](const sem::Sampler*) { + push_type(spv::Op::OpTypeSampler, {result}); - } else if (type->Is()) { - push_type(spv::Op::OpTypeSampler, {result}); + // Register both of the sampler type names. In SPIR-V they're the same + // sampler type, so we need to match that when we do the dedup check. + type_name_to_id_["__sampler_sampler"] = id; + type_name_to_id_["__sampler_comparison"] = id; + return true; + }, + [&](Default) { + error_ = "unable to convert type: " + type->type_name(); + return false; + }); - // Register both of the sampler type names. In SPIR-V they're the same - // sampler type, so we need to match that when we do the dedup check. - type_name_to_id_["__sampler_sampler"] = id; - type_name_to_id_["__sampler_comparison"] = id; - - } else { - error_ = "unable to convert type: " + type->type_name(); + if (!ok) { return 0; } @@ -3995,22 +4053,31 @@ bool Builder::GenerateTextureType(const sem::Texture* texture, } if (dim == ast::TextureDimension::kCubeArray) { - if (texture->Is() || - texture->Is()) { + if (texture->IsAnyOf()) { push_capability(SpvCapabilitySampledCubeArray); } } - uint32_t type_id = 0u; - if (texture->IsAnyOf()) { - type_id = GenerateTypeIfNeeded(builder_.create()); - } else if (auto* s = texture->As()) { - type_id = GenerateTypeIfNeeded(s->type()); - } else if (auto* ms = texture->As()) { - type_id = GenerateTypeIfNeeded(ms->type()); - } else if (auto* st = texture->As()) { - type_id = GenerateTypeIfNeeded(st->type()); - } + uint32_t type_id = Switch( + texture, + [&](const sem::DepthTexture*) { + return GenerateTypeIfNeeded(builder_.create()); + }, + [&](const sem::DepthMultisampledTexture*) { + return GenerateTypeIfNeeded(builder_.create()); + }, + [&](const sem::SampledTexture* t) { + return GenerateTypeIfNeeded(t->type()); + }, + [&](const sem::MultisampledTexture* t) { + return GenerateTypeIfNeeded(t->type()); + }, + [&](const sem::StorageTexture* t) { + return GenerateTypeIfNeeded(t->type()); + }, + [&](Default) -> uint32_t { // + return 0u; + }); if (type_id == 0u) { return false; } diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 094e88947b..96a459ab96 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -68,23 +68,17 @@ GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate() { // Generate global declarations in the order they appear in the module. for (auto* decl : program_->AST().GlobalDeclarations()) { - if (auto* td = decl->As()) { - if (!EmitTypeDecl(td)) { - return false; - } - } else if (auto* func = decl->As()) { - if (!EmitFunction(func)) { - return false; - } - } else if (auto* var = decl->As()) { - if (!EmitVariable(line(), var)) { - return false; - } - } else { - TINT_UNREACHABLE(Writer, diagnostics_); + if (!Switch( + decl, // + [&](const ast::TypeDecl* td) { return EmitTypeDecl(td); }, + [&](const ast::Function* func) { return EmitFunction(func); }, + [&](const ast::Variable* var) { return EmitVariable(line(), var); }, + [&](Default) { + TINT_UNREACHABLE(Writer, diagnostics_); + return false; + })) { return false; } - if (decl != program_->AST().GlobalDeclarations().back()) { line(); } @@ -94,59 +88,64 @@ bool GeneratorImpl::Generate() { } bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) { - if (auto* alias = ty->As()) { - auto out = line(); - out << "type " << program_->Symbols().NameFor(alias->name) << " = "; - if (!EmitType(out, alias->type)) { - return false; - } - out << ";"; - } else if (auto* str = ty->As()) { - if (!EmitStructType(str)) { - return false; - } - } else { - diagnostics_.add_error( - diag::System::Writer, - "unknown declared type: " + std::string(ty->TypeInfo().name)); - return false; - } - return true; + return Switch( + ty, + [&](const ast::Alias* alias) { // + auto out = line(); + out << "type " << program_->Symbols().NameFor(alias->name) << " = "; + if (!EmitType(out, alias->type)) { + return false; + } + out << ";"; + return true; + }, + [&](const ast::Struct* str) { // + return EmitStructType(str); + }, + [&](Default) { // + diagnostics_.add_error( + diag::System::Writer, + "unknown declared type: " + std::string(ty->TypeInfo().name)); + return false; + }); } bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) { - if (auto* a = expr->As()) { - return EmitIndexAccessor(out, a); - } - if (auto* b = expr->As()) { - return EmitBinary(out, b); - } - if (auto* b = expr->As()) { - return EmitBitcast(out, b); - } - if (auto* c = expr->As()) { - return EmitCall(out, c); - } - if (auto* i = expr->As()) { - return EmitIdentifier(out, i); - } - if (auto* l = expr->As()) { - return EmitLiteral(out, l); - } - if (auto* m = expr->As()) { - return EmitMemberAccessor(out, m); - } - if (expr->Is()) { - out << "_"; - return true; - } - if (auto* u = expr->As()) { - return EmitUnaryOp(out, u); - } - - diagnostics_.add_error(diag::System::Writer, "unknown expression type"); - return false; + return Switch( + expr, + [&](const ast::IndexAccessorExpression* a) { // + return EmitIndexAccessor(out, a); + }, + [&](const ast::BinaryExpression* b) { // + return EmitBinary(out, b); + }, + [&](const ast::BitcastExpression* b) { // + return EmitBitcast(out, b); + }, + [&](const ast::CallExpression* c) { // + return EmitCall(out, c); + }, + [&](const ast::IdentifierExpression* i) { // + return EmitIdentifier(out, i); + }, + [&](const ast::LiteralExpression* l) { // + return EmitLiteral(out, l); + }, + [&](const ast::MemberAccessorExpression* m) { // + return EmitMemberAccessor(out, m); + }, + [&](const ast::PhonyExpression*) { // + out << "_"; + return true; + }, + [&](const ast::UnaryOpExpression* u) { // + return EmitUnaryOp(out, u); + }, + [&](Default) { + diagnostics_.add_error(diag::System::Writer, "unknown expression type"); + return false; + }); } bool GeneratorImpl::EmitIndexAccessor( @@ -250,19 +249,28 @@ bool GeneratorImpl::EmitCall(std::ostream& out, bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* lit) { - if (auto* bl = lit->As()) { - out << (bl->value ? "true" : "false"); - } else if (auto* fl = lit->As()) { - out << FloatToBitPreservingString(fl->value); - } else if (auto* sl = lit->As()) { - out << sl->value; - } else if (auto* ul = lit->As()) { - out << ul->value << "u"; - } else { - diagnostics_.add_error(diag::System::Writer, "unknown literal type"); - return false; - } - return true; + return Switch( + lit, + [&](const ast::BoolLiteralExpression* bl) { // + out << (bl->value ? "true" : "false"); + return true; + }, + [&](const ast::FloatLiteralExpression* fl) { // + out << FloatToBitPreservingString(fl->value); + return true; + }, + [&](const ast::SintLiteralExpression* sl) { // + out << sl->value; + return true; + }, + [&](const ast::UintLiteralExpression* ul) { // + out << ul->value << "u"; + return true; + }, + [&](Default) { // + diagnostics_.add_error(diag::System::Writer, "unknown literal type"); + return false; + }); } bool GeneratorImpl::EmitIdentifier(std::ostream& out, @@ -366,155 +374,208 @@ bool GeneratorImpl::EmitAccess(std::ostream& out, const ast::Access access) { } bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) { - if (auto* ary = ty->As()) { - for (auto* attr : ary->attributes) { - if (auto* stride = attr->As()) { - out << "@stride(" << stride->stride << ") "; - } - } + return Switch( + ty, + [&](const ast::Array* ary) { + for (auto* attr : ary->attributes) { + if (auto* stride = attr->As()) { + out << "@stride(" << stride->stride << ") "; + } + } - out << "array<"; - if (!EmitType(out, ary->type)) { - return false; - } + out << "array<"; + if (!EmitType(out, ary->type)) { + return false; + } - if (!ary->IsRuntimeArray()) { - out << ", "; - if (!EmitExpression(out, ary->count)) { - return false; - } - } + if (!ary->IsRuntimeArray()) { + out << ", "; + if (!EmitExpression(out, ary->count)) { + return false; + } + } - out << ">"; - } else if (ty->Is()) { - out << "bool"; - } else if (ty->Is()) { - out << "f32"; - } else if (ty->Is()) { - out << "i32"; - } else if (auto* mat = ty->As()) { - out << "mat" << mat->columns << "x" << mat->rows; - if (auto* el_ty = mat->type) { - out << "<"; - if (!EmitType(out, el_ty)) { - return false; - } - out << ">"; - } - } else if (auto* ptr = ty->As()) { - out << "ptr<" << ptr->storage_class << ", "; - if (!EmitType(out, ptr->type)) { - return false; - } - if (ptr->access != ast::Access::kUndefined) { - out << ", "; - if (!EmitAccess(out, ptr->access)) { - return false; - } - } - out << ">"; - } else if (auto* atomic = ty->As()) { - out << "atomic<"; - if (!EmitType(out, atomic->type)) { - return false; - } - out << ">"; - } else if (auto* sampler = ty->As()) { - out << "sampler"; + out << ">"; + return true; + }, + [&](const ast::Bool*) { + out << "bool"; + return true; + }, + [&](const ast::F32*) { + out << "f32"; + return true; + }, + [&](const ast::I32*) { + out << "i32"; + return true; + }, + [&](const ast::Matrix* mat) { + out << "mat" << mat->columns << "x" << mat->rows; + if (auto* el_ty = mat->type) { + out << "<"; + if (!EmitType(out, el_ty)) { + return false; + } + out << ">"; + } + return true; + }, + [&](const ast::Pointer* ptr) { + out << "ptr<" << ptr->storage_class << ", "; + if (!EmitType(out, ptr->type)) { + return false; + } + if (ptr->access != ast::Access::kUndefined) { + out << ", "; + if (!EmitAccess(out, ptr->access)) { + return false; + } + } + out << ">"; + return true; + }, + [&](const ast::Atomic* atomic) { + out << "atomic<"; + if (!EmitType(out, atomic->type)) { + return false; + } + out << ">"; + return true; + }, + [&](const ast::Sampler* sampler) { + out << "sampler"; - if (sampler->IsComparison()) { - out << "_comparison"; - } - } else if (ty->Is()) { - out << "texture_external"; - } else if (auto* texture = ty->As()) { - out << "texture_"; - if (texture->Is()) { - out << "depth_"; - } else if (texture->Is()) { - out << "depth_multisampled_"; - } else if (texture->Is()) { - /* nothing to emit */ - } else if (texture->Is()) { - out << "multisampled_"; - } else if (texture->Is()) { - out << "storage_"; - } else { - diagnostics_.add_error(diag::System::Writer, "unknown texture type"); - return false; - } + if (sampler->IsComparison()) { + out << "_comparison"; + } + return true; + }, + [&](const ast::ExternalTexture*) { + out << "texture_external"; + return true; + }, + [&](const ast::Texture* texture) { + out << "texture_"; + bool ok = Switch( + texture, + [&](const ast::DepthTexture*) { // + out << "depth_"; + return true; + }, + [&](const ast::DepthMultisampledTexture*) { // + out << "depth_multisampled_"; + return true; + }, + [&](const ast::SampledTexture*) { // + /* nothing to emit */ + return true; + }, + [&](const ast::MultisampledTexture*) { // + out << "multisampled_"; + return true; + }, + [&](const ast::StorageTexture*) { // + out << "storage_"; + return true; + }, + [&](Default) { // + diagnostics_.add_error(diag::System::Writer, + "unknown texture type"); + return false; + }); + if (!ok) { + return false; + } - switch (texture->dim) { - case ast::TextureDimension::k1d: - out << "1d"; - break; - case ast::TextureDimension::k2d: - out << "2d"; - break; - case ast::TextureDimension::k2dArray: - out << "2d_array"; - break; - case ast::TextureDimension::k3d: - out << "3d"; - break; - case ast::TextureDimension::kCube: - out << "cube"; - break; - case ast::TextureDimension::kCubeArray: - out << "cube_array"; - break; - default: - diagnostics_.add_error(diag::System::Writer, - "unknown texture dimension"); - return false; - } + switch (texture->dim) { + case ast::TextureDimension::k1d: + out << "1d"; + break; + case ast::TextureDimension::k2d: + out << "2d"; + break; + case ast::TextureDimension::k2dArray: + out << "2d_array"; + break; + case ast::TextureDimension::k3d: + out << "3d"; + break; + case ast::TextureDimension::kCube: + out << "cube"; + break; + case ast::TextureDimension::kCubeArray: + out << "cube_array"; + break; + default: + diagnostics_.add_error(diag::System::Writer, + "unknown texture dimension"); + return false; + } - if (auto* sampled = texture->As()) { - out << "<"; - if (!EmitType(out, sampled->type)) { + return Switch( + texture, + [&](const ast::SampledTexture* sampled) { // + out << "<"; + if (!EmitType(out, sampled->type)) { + return false; + } + out << ">"; + return true; + }, + [&](const ast::MultisampledTexture* ms) { // + out << "<"; + if (!EmitType(out, ms->type)) { + return false; + } + out << ">"; + return true; + }, + [&](const ast::StorageTexture* storage) { // + out << "<"; + if (!EmitImageFormat(out, storage->format)) { + return false; + } + out << ", "; + if (!EmitAccess(out, storage->access)) { + return false; + } + out << ">"; + return true; + }, + [&](Default) { // + return true; + }); + }, + [&](const ast::U32*) { + out << "u32"; + return true; + }, + [&](const ast::Vector* vec) { + out << "vec" << vec->width; + if (auto* el_ty = vec->type) { + out << "<"; + if (!EmitType(out, el_ty)) { + return false; + } + out << ">"; + } + return true; + }, + [&](const ast::Void*) { + out << "void"; + return true; + }, + [&](const ast::TypeName* tn) { + out << program_->Symbols().NameFor(tn->name); + return true; + }, + [&](Default) { + diagnostics_.add_error( + diag::System::Writer, + "unknown type in EmitType: " + std::string(ty->TypeInfo().name)); return false; - } - out << ">"; - } else if (auto* ms = texture->As()) { - out << "<"; - if (!EmitType(out, ms->type)) { - return false; - } - out << ">"; - } else if (auto* storage = texture->As()) { - out << "<"; - if (!EmitImageFormat(out, storage->format)) { - return false; - } - out << ", "; - if (!EmitAccess(out, storage->access)) { - return false; - } - out << ">"; - } - - } else if (ty->Is()) { - out << "u32"; - } else if (auto* vec = ty->As()) { - out << "vec" << vec->width; - if (auto* el_ty = vec->type) { - out << "<"; - if (!EmitType(out, el_ty)) { - return false; - } - out << ">"; - } - } else if (ty->Is()) { - out << "void"; - } else if (auto* tn = ty->As()) { - out << program_->Symbols().NameFor(tn->name); - } else { - diagnostics_.add_error( - diag::System::Writer, - "unknown type in EmitType: " + std::string(ty->TypeInfo().name)); - return false; - } - return true; + }); } bool GeneratorImpl::EmitStructType(const ast::Struct* str) { @@ -632,56 +693,90 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out, } first = false; out << "@"; - if (auto* workgroup = attr->As()) { - auto values = workgroup->Values(); - out << "workgroup_size("; - for (int i = 0; i < 3; i++) { - if (values[i]) { - if (i > 0) { - out << ", "; + bool ok = Switch( + attr, + [&](const ast::WorkgroupAttribute* workgroup) { + auto values = workgroup->Values(); + out << "workgroup_size("; + for (int i = 0; i < 3; i++) { + if (values[i]) { + if (i > 0) { + out << ", "; + } + if (!EmitExpression(out, values[i])) { + return false; + } + } } - if (!EmitExpression(out, values[i])) { - return false; + out << ")"; + return true; + }, + [&](const ast::StructBlockAttribute*) { // + out << "block"; + return true; + }, + [&](const ast::StageAttribute* stage) { + out << "stage(" << stage->stage << ")"; + return true; + }, + [&](const ast::BindingAttribute* binding) { + out << "binding(" << binding->value << ")"; + return true; + }, + [&](const ast::GroupAttribute* group) { + out << "group(" << group->value << ")"; + return true; + }, + [&](const ast::LocationAttribute* location) { + out << "location(" << location->value << ")"; + return true; + }, + [&](const ast::BuiltinAttribute* builtin) { + out << "builtin(" << builtin->builtin << ")"; + return true; + }, + [&](const ast::InterpolateAttribute* interpolate) { + out << "interpolate(" << interpolate->type; + if (interpolate->sampling != ast::InterpolationSampling::kNone) { + out << ", " << interpolate->sampling; } - } - } - out << ")"; - } else if (attr->Is()) { - out << "block"; - } else if (auto* stage = attr->As()) { - out << "stage(" << stage->stage << ")"; - } else if (auto* binding = attr->As()) { - out << "binding(" << binding->value << ")"; - } else if (auto* group = attr->As()) { - out << "group(" << group->value << ")"; - } else if (auto* location = attr->As()) { - out << "location(" << location->value << ")"; - } else if (auto* builtin = attr->As()) { - out << "builtin(" << builtin->builtin << ")"; - } else if (auto* interpolate = attr->As()) { - out << "interpolate(" << interpolate->type; - if (interpolate->sampling != ast::InterpolationSampling::kNone) { - out << ", " << interpolate->sampling; - } - out << ")"; - } else if (attr->Is()) { - out << "invariant"; - } else if (auto* override_attr = attr->As()) { - out << "override"; - if (override_attr->has_value) { - out << "(" << override_attr->value << ")"; - } - } else if (auto* size = attr->As()) { - out << "size(" << size->size << ")"; - } else if (auto* align = attr->As()) { - out << "align(" << align->align << ")"; - } else if (auto* stride = attr->As()) { - out << "stride(" << stride->stride << ")"; - } else if (auto* internal = attr->As()) { - out << "internal(" << internal->InternalName() << ")"; - } else { - TINT_ICE(Writer, diagnostics_) - << "Unsupported attribute '" << attr->TypeInfo().name << "'"; + out << ")"; + return true; + }, + [&](const ast::InvariantAttribute*) { + out << "invariant"; + return true; + }, + [&](const ast::OverrideAttribute* override_deco) { + out << "override"; + if (override_deco->has_value) { + out << "(" << override_deco->value << ")"; + } + return true; + }, + [&](const ast::StructMemberSizeAttribute* size) { + out << "size(" << size->size << ")"; + return true; + }, + [&](const ast::StructMemberAlignAttribute* align) { + out << "align(" << align->align << ")"; + return true; + }, + [&](const ast::StrideAttribute* stride) { + out << "stride(" << stride->stride << ")"; + return true; + }, + [&](const ast::InternalAttribute* internal) { + out << "internal(" << internal->InternalName() << ")"; + return true; + }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) + << "Unsupported attribute '" << attr->TypeInfo().name << "'"; + return false; + }); + + if (!ok) { return false; } } @@ -809,55 +904,36 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) { } bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { - if (auto* a = stmt->As()) { - return EmitAssign(a); - } - if (auto* b = stmt->As()) { - return EmitBlock(b); - } - if (auto* b = stmt->As()) { - return EmitBreak(b); - } - if (auto* c = stmt->As()) { - auto out = line(); - if (!EmitCall(out, c->expr)) { - return false; - } - out << ";"; - return true; - } - if (auto* c = stmt->As()) { - return EmitContinue(c); - } - if (auto* d = stmt->As()) { - return EmitDiscard(d); - } - if (auto* f = stmt->As()) { - return EmitFallthrough(f); - } - if (auto* i = stmt->As()) { - return EmitIf(i); - } - if (auto* l = stmt->As()) { - return EmitLoop(l); - } - if (auto* l = stmt->As()) { - return EmitForLoop(l); - } - if (auto* r = stmt->As()) { - return EmitReturn(r); - } - if (auto* s = stmt->As()) { - return EmitSwitch(s); - } - if (auto* v = stmt->As()) { - return EmitVariable(line(), v->variable); - } - - diagnostics_.add_error( - diag::System::Writer, - "unknown statement type: " + std::string(stmt->TypeInfo().name)); - return false; + return Switch( + stmt, // + [&](const ast::AssignmentStatement* a) { return EmitAssign(a); }, + [&](const ast::BlockStatement* b) { return EmitBlock(b); }, + [&](const ast::BreakStatement* b) { return EmitBreak(b); }, + [&](const ast::CallStatement* c) { + auto out = line(); + if (!EmitCall(out, c->expr)) { + return false; + } + out << ";"; + return true; + }, + [&](const ast::ContinueStatement* c) { return EmitContinue(c); }, + [&](const ast::DiscardStatement* d) { return EmitDiscard(d); }, + [&](const ast::FallthroughStatement* f) { return EmitFallthrough(f); }, + [&](const ast::IfStatement* i) { return EmitIf(i); }, + [&](const ast::LoopStatement* l) { return EmitLoop(l); }, + [&](const ast::ForLoopStatement* l) { return EmitForLoop(l); }, + [&](const ast::ReturnStatement* r) { return EmitReturn(r); }, + [&](const ast::SwitchStatement* s) { return EmitSwitch(s); }, + [&](const ast::VariableDeclStatement* v) { + return EmitVariable(line(), v->variable); + }, + [&](Default) { + diagnostics_.add_error( + diag::System::Writer, + "unknown statement type: " + std::string(stmt->TypeInfo().name)); + return false; + }); } bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {