Add tint::Switch()

A type dispatch helper with replaces chains of:

  if (auto* a = obj->As<A>()) {
    ...
  } else if (auto* b = obj->As<B>()) {
    ...
  } else {
    ...
  }

with:

  Switch(obj,
    [&](A* a) { ... },
    [&](B* b) { ... },
    [&](Default) { ... });

This new helper provides greater opportunities for optimizations, avoids
scoping issues with if-else blocks, and is slightly cleaner (IMO).

Bug: tint:1383
Change-Id: Ice469a03342ef57cbcf65f69753e4b528ac50137
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78543
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-02-04 15:38:23 +00:00 committed by Tint LUCI CQ
parent fa0d64b76d
commit de857e1c58
11 changed files with 2413 additions and 1570 deletions

View File

@ -1160,6 +1160,7 @@ if(TINT_BUILD_BENCHMARKS)
endif() endif()
set(TINT_BENCHMARK_SRC set(TINT_BENCHMARK_SRC
"castable_bench.cc"
"bench/benchmark.cc" "bench/benchmark.cc"
"reader/wgsl/parser_bench.cc" "reader/wgsl/parser_bench.cc"
) )

View File

@ -35,16 +35,15 @@ Module::Module(ProgramID pid,
continue; continue;
} }
if (auto* ty = decl->As<ast::TypeDecl>()) { Switch(
type_decls_.push_back(ty); decl, //
} else if (auto* func = decl->As<Function>()) { [&](const ast::TypeDecl* type) { type_decls_.push_back(type); },
functions_.push_back(func); [&](const Function* func) { functions_.push_back(func); },
} else if (auto* var = decl->As<Variable>()) { [&](const Variable* var) { global_variables_.push_back(var); },
global_variables_.push_back(var); [&](Default) {
} else { diag::List diagnostics;
diag::List diagnostics; TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
TINT_ICE(AST, diagnostics) << "Unknown global declaration type"; });
}
} }
} }
@ -101,19 +100,24 @@ void Module::Copy(CloneContext* ctx, const Module* src) {
<< "src global declaration was nullptr"; << "src global declaration was nullptr";
continue; continue;
} }
if (auto* type = decl->As<ast::TypeDecl>()) { Switch(
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id); decl,
type_decls_.push_back(type); [&](const ast::TypeDecl* type) {
} else if (auto* func = decl->As<Function>()) { TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id); type_decls_.push_back(type);
functions_.push_back(func); },
} else if (auto* var = decl->As<Variable>()) { [&](const Function* func) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
global_variables_.push_back(var); functions_.push_back(func);
} else { },
TINT_ICE(AST, ctx->dst->Diagnostics()) [&](const Variable* var) {
<< "Unknown global declaration type"; TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
} global_variables_.push_back(var);
},
[&](Default) {
TINT_ICE(AST, ctx->dst->Diagnostics())
<< "Unknown global declaration type";
});
} }
} }

View File

@ -101,30 +101,47 @@ bool TraverseExpressions(const ast::Expression* root,
} }
} }
if (auto* idx = expr->As<IndexAccessorExpression>()) { bool ok = Switch(
push_pair(idx->object, idx->index); expr,
} else if (auto* bin_op = expr->As<BinaryExpression>()) { [&](const IndexAccessorExpression* idx) {
push_pair(bin_op->lhs, bin_op->rhs); push_pair(idx->object, idx->index);
} else if (auto* bitcast = expr->As<BitcastExpression>()) { return true;
to_visit.push_back(bitcast->expr); },
} else if (auto* call = expr->As<CallExpression>()) { [&](const BinaryExpression* bin_op) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the push_pair(bin_op->lhs, bin_op->rhs);
// function name in the traversal. return true;
// to_visit.push_back(call->func); },
push_list(call->args); [&](const BitcastExpression* bitcast) {
} else if (auto* member = expr->As<MemberAccessorExpression>()) { to_visit.push_back(bitcast->expr);
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the return true;
// member name in the traversal. },
// push_pair(member->structure, member->member); [&](const CallExpression* call) {
to_visit.push_back(member->structure); // TODO(crbug.com/tint/1257): Resolver breaks if we actually include
} else if (auto* unary = expr->As<UnaryOpExpression>()) { // the function name in the traversal. to_visit.push_back(call->func);
to_visit.push_back(unary->expr); push_list(call->args);
} else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression, return true;
PhonyExpression>()) { },
// Leaf expression [&](const MemberAccessorExpression* member) {
} else { // TODO(crbug.com/tint/1257): Resolver breaks if we actually include
TINT_ICE(AST, diags) << "unhandled expression type: " // the member name in the traversal. push_pair(member->structure,
<< expr->TypeInfo().name; // member->member);
to_visit.push_back(member->structure);
return true;
},
[&](const UnaryOpExpression* unary) {
to_visit.push_back(unary->expr);
return true;
},
[&](Default) {
if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
PhonyExpression>()) {
return true; // Leaf expression
}
TINT_ICE(AST, diags)
<< "unhandled expression type: " << expr->TypeInfo().name;
return false;
});
if (!ok) {
return false; return false;
} }
} }

View File

@ -453,6 +453,105 @@ class Castable : public BASE {
} }
}; };
/// Default can be used as the default case for a Switch(), when all previous
/// cases failed to match.
///
/// Example:
/// ```
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ },
/// [&](Default) { /* If not TypeA or TypeB */ });
/// ```
struct Default {};
/// Switch is used to dispatch one of the provided callback case handler
/// functions based on the type of `object` and the parameter type of the case
/// handlers. Switch will sequentially check the type of `object` against each
/// of the switch case handler functions, and will invoke the first case handler
/// function which has a parameter type that matches the object type. When a
/// case handler is matched, it will be called with the single argument of
/// `object` cast to the case handler's parameter type. Switch will invoke at
/// most one case handler. Each of the case functions must have the signature
/// `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
/// is the return type, consistent across all case handlers.
///
/// An optional default case function with the signature `R(Default)` can be
/// used as the last case. This default case will be called if all previous
/// cases failed to match.
///
/// Example:
/// ```
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ });
///
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ },
/// [&](Default) { /* Called if object is not TypeA or TypeB */ });
/// ```
///
/// @param object the object who's type is used to
/// @param first_case the first switch case
/// @param other_cases additional switch cases (optional)
/// @return the value returned by the called case. If no cases matched, then the
/// zero value for the consistent case type.
template <typename T, typename FIRST_CASE, typename... OTHER_CASES>
traits::ReturnType<FIRST_CASE> //
Switch(T* object, FIRST_CASE&& first_case, OTHER_CASES&&... other_cases) {
using ReturnType = traits::ReturnType<FIRST_CASE>;
using CaseType = std::remove_pointer_t<traits::ParameterType<FIRST_CASE, 0>>;
static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
static_assert(traits::SignatureOfT<FIRST_CASE>::parameter_count == 1,
"Switch case must have a single parameter");
if constexpr (std::is_same_v<CaseType, Default>) {
// Default case. Must be last.
(void)object; // 'object' is not used by the Default case.
static_assert(sizeof...(other_cases) == 0,
"Switch Default case must come last");
if constexpr (kHasReturnType) {
return first_case({});
} else {
first_case({});
return;
}
} else {
// Regular case.
static_assert(traits::IsTypeOrDerived<CaseType, CastableBase>::value,
"Switch case parameter is not a Castable pointer");
// Does the case match?
if (auto* ptr = As<CaseType>(object)) {
if constexpr (kHasReturnType) {
return first_case(ptr);
} else {
first_case(ptr);
return;
}
}
// Case did not match. Got any more cases to try?
if constexpr (sizeof...(other_cases) > 0) {
// Try the next cases...
if constexpr (kHasReturnType) {
auto res = Switch(object, std::forward<OTHER_CASES>(other_cases)...);
static_assert(std::is_same_v<decltype(res), ReturnType>,
"Switch case types do not have consistent return type");
return res;
} else {
Switch(object, std::forward<OTHER_CASES>(other_cases)...);
return;
}
} else {
// That was the last case. No cases matched.
if constexpr (kHasReturnType) {
return {};
} else {
return;
}
}
}
}
} // namespace tint } // namespace tint
TINT_CASTABLE_POP_DISABLE_WARNINGS(); TINT_CASTABLE_POP_DISABLE_WARNINGS();

270
src/castable_bench.cc Normal file
View File

@ -0,0 +1,270 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bench/benchmark.h"
namespace tint {
namespace {
struct Base : public tint::Castable<Base> {};
struct A : public tint::Castable<A, Base> {};
struct AA : public tint::Castable<AA, A> {};
struct AAA : public tint::Castable<AAA, AA> {};
struct AAB : public tint::Castable<AAB, AA> {};
struct AAC : public tint::Castable<AAC, AA> {};
struct AB : public tint::Castable<AB, A> {};
struct ABA : public tint::Castable<ABA, AB> {};
struct ABB : public tint::Castable<ABB, AB> {};
struct ABC : public tint::Castable<ABC, AB> {};
struct AC : public tint::Castable<AC, A> {};
struct ACA : public tint::Castable<ACA, AC> {};
struct ACB : public tint::Castable<ACB, AC> {};
struct ACC : public tint::Castable<ACC, AC> {};
struct B : public tint::Castable<B, Base> {};
struct BA : public tint::Castable<BA, B> {};
struct BAA : public tint::Castable<BAA, BA> {};
struct BAB : public tint::Castable<BAB, BA> {};
struct BAC : public tint::Castable<BAC, BA> {};
struct BB : public tint::Castable<BB, B> {};
struct BBA : public tint::Castable<BBA, BB> {};
struct BBB : public tint::Castable<BBB, BB> {};
struct BBC : public tint::Castable<BBC, BB> {};
struct BC : public tint::Castable<BC, B> {};
struct BCA : public tint::Castable<BCA, BC> {};
struct BCB : public tint::Castable<BCB, BC> {};
struct BCC : public tint::Castable<BCC, BC> {};
struct C : public tint::Castable<C, Base> {};
struct CA : public tint::Castable<CA, C> {};
struct CAA : public tint::Castable<CAA, CA> {};
struct CAB : public tint::Castable<CAB, CA> {};
struct CAC : public tint::Castable<CAC, CA> {};
struct CB : public tint::Castable<CB, C> {};
struct CBA : public tint::Castable<CBA, CB> {};
struct CBB : public tint::Castable<CBB, CB> {};
struct CBC : public tint::Castable<CBC, CB> {};
struct CC : public tint::Castable<CC, C> {};
struct CCA : public tint::Castable<CCA, CC> {};
struct CCB : public tint::Castable<CCB, CC> {};
struct CCC : public tint::Castable<CCC, CC> {};
using AllTypes = std::tuple<Base,
A,
AA,
AAA,
AAB,
AAC,
AB,
ABA,
ABB,
ABC,
AC,
ACA,
ACB,
ACC,
B,
BA,
BAA,
BAB,
BAC,
BB,
BBA,
BBB,
BBC,
BC,
BCA,
BCB,
BCC,
C,
CA,
CAA,
CAB,
CAC,
CB,
CBA,
CBB,
CBC,
CC,
CCA,
CCB,
CCC>;
std::vector<std::unique_ptr<Base>> MakeObjects() {
std::vector<std::unique_ptr<Base>> out;
out.emplace_back(std::make_unique<Base>());
out.emplace_back(std::make_unique<A>());
out.emplace_back(std::make_unique<AA>());
out.emplace_back(std::make_unique<AAA>());
out.emplace_back(std::make_unique<AAB>());
out.emplace_back(std::make_unique<AAC>());
out.emplace_back(std::make_unique<AB>());
out.emplace_back(std::make_unique<ABA>());
out.emplace_back(std::make_unique<ABB>());
out.emplace_back(std::make_unique<ABC>());
out.emplace_back(std::make_unique<AC>());
out.emplace_back(std::make_unique<ACA>());
out.emplace_back(std::make_unique<ACB>());
out.emplace_back(std::make_unique<ACC>());
out.emplace_back(std::make_unique<B>());
out.emplace_back(std::make_unique<BA>());
out.emplace_back(std::make_unique<BAA>());
out.emplace_back(std::make_unique<BAB>());
out.emplace_back(std::make_unique<BAC>());
out.emplace_back(std::make_unique<BB>());
out.emplace_back(std::make_unique<BBA>());
out.emplace_back(std::make_unique<BBB>());
out.emplace_back(std::make_unique<BBC>());
out.emplace_back(std::make_unique<BC>());
out.emplace_back(std::make_unique<BCA>());
out.emplace_back(std::make_unique<BCB>());
out.emplace_back(std::make_unique<BCC>());
out.emplace_back(std::make_unique<C>());
out.emplace_back(std::make_unique<CA>());
out.emplace_back(std::make_unique<CAA>());
out.emplace_back(std::make_unique<CAB>());
out.emplace_back(std::make_unique<CAC>());
out.emplace_back(std::make_unique<CB>());
out.emplace_back(std::make_unique<CBA>());
out.emplace_back(std::make_unique<CBB>());
out.emplace_back(std::make_unique<CBC>());
out.emplace_back(std::make_unique<CC>());
out.emplace_back(std::make_unique<CCA>());
out.emplace_back(std::make_unique<CCB>());
out.emplace_back(std::make_unique<CCC>());
return out;
}
void CastableLargeSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const AAA*) { ::benchmark::DoNotOptimize(i += 40); },
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 50); },
[&](const AAC*) { ::benchmark::DoNotOptimize(i += 60); },
[&](const ABA*) { ::benchmark::DoNotOptimize(i += 80); },
[&](const ABB*) { ::benchmark::DoNotOptimize(i += 90); },
[&](const ABC*) { ::benchmark::DoNotOptimize(i += 100); },
[&](const ACA*) { ::benchmark::DoNotOptimize(i += 120); },
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
[&](const ACC*) { ::benchmark::DoNotOptimize(i += 140); },
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
[&](const BAC*) { ::benchmark::DoNotOptimize(i += 190); },
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
[&](const BBC*) { ::benchmark::DoNotOptimize(i += 230); },
[&](const BCA*) { ::benchmark::DoNotOptimize(i += 250); },
[&](const BCB*) { ::benchmark::DoNotOptimize(i += 260); },
[&](const BCC*) { ::benchmark::DoNotOptimize(i += 270); },
[&](const CA*) { ::benchmark::DoNotOptimize(i += 290); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
[&](const CAB*) { ::benchmark::DoNotOptimize(i += 310); },
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 320); },
[&](const CBA*) { ::benchmark::DoNotOptimize(i += 340); },
[&](const CBB*) { ::benchmark::DoNotOptimize(i += 350); },
[&](const CBC*) { ::benchmark::DoNotOptimize(i += 360); },
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableLargeSwitch);
void CastableMediumSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableMediumSwitch);
void CastableSmallSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 30); },
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 290); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableSmallSwitch);
} // namespace
} // namespace tint
TINT_INSTANTIATE_TYPEINFO(tint::Base);
TINT_INSTANTIATE_TYPEINFO(tint::A);
TINT_INSTANTIATE_TYPEINFO(tint::AA);
TINT_INSTANTIATE_TYPEINFO(tint::AAA);
TINT_INSTANTIATE_TYPEINFO(tint::AAB);
TINT_INSTANTIATE_TYPEINFO(tint::AAC);
TINT_INSTANTIATE_TYPEINFO(tint::AB);
TINT_INSTANTIATE_TYPEINFO(tint::ABA);
TINT_INSTANTIATE_TYPEINFO(tint::ABB);
TINT_INSTANTIATE_TYPEINFO(tint::ABC);
TINT_INSTANTIATE_TYPEINFO(tint::AC);
TINT_INSTANTIATE_TYPEINFO(tint::ACA);
TINT_INSTANTIATE_TYPEINFO(tint::ACB);
TINT_INSTANTIATE_TYPEINFO(tint::ACC);
TINT_INSTANTIATE_TYPEINFO(tint::B);
TINT_INSTANTIATE_TYPEINFO(tint::BA);
TINT_INSTANTIATE_TYPEINFO(tint::BAA);
TINT_INSTANTIATE_TYPEINFO(tint::BAB);
TINT_INSTANTIATE_TYPEINFO(tint::BAC);
TINT_INSTANTIATE_TYPEINFO(tint::BB);
TINT_INSTANTIATE_TYPEINFO(tint::BBA);
TINT_INSTANTIATE_TYPEINFO(tint::BBB);
TINT_INSTANTIATE_TYPEINFO(tint::BBC);
TINT_INSTANTIATE_TYPEINFO(tint::BC);
TINT_INSTANTIATE_TYPEINFO(tint::BCA);
TINT_INSTANTIATE_TYPEINFO(tint::BCB);
TINT_INSTANTIATE_TYPEINFO(tint::BCC);
TINT_INSTANTIATE_TYPEINFO(tint::C);
TINT_INSTANTIATE_TYPEINFO(tint::CA);
TINT_INSTANTIATE_TYPEINFO(tint::CAA);
TINT_INSTANTIATE_TYPEINFO(tint::CAB);
TINT_INSTANTIATE_TYPEINFO(tint::CAC);
TINT_INSTANTIATE_TYPEINFO(tint::CB);
TINT_INSTANTIATE_TYPEINFO(tint::CBA);
TINT_INSTANTIATE_TYPEINFO(tint::CBB);
TINT_INSTANTIATE_TYPEINFO(tint::CBC);
TINT_INSTANTIATE_TYPEINFO(tint::CC);
TINT_INSTANTIATE_TYPEINFO(tint::CCA);
TINT_INSTANTIATE_TYPEINFO(tint::CCB);
TINT_INSTANTIATE_TYPEINFO(tint::CCC);

View File

@ -252,6 +252,151 @@ TEST(Castable, As) {
ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get())); ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get()));
} }
TEST(Castable, SwitchNoDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_amphibian = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Amphibian* amphibian) {
EXPECT_EQ(amphibian, frog.get());
frog_matched_amphibian = true;
});
EXPECT_TRUE(frog_matched_amphibian);
}
{
bool bear_matched_mammal = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Mammal* mammal) {
EXPECT_EQ(mammal, bear.get());
bear_matched_mammal = true;
});
EXPECT_TRUE(bear_matched_mammal);
}
{
bool gecko_matched_reptile = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Reptile* reptile) {
EXPECT_EQ(reptile, gecko.get());
gecko_matched_reptile = true;
});
EXPECT_TRUE(gecko_matched_reptile);
}
}
TEST(Castable, SwitchWithUnusedDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_amphibian = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Amphibian* amphibian) {
EXPECT_EQ(amphibian, frog.get());
frog_matched_amphibian = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(frog_matched_amphibian);
}
{
bool bear_matched_mammal = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Mammal* mammal) {
EXPECT_EQ(mammal, bear.get());
bear_matched_mammal = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(bear_matched_mammal);
}
{
bool gecko_matched_reptile = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Reptile* reptile) {
EXPECT_EQ(reptile, gecko.get());
gecko_matched_reptile = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(gecko_matched_reptile);
}
}
TEST(Castable, SwitchDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_default = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Default) { frog_matched_default = true; });
EXPECT_TRUE(frog_matched_default);
}
{
bool bear_matched_default = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Default) { bear_matched_default = true; });
EXPECT_TRUE(bear_matched_default);
}
{
bool gecko_matched_default = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Default) { gecko_matched_default = true; });
EXPECT_TRUE(gecko_matched_default);
}
}
TEST(Castable, SwitchMatchFirst) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
{
bool frog_matched_animal = false;
Switch(
frog.get(),
[&](Animal* animal) {
EXPECT_EQ(animal, frog.get());
frog_matched_animal = true;
},
[&](Amphibian*) { FAIL() << "animal should have been matched first"; });
EXPECT_TRUE(frog_matched_animal);
}
{
bool frog_matched_amphibian = false;
Switch(
frog.get(),
[&](Amphibian* amphibain) {
EXPECT_EQ(amphibain, frog.get());
frog_matched_amphibian = true;
},
[&](Animal*) { FAIL() << "amphibian should have been matched first"; });
EXPECT_TRUE(frog_matched_amphibian);
}
}
} // namespace } // namespace
TINT_INSTANTIATE_TYPEINFO(Animal); TINT_INSTANTIATE_TYPEINFO(Animal);

View File

@ -953,7 +953,7 @@ const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
bool FunctionEmitter::EmitPipelineInput(std::string var_name, bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* var_type, const Type* var_type,
ast::AttributeList* decos, ast::AttributeList* attrs,
std::vector<int> index_prefix, std::vector<int> index_prefix,
const Type* tip_type, const Type* tip_type,
const Type* forced_param_type, const Type* forced_param_type,
@ -966,105 +966,121 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name,
} }
// Recursively flatten matrices, arrays, and structures. // Recursively flatten matrices, arrays, and structures.
if (auto* matrix_type = tip_type->As<Matrix>()) { return Switch(
index_prefix.push_back(0); tip_type,
const auto num_columns = static_cast<int>(matrix_type->columns); [&](const Matrix* matrix_type) -> bool {
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); index_prefix.push_back(0);
for (int col = 0; col < num_columns; col++) { const auto num_columns = static_cast<int>(matrix_type->columns);
index_prefix.back() = col; const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty, for (int col = 0; col < num_columns; col++) {
forced_param_type, params, statements)) { index_prefix.back() = col;
return false; if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
} vec_ty, forced_param_type, params,
} statements)) {
return success(); return false;
} else if (auto* array_type = tip_type->As<Array>()) { }
if (array_type->size == 0) { }
return Fail() << "runtime-size array not allowed on pipeline IO"; return success();
} },
index_prefix.push_back(0); [&](const Array* array_type) -> bool {
const Type* elem_ty = array_type->type; if (array_type->size == 0) {
for (int i = 0; i < static_cast<int>(array_type->size); i++) { return Fail() << "runtime-size array not allowed on pipeline IO";
index_prefix.back() = i; }
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty, index_prefix.push_back(0);
forced_param_type, params, statements)) { const Type* elem_ty = array_type->type;
return false; for (int i = 0; i < static_cast<int>(array_type->size); i++) {
} index_prefix.back() = i;
} if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
return success(); elem_ty, forced_param_type, params,
} else if (auto* struct_type = tip_type->As<Struct>()) { statements)) {
const auto& members = struct_type->members; return false;
index_prefix.push_back(0); }
for (int i = 0; i < static_cast<int>(members.size()); ++i) { }
index_prefix.back() = i; return success();
ast::AttributeList member_decos(*decos); },
if (!parser_impl_.ConvertPipelineDecorations( [&](const Struct* struct_type) -> bool {
struct_type, const auto& members = struct_type->members;
parser_impl_.GetMemberPipelineDecorations(*struct_type, i), index_prefix.push_back(0);
&member_decos)) { for (int i = 0; i < static_cast<int>(members.size()); ++i) {
return false; index_prefix.back() = i;
} ast::AttributeList member_attrs(*attrs);
if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix, if (!parser_impl_.ConvertPipelineDecorations(
members[i], forced_param_type, params, struct_type,
statements)) { parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
return false; &member_attrs)) {
} return false;
// Copy the location as updated by nested expansion of the member. }
parser_impl_.SetLocation(decos, GetLocation(member_decos)); if (!EmitPipelineInput(var_name, var_type, &member_attrs,
} index_prefix, members[i], forced_param_type,
return success(); params, statements)) {
} return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin =
ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos); const Type* param_type = is_builtin ? forced_param_type : tip_type;
const Type* param_type = is_builtin ? forced_param_type : tip_type; const auto param_name = namer_.MakeDerivedName(var_name + "_param");
// Create the parameter.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
params->push_back(
builder_.Param(param_name, param_type->Build(builder_), *attrs));
const auto param_name = namer_.MakeDerivedName(var_name + "_param"); // Add a body statement to copy the parameter to the corresponding
// Create the parameter. // private variable.
// TODO(dneto): Note: If the parameter has non-location decorations, const ast::Expression* param_value = builder_.Expr(param_name);
// then those decoration AST nodes will be reused between multiple elements const ast::Expression* store_dest = builder_.Expr(var_name);
// of a matrix, array, or structure. Normally that's disallowed but currently
// the SPIR-V reader will make duplicates when the entire AST is cloned
// at the top level of the SPIR-V reader flow. Consider rewriting this
// to avoid this node-sharing.
params->push_back(
builder_.Param(param_name, param_type->Build(builder_), *decos));
// Add a body statement to copy the parameter to the corresponding private // Index into the LHS as needed.
// variable. auto* current_type =
const ast::Expression* param_value = builder_.Expr(param_name); var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
const ast::Expression* store_dest = builder_.Expr(var_name); for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
store_dest =
builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
store_dest =
builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
store_dest = builder_.MemberAccessor(
store_dest, builder_.Expr(parser_impl_.GetMemberName(
*struct_type, index)));
current_type = struct_type->members[index];
});
}
// Index into the LHS as needed. if (is_builtin && (tip_type != forced_param_type)) {
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); // The parameter will have the WGSL type, but we need bitcast to
for (auto index : index_prefix) { // the variable store type.
if (auto* matrix_type = current_type->As<Matrix>()) { param_value = create<ast::BitcastExpression>(
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index)); tip_type->Build(builder_), param_value);
current_type = ty_.Vector(matrix_type->type, matrix_type->rows); }
} else if (auto* array_type = current_type->As<Array>()) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
} else if (auto* struct_type = current_type->As<Struct>()) {
store_dest = builder_.MemberAccessor(
store_dest,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[index];
}
}
if (is_builtin && (tip_type != forced_param_type)) { statements->push_back(builder_.Assign(store_dest, param_value));
// The parameter will have the WGSL type, but we need bitcast to
// the variable store type.
param_value =
create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
}
statements->push_back(builder_.Assign(store_dest, param_value)); // Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(attrs);
// Increment the location attribute, in case more parameters will follow. return success();
IncrementLocation(decos); });
return success();
} }
void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) { void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
@ -1102,106 +1118,120 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
} }
// Recursively flatten matrices, arrays, and structures. // Recursively flatten matrices, arrays, and structures.
if (auto* matrix_type = tip_type->As<Matrix>()) { return Switch(
index_prefix.push_back(0); tip_type,
const auto num_columns = static_cast<int>(matrix_type->columns); [&](const Matrix* matrix_type) -> bool {
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); index_prefix.push_back(0);
for (int col = 0; col < num_columns; col++) { const auto num_columns = static_cast<int>(matrix_type->columns);
index_prefix.back() = col; const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty, for (int col = 0; col < num_columns; col++) {
forced_member_type, return_members, index_prefix.back() = col;
return_exprs)) { if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
return false; vec_ty, forced_member_type, return_members,
} return_exprs)) {
} return false;
return success(); }
} else if (auto* array_type = tip_type->As<Array>()) { }
if (array_type->size == 0) { return success();
return Fail() << "runtime-size array not allowed on pipeline IO"; },
} [&](const Array* array_type) -> bool {
index_prefix.push_back(0); if (array_type->size == 0) {
const Type* elem_ty = array_type->type; return Fail() << "runtime-size array not allowed on pipeline IO";
for (int i = 0; i < static_cast<int>(array_type->size); i++) { }
index_prefix.back() = i; index_prefix.push_back(0);
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty, const Type* elem_ty = array_type->type;
forced_member_type, return_members, for (int i = 0; i < static_cast<int>(array_type->size); i++) {
return_exprs)) { index_prefix.back() = i;
return false; if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
} elem_ty, forced_member_type, return_members,
} return_exprs)) {
return success(); return false;
} else if (auto* struct_type = tip_type->As<Struct>()) { }
const auto& members = struct_type->members; }
index_prefix.push_back(0); return success();
for (int i = 0; i < static_cast<int>(members.size()); ++i) { },
index_prefix.back() = i; [&](const Struct* struct_type) -> bool {
ast::AttributeList member_decos(*decos); const auto& members = struct_type->members;
if (!parser_impl_.ConvertPipelineDecorations( index_prefix.push_back(0);
struct_type, for (int i = 0; i < static_cast<int>(members.size()); ++i) {
parser_impl_.GetMemberPipelineDecorations(*struct_type, i), index_prefix.back() = i;
&member_decos)) { ast::AttributeList member_attrs(*decos);
return false; if (!parser_impl_.ConvertPipelineDecorations(
} struct_type,
if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix, parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
members[i], forced_member_type, return_members, &member_attrs)) {
return_exprs)) { return false;
return false; }
} if (!EmitPipelineOutput(var_name, var_type, &member_attrs,
// Copy the location as updated by nested expansion of the member. index_prefix, members[i], forced_member_type,
parser_impl_.SetLocation(decos, GetLocation(member_decos)); return_members, return_exprs)) {
} return false;
return success(); }
} // Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin =
ast::HasAttribute<ast::BuiltinAttribute>(*decos);
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos); const Type* member_type = is_builtin ? forced_member_type : tip_type;
// Derive the member name directly from the variable name. They can't
// collide.
const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
return_members->push_back(
builder_.Member(member_name, member_type->Build(builder_), *decos));
const Type* member_type = is_builtin ? forced_member_type : tip_type; // Create an expression to evaluate the part of the variable indexed by
// Derive the member name directly from the variable name. They can't // the index_prefix.
// collide. const ast::Expression* load_source = builder_.Expr(var_name);
const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple elements
// of a matrix, array, or structure. Normally that's disallowed but currently
// the SPIR-V reader will make duplicates when the entire AST is cloned
// at the top level of the SPIR-V reader flow. Consider rewriting this
// to avoid this node-sharing.
return_members->push_back(
builder_.Member(member_name, member_type->Build(builder_), *decos));
// Create an expression to evaluate the part of the variable indexed by // Index into the variable as needed to pick out the flattened member.
// the index_prefix. auto* current_type =
const ast::Expression* load_source = builder_.Expr(var_name); var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
load_source = builder_.MemberAccessor(
load_source, builder_.Expr(parser_impl_.GetMemberName(
*struct_type, index)));
current_type = struct_type->members[index];
});
}
// Index into the variable as needed to pick out the flattened member. if (is_builtin && (tip_type != forced_member_type)) {
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); // The member will have the WGSL type, but we need bitcast to
for (auto index : index_prefix) { // the variable store type.
if (auto* matrix_type = current_type->As<Matrix>()) { load_source = create<ast::BitcastExpression>(
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index)); forced_member_type->Build(builder_), load_source);
current_type = ty_.Vector(matrix_type->type, matrix_type->rows); }
} else if (auto* array_type = current_type->As<Array>()) { return_exprs->push_back(load_source);
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
} else if (auto* struct_type = current_type->As<Struct>()) {
load_source = builder_.MemberAccessor(
load_source,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[index];
}
}
if (is_builtin && (tip_type != forced_member_type)) { // Increment the location attribute, in case more parameters will
// The member will have the WGSL type, but we need bitcast to // follow.
// the variable store type. IncrementLocation(decos);
load_source = create<ast::BitcastExpression>(
forced_member_type->Build(builder_), load_source);
}
return_exprs->push_back(load_source);
// Increment the location attribute, in case more parameters will follow. return success();
IncrementLocation(decos); });
return success();
} }
bool FunctionEmitter::EmitEntryPointAsWrapper() { bool FunctionEmitter::EmitEntryPointAsWrapper() {

View File

@ -239,39 +239,41 @@ bool GeneratorImpl::Generate() {
} }
last_kind = kind; last_kind = kind;
if (auto* global = decl->As<ast::Variable>()) { bool ok = Switch(
if (!EmitGlobalVariable(global)) { decl,
return false; [&](const ast::Variable* global) { //
} return EmitGlobalVariable(global);
} else if (auto* str = decl->As<ast::Struct>()) { },
auto* ty = builder_.Sem().Get(str); [&](const ast::Struct* str) {
auto storage_class_uses = ty->StorageClassUsage(); auto* ty = builder_.Sem().Get(str);
if (storage_class_uses.size() != auto storage_class_uses = ty->StorageClassUsage();
(storage_class_uses.count(ast::StorageClass::kStorage) + if (storage_class_uses.size() !=
storage_class_uses.count(ast::StorageClass::kUniform))) { (storage_class_uses.count(ast::StorageClass::kStorage) +
// The structure is used as something other than a storage buffer or storage_class_uses.count(ast::StorageClass::kUniform))) {
// uniform buffer, so it needs to be emitted. // The structure is used as something other than a storage buffer or
// Storage buffer are read and written to via a ByteAddressBuffer // uniform buffer, so it needs to be emitted.
// instead of true structure. // Storage buffer are read and written to via a ByteAddressBuffer
// Structures used as uniform buffer are read from an array of vectors // instead of true structure.
// instead of true structure. // Structures used as uniform buffer are read from an array of
if (!EmitStructType(current_buffer_, ty)) { // vectors instead of true structure.
return EmitStructType(current_buffer_, ty);
}
return true;
},
[&](const ast::Function* func) {
if (func->IsEntryPoint()) {
return EmitEntryPointFunction(func);
}
return EmitFunction(func);
},
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unhandled module-scope declaration: "
<< decl->TypeInfo().name;
return false; return false;
} });
}
} else if (auto* func = decl->As<ast::Function>()) { if (!ok) {
if (func->IsEntryPoint()) {
if (!EmitEntryPointFunction(func)) {
return false;
}
} else {
if (!EmitFunction(func)) {
return false;
}
}
} else {
TINT_ICE(Writer, diagnostics_)
<< "unhandled module-scope declaration: " << decl->TypeInfo().name;
return false; return false;
} }
} }
@ -929,22 +931,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) { const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr); auto* call = builder_.Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
return Switch(
if (auto* func = target->As<sem::Function>()) { target,
return EmitFunctionCall(out, call, func); [&](const sem::Function* func) {
} return EmitFunctionCall(out, call, func);
if (auto* builtin = target->As<sem::Builtin>()) { },
return EmitBuiltinCall(out, call, builtin); [&](const sem::Builtin* builtin) {
} return EmitBuiltinCall(out, call, builtin);
if (auto* conv = target->As<sem::TypeConversion>()) { },
return EmitTypeConversion(out, call, conv); [&](const sem::TypeConversion* conv) {
} return EmitTypeConversion(out, call, conv);
if (auto* ctor = target->As<sem::TypeConstructor>()) { },
return EmitTypeConstructor(out, call, ctor); [&](const sem::TypeConstructor* ctor) {
} return EmitTypeConstructor(out, call, ctor);
TINT_ICE(Writer, diagnostics_) },
<< "unhandled call target: " << target->TypeInfo().name; [&](Default) {
return false; TINT_ICE(Writer, diagnostics_)
<< "unhandled call target: " << target->TypeInfo().name;
return false;
});
} }
bool GeneratorImpl::EmitFunctionCall(std::ostream& out, bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@ -2639,35 +2644,38 @@ bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
bool GeneratorImpl::EmitExpression(std::ostream& out, bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) { const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { return Switch(
return EmitIndexAccessor(out, a); expr,
} [&](const ast::IndexAccessorExpression* a) { //
if (auto* b = expr->As<ast::BinaryExpression>()) { return EmitIndexAccessor(out, a);
return EmitBinary(out, b); },
} [&](const ast::BinaryExpression* b) { //
if (auto* b = expr->As<ast::BitcastExpression>()) { return EmitBinary(out, b);
return EmitBitcast(out, b); },
} [&](const ast::BitcastExpression* b) { //
if (auto* c = expr->As<ast::CallExpression>()) { return EmitBitcast(out, b);
return EmitCall(out, c); },
} [&](const ast::CallExpression* c) { //
if (auto* i = expr->As<ast::IdentifierExpression>()) { return EmitCall(out, c);
return EmitIdentifier(out, i); },
} [&](const ast::IdentifierExpression* i) { //
if (auto* l = expr->As<ast::LiteralExpression>()) { return EmitIdentifier(out, i);
return EmitLiteral(out, l); },
} [&](const ast::LiteralExpression* l) { //
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { return EmitLiteral(out, l);
return EmitMemberAccessor(out, m); },
} [&](const ast::MemberAccessorExpression* m) { //
if (auto* u = expr->As<ast::UnaryOpExpression>()) { return EmitMemberAccessor(out, m);
return EmitUnaryOp(out, u); },
} [&](const ast::UnaryOpExpression* u) { //
return EmitUnaryOp(out, u);
diagnostics_.add_error( },
diag::System::Writer, [&](Default) { //
"unknown expression type: " + std::string(expr->TypeInfo().name)); diagnostics_.add_error(
return false; diag::System::Writer,
"unknown expression type: " + std::string(expr->TypeInfo().name));
return false;
});
} }
bool GeneratorImpl::EmitIdentifier(std::ostream& out, bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@ -3127,80 +3135,108 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
bool GeneratorImpl::EmitLiteral(std::ostream& out, bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) { const ast::LiteralExpression* lit) {
if (auto* l = lit->As<ast::BoolLiteralExpression>()) { return Switch(
out << (l->value ? "true" : "false"); lit,
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) { [&](const ast::BoolLiteralExpression* l) {
if (std::isinf(fl->value)) { out << (l->value ? "true" : "false");
out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)"); return true;
} else if (std::isnan(fl->value)) { },
out << "asfloat(0x7fc00000u)"; [&](const ast::FloatLiteralExpression* fl) {
} else { if (std::isinf(fl->value)) {
out << FloatToString(fl->value) << "f"; out << (fl->value >= 0 ? "asfloat(0x7f800000u)"
} : "asfloat(0xff800000u)");
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) { } else if (std::isnan(fl->value)) {
out << sl->value; out << "asfloat(0x7fc00000u)";
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) { } else {
out << ul->value << "u"; out << FloatToString(fl->value) << "f";
} else { }
diagnostics_.add_error(diag::System::Writer, "unknown literal type"); return true;
return false; },
} [&](const ast::SintLiteralExpression* sl) {
return true; out << sl->value;
return true;
},
[&](const ast::UintLiteralExpression* ul) {
out << ul->value << "u";
return true;
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
});
} }
bool GeneratorImpl::EmitValue(std::ostream& out, bool GeneratorImpl::EmitValue(std::ostream& out,
const sem::Type* type, const sem::Type* type,
int value) { int value) {
if (type->Is<sem::Bool>()) { return Switch(
out << (value == 0 ? "false" : "true"); type,
} else if (type->Is<sem::F32>()) { [&](const sem::Bool*) {
out << value << ".0f"; out << (value == 0 ? "false" : "true");
} else if (type->Is<sem::I32>()) { return true;
out << value; },
} else if (type->Is<sem::U32>()) { [&](const sem::F32*) {
out << value << "u"; out << value << ".0f";
} else if (auto* vec = type->As<sem::Vector>()) { return true;
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, },
"")) { [&](const sem::I32*) {
return false; out << value;
} return true;
ScopedParen sp(out); },
for (uint32_t i = 0; i < vec->Width(); i++) { [&](const sem::U32*) {
if (i != 0) { out << value << "u";
out << ", "; return true;
} },
if (!EmitValue(out, vec->type(), value)) { [&](const sem::Vector* vec) {
if (!EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < vec->Width(); i++) {
if (i != 0) {
out << ", ";
}
if (!EmitValue(out, vec->type(), value)) {
return false;
}
}
return true;
},
[&](const sem::Matrix* mat) {
if (!EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
if (i != 0) {
out << ", ";
}
if (!EmitValue(out, mat->type(), value)) {
return false;
}
}
return true;
},
[&](const sem::Struct*) {
out << "(";
TINT_DEFER(out << ")" << value);
return EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kUndefined, "");
},
[&](const sem::Array*) {
out << "(";
TINT_DEFER(out << ")" << value);
return EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kUndefined, "");
},
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
"Invalid type for value emission: " + type->type_name());
return false; return false;
} });
}
} else if (auto* mat = type->As<sem::Matrix>()) {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
"")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
if (i != 0) {
out << ", ";
}
if (!EmitValue(out, mat->type(), value)) {
return false;
}
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
out << "(";
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
"")) {
return false;
}
out << ")" << value;
} else {
diagnostics_.add_error(
diag::System::Writer,
"Invalid type for value emission: " + type->type_name());
return false;
}
return true;
} }
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
@ -3375,56 +3411,59 @@ bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
} }
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) { return Switch(
return EmitAssign(a); stmt,
} [&](const ast::AssignmentStatement* a) { //
if (auto* b = stmt->As<ast::BlockStatement>()) { return EmitAssign(a);
return EmitBlock(b); },
} [&](const ast::BlockStatement* b) { //
if (auto* b = stmt->As<ast::BreakStatement>()) { return EmitBlock(b);
return EmitBreak(b); },
} [&](const ast::BreakStatement* b) { //
if (auto* c = stmt->As<ast::CallStatement>()) { return EmitBreak(b);
auto out = line(); },
if (!EmitCall(out, c->expr)) { [&](const ast::CallStatement* c) { //
return false; auto out = line();
} if (!EmitCall(out, c->expr)) {
out << ";"; return false;
return true; }
} out << ";";
if (auto* c = stmt->As<ast::ContinueStatement>()) { return true;
return EmitContinue(c); },
} [&](const ast::ContinueStatement* c) { //
if (auto* d = stmt->As<ast::DiscardStatement>()) { return EmitContinue(c);
return EmitDiscard(d); },
} [&](const ast::DiscardStatement* d) { //
if (stmt->As<ast::FallthroughStatement>()) { return EmitDiscard(d);
line() << "/* fallthrough */"; },
return true; [&](const ast::FallthroughStatement*) { //
} line() << "/* fallthrough */";
if (auto* i = stmt->As<ast::IfStatement>()) { return true;
return EmitIf(i); },
} [&](const ast::IfStatement* i) { //
if (auto* l = stmt->As<ast::LoopStatement>()) { return EmitIf(i);
return EmitLoop(l); },
} [&](const ast::LoopStatement* l) { //
if (auto* l = stmt->As<ast::ForLoopStatement>()) { return EmitLoop(l);
return EmitForLoop(l); },
} [&](const ast::ForLoopStatement* l) { //
if (auto* r = stmt->As<ast::ReturnStatement>()) { return EmitForLoop(l);
return EmitReturn(r); },
} [&](const ast::ReturnStatement* r) { //
if (auto* s = stmt->As<ast::SwitchStatement>()) { return EmitReturn(r);
return EmitSwitch(s); },
} [&](const ast::SwitchStatement* s) { //
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { return EmitSwitch(s);
return EmitVariable(v->variable); },
} [&](const ast::VariableDeclStatement* v) { //
return EmitVariable(v->variable);
diagnostics_.add_error( },
diag::System::Writer, [&](Default) { //
"unknown statement type: " + std::string(stmt->TypeInfo().name)); diagnostics_.add_error(
return false; diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name));
return false;
});
} }
bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) { bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
@ -3516,156 +3555,181 @@ bool GeneratorImpl::EmitType(std::ostream& out,
break; break;
} }
if (auto* ary = type->As<sem::Array>()) { return Switch(
const sem::Type* base_type = ary; type,
std::vector<uint32_t> sizes; [&](const sem::Array* ary) {
while (auto* arr = base_type->As<sem::Array>()) { const sem::Type* base_type = ary;
if (arr->IsRuntimeSized()) { std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) {
TINT_ICE(Writer, diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which "
"should "
"have been transformed into a ByteAddressBuffer";
return false;
}
sizes.push_back(arr->Count());
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, storage_class, access, "")) {
return false;
}
if (!name.empty()) {
out << " " << name;
if (name_printed) {
*name_printed = true;
}
}
for (uint32_t size : sizes) {
out << "[" << size << "]";
}
return true;
},
[&](const sem::Bool*) {
out << "bool";
return true;
},
[&](const sem::F32*) {
out << "float";
return true;
},
[&](const sem::I32*) {
out << "int";
return true;
},
[&](const sem::Matrix* mat) {
if (!EmitType(out, mat->type(), storage_class, access, "")) {
return false;
}
// Note: HLSL's matrices are declared as <type>NxM, where N is the
// number of rows and M is the number of columns. Despite HLSL's
// matrices being column-major by default, the index operator and
// constructors actually operate on row-vectors, where as WGSL operates
// on column vectors. To simplify everything we use the transpose of the
// matrices. See:
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
out << mat->columns() << "x" << mat->rows();
return true;
},
[&](const sem::Pointer*) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which should " << "Attempting to emit pointer type. These should have been "
"have been transformed into a ByteAddressBuffer"; "removed with the InlinePointerLets transform";
return false; return false;
} },
sizes.push_back(arr->Count()); [&](const sem::Sampler* sampler) {
base_type = arr->ElemType(); out << "Sampler";
} if (sampler->IsComparison()) {
if (!EmitType(out, base_type, storage_class, access, "")) { out << "Comparison";
return false; }
} out << "State";
if (!name.empty()) { return true;
out << " " << name; },
if (name_printed) { [&](const sem::Struct* str) {
*name_printed = true; out << StructName(str);
} return true;
} },
for (uint32_t size : sizes) { [&](const sem::Texture* tex) {
out << "[" << size << "]"; auto* storage = tex->As<sem::StorageTexture>();
} auto* ms = tex->As<sem::MultisampledTexture>();
} else if (type->Is<sem::Bool>()) { auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
out << "bool"; auto* sampled = tex->As<sem::SampledTexture>();
} else if (type->Is<sem::F32>()) {
out << "float";
} else if (type->Is<sem::I32>()) {
out << "int";
} else if (auto* mat = type->As<sem::Matrix>()) {
if (!EmitType(out, mat->type(), storage_class, access, "")) {
return false;
}
// Note: HLSL's matrices are declared as <type>NxM, where N is the number of
// rows and M is the number of columns. Despite HLSL's matrices being
// column-major by default, the index operator and constructors actually
// operate on row-vectors, where as WGSL operates on column vectors.
// To simplify everything we use the transpose of the matrices.
// See:
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
out << mat->columns() << "x" << mat->rows();
} else if (type->Is<sem::Pointer>()) {
TINT_ICE(Writer, diagnostics_)
<< "Attempting to emit pointer type. These should have been removed "
"with the InlinePointerLets transform";
return false;
} else if (auto* sampler = type->As<sem::Sampler>()) {
out << "Sampler";
if (sampler->IsComparison()) {
out << "Comparison";
}
out << "State";
} else if (auto* str = type->As<sem::Struct>()) {
out << StructName(str);
} else if (auto* tex = type->As<sem::Texture>()) {
auto* storage = tex->As<sem::StorageTexture>();
auto* ms = tex->As<sem::MultisampledTexture>();
auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
auto* sampled = tex->As<sem::SampledTexture>();
if (storage && storage->access() != ast::Access::kRead) { if (storage && storage->access() != ast::Access::kRead) {
out << "RW"; out << "RW";
} }
out << "Texture"; out << "Texture";
switch (tex->dim()) { switch (tex->dim()) {
case ast::TextureDimension::k1d: case ast::TextureDimension::k1d:
out << "1D"; out << "1D";
break; break;
case ast::TextureDimension::k2d: case ast::TextureDimension::k2d:
out << ((ms || depth_ms) ? "2DMS" : "2D"); out << ((ms || depth_ms) ? "2DMS" : "2D");
break; break;
case ast::TextureDimension::k2dArray: case ast::TextureDimension::k2dArray:
out << ((ms || depth_ms) ? "2DMSArray" : "2DArray"); out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
break; break;
case ast::TextureDimension::k3d: case ast::TextureDimension::k3d:
out << "3D"; out << "3D";
break; break;
case ast::TextureDimension::kCube: case ast::TextureDimension::kCube:
out << "Cube"; out << "Cube";
break; break;
case ast::TextureDimension::kCubeArray: case ast::TextureDimension::kCubeArray:
out << "CubeArray"; out << "CubeArray";
break; break;
default: default:
TINT_UNREACHABLE(Writer, diagnostics_) TINT_UNREACHABLE(Writer, diagnostics_)
<< "unexpected TextureDimension " << tex->dim(); << "unexpected TextureDimension " << tex->dim();
return false; return false;
} }
if (storage) { if (storage) {
auto* component = image_format_to_rwtexture_type(storage->texel_format()); auto* component =
if (component == nullptr) { image_format_to_rwtexture_type(storage->texel_format());
TINT_ICE(Writer, diagnostics_) if (component == nullptr) {
<< "Unsupported StorageTexture TexelFormat: " TINT_ICE(Writer, diagnostics_)
<< static_cast<int>(storage->texel_format()); << "Unsupported StorageTexture TexelFormat: "
<< static_cast<int>(storage->texel_format());
return false;
}
out << "<" << component << ">";
} else if (depth_ms) {
out << "<float4>";
} else if (sampled || ms) {
auto* subtype = sampled ? sampled->type() : ms->type();
out << "<";
if (subtype->Is<sem::F32>()) {
out << "float4";
} else if (subtype->Is<sem::I32>()) {
out << "int4";
} else if (subtype->Is<sem::U32>()) {
out << "uint4";
} else {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported multisampled texture type";
return false;
}
out << ">";
}
return true;
},
[&](const sem::U32*) {
out << "uint";
return true;
},
[&](const sem::Vector* vec) {
auto width = vec->Width();
if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
out << "float" << width;
} else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
out << "int" << width;
} else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
out << "uint" << width;
} else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
out << "bool" << width;
} else {
out << "vector<";
if (!EmitType(out, vec->type(), storage_class, access, "")) {
return false;
}
out << ", " << width << ">";
}
return true;
},
[&](const sem::Atomic* atomic) {
return EmitType(out, atomic->Type(), storage_class, access, name);
},
[&](const sem::Void*) {
out << "void";
return true;
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer,
"unknown type in EmitType");
return false; return false;
} });
out << "<" << component << ">";
} else if (depth_ms) {
out << "<float4>";
} else if (sampled || ms) {
auto* subtype = sampled ? sampled->type() : ms->type();
out << "<";
if (subtype->Is<sem::F32>()) {
out << "float4";
} else if (subtype->Is<sem::I32>()) {
out << "int4";
} else if (subtype->Is<sem::U32>()) {
out << "uint4";
} else {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported multisampled texture type";
return false;
}
out << ">";
}
} else if (type->Is<sem::U32>()) {
out << "uint";
} else if (auto* vec = type->As<sem::Vector>()) {
auto width = vec->Width();
if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
out << "float" << width;
} else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
out << "int" << width;
} else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
out << "uint" << width;
} else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
out << "bool" << width;
} else {
out << "vector<";
if (!EmitType(out, vec->type(), storage_class, access, "")) {
return false;
}
out << ", " << width << ">";
}
} else if (auto* atomic = type->As<sem::Atomic>()) {
if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
return false;
}
} else if (type->Is<sem::Void>()) {
out << "void";
} else {
diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
return false;
}
return true;
} }
bool GeneratorImpl::EmitTypeAndName(std::ostream& out, bool GeneratorImpl::EmitTypeAndName(std::ostream& out,

File diff suppressed because it is too large Load Diff

View File

@ -560,33 +560,37 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
} }
uint32_t Builder::GenerateExpression(const ast::Expression* expr) { uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { return Switch(
return GenerateAccessorExpression(a); expr,
} [&](const ast::IndexAccessorExpression* a) { //
if (auto* b = expr->As<ast::BinaryExpression>()) { return GenerateAccessorExpression(a);
return GenerateBinaryExpression(b); },
} [&](const ast::BinaryExpression* b) { //
if (auto* b = expr->As<ast::BitcastExpression>()) { return GenerateBinaryExpression(b);
return GenerateBitcastExpression(b); },
} [&](const ast::BitcastExpression* b) { //
if (auto* c = expr->As<ast::CallExpression>()) { return GenerateBitcastExpression(b);
return GenerateCallExpression(c); },
} [&](const ast::CallExpression* c) { //
if (auto* i = expr->As<ast::IdentifierExpression>()) { return GenerateCallExpression(c);
return GenerateIdentifierExpression(i); },
} [&](const ast::IdentifierExpression* i) { //
if (auto* l = expr->As<ast::LiteralExpression>()) { return GenerateIdentifierExpression(i);
return GenerateLiteralIfNeeded(nullptr, l); },
} [&](const ast::LiteralExpression* l) { //
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { return GenerateLiteralIfNeeded(nullptr, l);
return GenerateAccessorExpression(m); },
} [&](const ast::MemberAccessorExpression* m) { //
if (auto* u = expr->As<ast::UnaryOpExpression>()) { return GenerateAccessorExpression(m);
return GenerateUnaryOpExpression(u); },
} [&](const ast::UnaryOpExpression* u) { //
return GenerateUnaryOpExpression(u);
error_ = "unknown expression type: " + std::string(expr->TypeInfo().name); },
return 0; [&](Default) -> uint32_t {
error_ =
"unknown expression type: " + std::string(expr->TypeInfo().name);
return 0;
});
} }
bool Builder::GenerateFunction(const ast::Function* func_ast) { bool Builder::GenerateFunction(const ast::Function* func_ast) {
@ -861,33 +865,56 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
push_type(spv::Op::OpVariable, std::move(ops)); push_type(spv::Op::OpVariable, std::move(ops));
for (auto* attr : var->attributes) { for (auto* attr : var->attributes) {
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { bool ok = Switch(
push_annot(spv::Op::OpDecorate, attr,
{Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn), [&](const ast::BuiltinAttribute* builtin) {
Operand::Int( push_annot(spv::Op::OpDecorate,
ConvertBuiltin(builtin->builtin, sem->StorageClass()))}); {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
} else if (auto* location = attr->As<ast::LocationAttribute>()) { Operand::Int(ConvertBuiltin(builtin->builtin,
push_annot(spv::Op::OpDecorate, sem->StorageClass()))});
{Operand::Int(var_id), Operand::Int(SpvDecorationLocation), return true;
Operand::Int(location->value)}); },
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { [&](const ast::LocationAttribute* location) {
AddInterpolationDecorations(var_id, interpolate->type, push_annot(spv::Op::OpDecorate,
interpolate->sampling); {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
} else if (attr->Is<ast::InvariantAttribute>()) { Operand::Int(location->value)});
push_annot(spv::Op::OpDecorate, return true;
{Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)}); },
} else if (auto* binding = attr->As<ast::BindingAttribute>()) { [&](const ast::InterpolateAttribute* interpolate) {
push_annot(spv::Op::OpDecorate, AddInterpolationDecorations(var_id, interpolate->type,
{Operand::Int(var_id), Operand::Int(SpvDecorationBinding), interpolate->sampling);
Operand::Int(binding->value)}); return true;
} else if (auto* group = attr->As<ast::GroupAttribute>()) { },
push_annot(spv::Op::OpDecorate, {Operand::Int(var_id), [&](const ast::InvariantAttribute*) {
Operand::Int(SpvDecorationDescriptorSet), push_annot(
Operand::Int(group->value)}); spv::Op::OpDecorate,
} else if (attr->Is<ast::OverrideAttribute>()) { {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
// Spec constants are handled elsewhere return true;
} else if (!attr->Is<ast::InternalAttribute>()) { },
error_ = "unknown attribute"; [&](const ast::BindingAttribute* binding) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
Operand::Int(binding->value)});
return true;
},
[&](const ast::GroupAttribute* group) {
push_annot(
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
Operand::Int(group->value)});
return true;
},
[&](const ast::OverrideAttribute*) {
return true; // Spec constants are handled elsewhere
},
[&](const ast::InternalAttribute*) {
return true; // ignored
},
[&](Default) {
error_ = "unknown attribute";
return false;
});
if (!ok) {
return false; return false;
} }
} }
@ -1123,19 +1150,21 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
// promoted to storage with the VarForDynamicIndex transform. // promoted to storage with the VarForDynamicIndex transform.
for (auto* accessor : accessors) { for (auto* accessor : accessors) {
if (auto* array = accessor->As<ast::IndexAccessorExpression>()) { bool ok = Switch(
if (!GenerateIndexAccessor(array, &info)) { accessor,
return 0; [&](const ast::IndexAccessorExpression* array) {
} return GenerateIndexAccessor(array, &info);
} else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) { },
if (!GenerateMemberAccessor(member, &info)) { [&](const ast::MemberAccessorExpression* member) {
return 0; return GenerateMemberAccessor(member, &info);
} },
[&](Default) {
} else { error_ = "invalid accessor in list: " +
error_ = std::string(accessor->TypeInfo().name);
"invalid accessor in list: " + std::string(accessor->TypeInfo().name); return false;
return 0; });
if (!ok) {
return false;
} }
} }
@ -1653,21 +1682,28 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
constant.constant_id = global->ConstantId(); constant.constant_id = global->ConstantId();
} }
if (auto* l = lit->As<ast::BoolLiteralExpression>()) { Switch(
constant.kind = ScalarConstant::Kind::kBool; lit,
constant.value.b = l->value; [&](const ast::BoolLiteralExpression* l) {
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) { constant.kind = ScalarConstant::Kind::kBool;
constant.kind = ScalarConstant::Kind::kI32; constant.value.b = l->value;
constant.value.i32 = sl->value; },
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) { [&](const ast::SintLiteralExpression* sl) {
constant.kind = ScalarConstant::Kind::kU32; constant.kind = ScalarConstant::Kind::kI32;
constant.value.u32 = ul->value; constant.value.i32 = sl->value;
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) { },
constant.kind = ScalarConstant::Kind::kF32; [&](const ast::UintLiteralExpression* ul) {
constant.value.f32 = fl->value; constant.kind = ScalarConstant::Kind::kU32;
} else { constant.value.u32 = ul->value;
error_ = "unknown literal type"; },
return 0; [&](const ast::FloatLiteralExpression* fl) {
constant.kind = ScalarConstant::Kind::kF32;
constant.value.f32 = fl->value;
},
[&](Default) { error_ = "unknown literal type"; });
if (!error_.empty()) {
return false;
} }
return GenerateConstantIfNeeded(constant); return GenerateConstantIfNeeded(constant);
@ -2209,19 +2245,25 @@ bool Builder::GenerateBlockStatementWithoutScoping(
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) { uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr); auto* call = builder_.Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
return Switch(
if (auto* func = target->As<sem::Function>()) { target,
return GenerateFunctionCall(call, func); [&](const sem::Function* func) {
} return GenerateFunctionCall(call, func);
if (auto* builtin = target->As<sem::Builtin>()) { },
return GenerateBuiltinCall(call, builtin); [&](const sem::Builtin* builtin) {
} return GenerateBuiltinCall(call, builtin);
if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) { },
return GenerateTypeConstructorOrConversion(call, nullptr); [&](const sem::TypeConversion*) {
} return GenerateTypeConstructorOrConversion(call, nullptr);
TINT_ICE(Writer, builder_.Diagnostics()) },
<< "unhandled call target: " << target->TypeInfo().name; [&](const sem::TypeConstructor*) {
return false; return GenerateTypeConstructorOrConversion(call, nullptr);
},
[&](Default) -> uint32_t {
TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name;
return 0;
});
} }
uint32_t Builder::GenerateFunctionCall(const sem::Call* call, uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
@ -3790,46 +3832,49 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
} }
bool Builder::GenerateStatement(const ast::Statement* stmt) { bool Builder::GenerateStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) { return Switch(
return GenerateAssignStatement(a); stmt,
} [&](const ast::AssignmentStatement* a) {
if (auto* b = stmt->As<ast::BlockStatement>()) { return GenerateAssignStatement(a);
return GenerateBlockStatement(b); },
} [&](const ast::BlockStatement* b) { //
if (auto* b = stmt->As<ast::BreakStatement>()) { return GenerateBlockStatement(b);
return GenerateBreakStatement(b); },
} [&](const ast::BreakStatement* b) { //
if (auto* c = stmt->As<ast::CallStatement>()) { return GenerateBreakStatement(b);
return GenerateCallExpression(c->expr) != 0; },
} [&](const ast::CallStatement* c) {
if (auto* c = stmt->As<ast::ContinueStatement>()) { return GenerateCallExpression(c->expr) != 0;
return GenerateContinueStatement(c); },
} [&](const ast::ContinueStatement* c) {
if (auto* d = stmt->As<ast::DiscardStatement>()) { return GenerateContinueStatement(c);
return GenerateDiscardStatement(d); },
} [&](const ast::DiscardStatement* d) {
if (stmt->Is<ast::FallthroughStatement>()) { return GenerateDiscardStatement(d);
// Do nothing here, the fallthrough gets handled by the switch code. },
return true; [&](const ast::FallthroughStatement*) {
} // Do nothing here, the fallthrough gets handled by the switch code.
if (auto* i = stmt->As<ast::IfStatement>()) { return true;
return GenerateIfStatement(i); },
} [&](const ast::IfStatement* i) { //
if (auto* l = stmt->As<ast::LoopStatement>()) { return GenerateIfStatement(i);
return GenerateLoopStatement(l); },
} [&](const ast::LoopStatement* l) { //
if (auto* r = stmt->As<ast::ReturnStatement>()) { return GenerateLoopStatement(l);
return GenerateReturnStatement(r); },
} [&](const ast::ReturnStatement* r) { //
if (auto* s = stmt->As<ast::SwitchStatement>()) { return GenerateReturnStatement(r);
return GenerateSwitchStatement(s); },
} [&](const ast::SwitchStatement* s) { //
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { return GenerateSwitchStatement(s);
return GenerateVariableDeclStatement(v); },
} [&](const ast::VariableDeclStatement* v) {
return GenerateVariableDeclStatement(v);
error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name); },
return false; [&](Default) {
error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
return false;
});
} }
bool Builder::GenerateVariableDeclStatement( bool Builder::GenerateVariableDeclStatement(
@ -3872,78 +3917,91 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t { return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
auto result = result_op(); auto result = result_op();
auto id = result.to_i(); auto id = result.to_i();
if (auto* arr = type->As<sem::Array>()) { bool ok = Switch(
if (!GenerateArrayType(arr, result)) { type,
return 0; [&](const sem::Array* arr) { //
} return GenerateArrayType(arr, result);
} else if (type->Is<sem::Bool>()) { },
push_type(spv::Op::OpTypeBool, {result}); [&](const sem::Bool*) {
} else if (type->Is<sem::F32>()) { push_type(spv::Op::OpTypeBool, {result});
push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)}); return true;
} else if (type->Is<sem::I32>()) { },
push_type(spv::Op::OpTypeInt, [&](const sem::F32*) {
{result, Operand::Int(32), Operand::Int(1)}); push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
} else if (auto* mat = type->As<sem::Matrix>()) { return true;
if (!GenerateMatrixType(mat, result)) { },
return 0; [&](const sem::I32*) {
} push_type(spv::Op::OpTypeInt,
} else if (auto* ptr = type->As<sem::Pointer>()) { {result, Operand::Int(32), Operand::Int(1)});
if (!GeneratePointerType(ptr, result)) { return true;
return 0; },
} [&](const sem::Matrix* mat) { //
} else if (auto* ref = type->As<sem::Reference>()) { return GenerateMatrixType(mat, result);
if (!GenerateReferenceType(ref, result)) { },
return 0; [&](const sem::Pointer* ptr) { //
} return GeneratePointerType(ptr, result);
} else if (auto* str = type->As<sem::Struct>()) { },
if (!GenerateStructType(str, result)) { [&](const sem::Reference* ref) { //
return 0; return GenerateReferenceType(ref, result);
} },
} else if (type->Is<sem::U32>()) { [&](const sem::Struct* str) { //
push_type(spv::Op::OpTypeInt, return GenerateStructType(str, result);
{result, Operand::Int(32), Operand::Int(0)}); },
} else if (auto* vec = type->As<sem::Vector>()) { [&](const sem::U32*) {
if (!GenerateVectorType(vec, result)) { push_type(spv::Op::OpTypeInt,
return 0; {result, Operand::Int(32), Operand::Int(0)});
} return true;
} else if (type->Is<sem::Void>()) { },
push_type(spv::Op::OpTypeVoid, {result}); [&](const sem::Vector* vec) { //
} else if (auto* tex = type->As<sem::Texture>()) { return GenerateVectorType(vec, result);
if (!GenerateTextureType(tex, result)) { },
return 0; [&](const sem::Void*) {
} push_type(spv::Op::OpTypeVoid, {result});
return true;
},
[&](const sem::StorageTexture* tex) {
if (!GenerateTextureType(tex, result)) {
return false;
}
if (auto* st = tex->As<sem::StorageTexture>()) { // Register all three access types of StorageTexture names. In
// Register all three access types of StorageTexture names. In SPIR-V, // SPIR-V, we must output a single type, while the variable is
// we must output a single type, while the variable is annotated with // annotated with the access type. Doing this ensures we de-dupe.
// the access type. Doing this ensures we de-dupe. type_name_to_id_[builder_
type_name_to_id_[builder_ .create<sem::StorageTexture>(
.create<sem::StorageTexture>( tex->dim(), tex->texel_format(),
st->dim(), st->texel_format(), ast::Access::kRead, tex->type())
ast::Access::kRead, st->type()) ->type_name()] = id;
->type_name()] = id; type_name_to_id_[builder_
type_name_to_id_[builder_ .create<sem::StorageTexture>(
.create<sem::StorageTexture>( tex->dim(), tex->texel_format(),
st->dim(), st->texel_format(), ast::Access::kWrite, tex->type())
ast::Access::kWrite, st->type()) ->type_name()] = id;
->type_name()] = id; type_name_to_id_[builder_
type_name_to_id_[builder_ .create<sem::StorageTexture>(
.create<sem::StorageTexture>( tex->dim(), tex->texel_format(),
st->dim(), st->texel_format(), ast::Access::kReadWrite, tex->type())
ast::Access::kReadWrite, st->type()) ->type_name()] = id;
->type_name()] = id; return true;
} },
[&](const sem::Texture* tex) {
return GenerateTextureType(tex, result);
},
[&](const sem::Sampler*) {
push_type(spv::Op::OpTypeSampler, {result});
} else if (type->Is<sem::Sampler>()) { // Register both of the sampler type names. In SPIR-V they're the same
push_type(spv::Op::OpTypeSampler, {result}); // sampler type, so we need to match that when we do the dedup check.
type_name_to_id_["__sampler_sampler"] = id;
type_name_to_id_["__sampler_comparison"] = id;
return true;
},
[&](Default) {
error_ = "unable to convert type: " + type->type_name();
return false;
});
// Register both of the sampler type names. In SPIR-V they're the same if (!ok) {
// sampler type, so we need to match that when we do the dedup check.
type_name_to_id_["__sampler_sampler"] = id;
type_name_to_id_["__sampler_comparison"] = id;
} else {
error_ = "unable to convert type: " + type->type_name();
return 0; return 0;
} }
@ -3995,22 +4053,31 @@ bool Builder::GenerateTextureType(const sem::Texture* texture,
} }
if (dim == ast::TextureDimension::kCubeArray) { if (dim == ast::TextureDimension::kCubeArray) {
if (texture->Is<sem::SampledTexture>() || if (texture->IsAnyOf<sem::SampledTexture, sem::DepthTexture>()) {
texture->Is<sem::DepthTexture>()) {
push_capability(SpvCapabilitySampledCubeArray); push_capability(SpvCapabilitySampledCubeArray);
} }
} }
uint32_t type_id = 0u; uint32_t type_id = Switch(
if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) { texture,
type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>()); [&](const sem::DepthTexture*) {
} else if (auto* s = texture->As<sem::SampledTexture>()) { return GenerateTypeIfNeeded(builder_.create<sem::F32>());
type_id = GenerateTypeIfNeeded(s->type()); },
} else if (auto* ms = texture->As<sem::MultisampledTexture>()) { [&](const sem::DepthMultisampledTexture*) {
type_id = GenerateTypeIfNeeded(ms->type()); return GenerateTypeIfNeeded(builder_.create<sem::F32>());
} else if (auto* st = texture->As<sem::StorageTexture>()) { },
type_id = GenerateTypeIfNeeded(st->type()); [&](const sem::SampledTexture* t) {
} return GenerateTypeIfNeeded(t->type());
},
[&](const sem::MultisampledTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](const sem::StorageTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](Default) -> uint32_t { //
return 0u;
});
if (type_id == 0u) { if (type_id == 0u) {
return false; return false;
} }

View File

@ -68,23 +68,17 @@ GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() { bool GeneratorImpl::Generate() {
// Generate global declarations in the order they appear in the module. // Generate global declarations in the order they appear in the module.
for (auto* decl : program_->AST().GlobalDeclarations()) { for (auto* decl : program_->AST().GlobalDeclarations()) {
if (auto* td = decl->As<ast::TypeDecl>()) { if (!Switch(
if (!EmitTypeDecl(td)) { decl, //
return false; [&](const ast::TypeDecl* td) { return EmitTypeDecl(td); },
} [&](const ast::Function* func) { return EmitFunction(func); },
} else if (auto* func = decl->As<ast::Function>()) { [&](const ast::Variable* var) { return EmitVariable(line(), var); },
if (!EmitFunction(func)) { [&](Default) {
return false; TINT_UNREACHABLE(Writer, diagnostics_);
} return false;
} else if (auto* var = decl->As<ast::Variable>()) { })) {
if (!EmitVariable(line(), var)) {
return false;
}
} else {
TINT_UNREACHABLE(Writer, diagnostics_);
return false; return false;
} }
if (decl != program_->AST().GlobalDeclarations().back()) { if (decl != program_->AST().GlobalDeclarations().back()) {
line(); line();
} }
@ -94,59 +88,64 @@ bool GeneratorImpl::Generate() {
} }
bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) { bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
if (auto* alias = ty->As<ast::Alias>()) { return Switch(
auto out = line(); ty,
out << "type " << program_->Symbols().NameFor(alias->name) << " = "; [&](const ast::Alias* alias) { //
if (!EmitType(out, alias->type)) { auto out = line();
return false; out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
} if (!EmitType(out, alias->type)) {
out << ";"; return false;
} else if (auto* str = ty->As<ast::Struct>()) { }
if (!EmitStructType(str)) { out << ";";
return false; return true;
} },
} else { [&](const ast::Struct* str) { //
diagnostics_.add_error( return EmitStructType(str);
diag::System::Writer, },
"unknown declared type: " + std::string(ty->TypeInfo().name)); [&](Default) { //
return false; diagnostics_.add_error(
} diag::System::Writer,
return true; "unknown declared type: " + std::string(ty->TypeInfo().name));
return false;
});
} }
bool GeneratorImpl::EmitExpression(std::ostream& out, bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) { const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { return Switch(
return EmitIndexAccessor(out, a); expr,
} [&](const ast::IndexAccessorExpression* a) { //
if (auto* b = expr->As<ast::BinaryExpression>()) { return EmitIndexAccessor(out, a);
return EmitBinary(out, b); },
} [&](const ast::BinaryExpression* b) { //
if (auto* b = expr->As<ast::BitcastExpression>()) { return EmitBinary(out, b);
return EmitBitcast(out, b); },
} [&](const ast::BitcastExpression* b) { //
if (auto* c = expr->As<ast::CallExpression>()) { return EmitBitcast(out, b);
return EmitCall(out, c); },
} [&](const ast::CallExpression* c) { //
if (auto* i = expr->As<ast::IdentifierExpression>()) { return EmitCall(out, c);
return EmitIdentifier(out, i); },
} [&](const ast::IdentifierExpression* i) { //
if (auto* l = expr->As<ast::LiteralExpression>()) { return EmitIdentifier(out, i);
return EmitLiteral(out, l); },
} [&](const ast::LiteralExpression* l) { //
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { return EmitLiteral(out, l);
return EmitMemberAccessor(out, m); },
} [&](const ast::MemberAccessorExpression* m) { //
if (expr->Is<ast::PhonyExpression>()) { return EmitMemberAccessor(out, m);
out << "_"; },
return true; [&](const ast::PhonyExpression*) { //
} out << "_";
if (auto* u = expr->As<ast::UnaryOpExpression>()) { return true;
return EmitUnaryOp(out, u); },
} [&](const ast::UnaryOpExpression* u) { //
return EmitUnaryOp(out, u);
diagnostics_.add_error(diag::System::Writer, "unknown expression type"); },
return false; [&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown expression type");
return false;
});
} }
bool GeneratorImpl::EmitIndexAccessor( bool GeneratorImpl::EmitIndexAccessor(
@ -250,19 +249,28 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
bool GeneratorImpl::EmitLiteral(std::ostream& out, bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) { const ast::LiteralExpression* lit) {
if (auto* bl = lit->As<ast::BoolLiteralExpression>()) { return Switch(
out << (bl->value ? "true" : "false"); lit,
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) { [&](const ast::BoolLiteralExpression* bl) { //
out << FloatToBitPreservingString(fl->value); out << (bl->value ? "true" : "false");
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) { return true;
out << sl->value; },
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) { [&](const ast::FloatLiteralExpression* fl) { //
out << ul->value << "u"; out << FloatToBitPreservingString(fl->value);
} else { return true;
diagnostics_.add_error(diag::System::Writer, "unknown literal type"); },
return false; [&](const ast::SintLiteralExpression* sl) { //
} out << sl->value;
return true; return true;
},
[&](const ast::UintLiteralExpression* ul) { //
out << ul->value << "u";
return true;
},
[&](Default) { //
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
});
} }
bool GeneratorImpl::EmitIdentifier(std::ostream& out, bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@ -366,155 +374,208 @@ bool GeneratorImpl::EmitAccess(std::ostream& out, const ast::Access access) {
} }
bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) { bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
if (auto* ary = ty->As<ast::Array>()) { return Switch(
for (auto* attr : ary->attributes) { ty,
if (auto* stride = attr->As<ast::StrideAttribute>()) { [&](const ast::Array* ary) {
out << "@stride(" << stride->stride << ") "; for (auto* attr : ary->attributes) {
} if (auto* stride = attr->As<ast::StrideAttribute>()) {
} out << "@stride(" << stride->stride << ") ";
}
}
out << "array<"; out << "array<";
if (!EmitType(out, ary->type)) { if (!EmitType(out, ary->type)) {
return false; return false;
} }
if (!ary->IsRuntimeArray()) { if (!ary->IsRuntimeArray()) {
out << ", "; out << ", ";
if (!EmitExpression(out, ary->count)) { if (!EmitExpression(out, ary->count)) {
return false; return false;
} }
} }
out << ">"; out << ">";
} else if (ty->Is<ast::Bool>()) { return true;
out << "bool"; },
} else if (ty->Is<ast::F32>()) { [&](const ast::Bool*) {
out << "f32"; out << "bool";
} else if (ty->Is<ast::I32>()) { return true;
out << "i32"; },
} else if (auto* mat = ty->As<ast::Matrix>()) { [&](const ast::F32*) {
out << "mat" << mat->columns << "x" << mat->rows; out << "f32";
if (auto* el_ty = mat->type) { return true;
out << "<"; },
if (!EmitType(out, el_ty)) { [&](const ast::I32*) {
return false; out << "i32";
} return true;
out << ">"; },
} [&](const ast::Matrix* mat) {
} else if (auto* ptr = ty->As<ast::Pointer>()) { out << "mat" << mat->columns << "x" << mat->rows;
out << "ptr<" << ptr->storage_class << ", "; if (auto* el_ty = mat->type) {
if (!EmitType(out, ptr->type)) { out << "<";
return false; if (!EmitType(out, el_ty)) {
} return false;
if (ptr->access != ast::Access::kUndefined) { }
out << ", "; out << ">";
if (!EmitAccess(out, ptr->access)) { }
return false; return true;
} },
} [&](const ast::Pointer* ptr) {
out << ">"; out << "ptr<" << ptr->storage_class << ", ";
} else if (auto* atomic = ty->As<ast::Atomic>()) { if (!EmitType(out, ptr->type)) {
out << "atomic<"; return false;
if (!EmitType(out, atomic->type)) { }
return false; if (ptr->access != ast::Access::kUndefined) {
} out << ", ";
out << ">"; if (!EmitAccess(out, ptr->access)) {
} else if (auto* sampler = ty->As<ast::Sampler>()) { return false;
out << "sampler"; }
}
out << ">";
return true;
},
[&](const ast::Atomic* atomic) {
out << "atomic<";
if (!EmitType(out, atomic->type)) {
return false;
}
out << ">";
return true;
},
[&](const ast::Sampler* sampler) {
out << "sampler";
if (sampler->IsComparison()) { if (sampler->IsComparison()) {
out << "_comparison"; out << "_comparison";
} }
} else if (ty->Is<ast::ExternalTexture>()) { return true;
out << "texture_external"; },
} else if (auto* texture = ty->As<ast::Texture>()) { [&](const ast::ExternalTexture*) {
out << "texture_"; out << "texture_external";
if (texture->Is<ast::DepthTexture>()) { return true;
out << "depth_"; },
} else if (texture->Is<ast::DepthMultisampledTexture>()) { [&](const ast::Texture* texture) {
out << "depth_multisampled_"; out << "texture_";
} else if (texture->Is<ast::SampledTexture>()) { bool ok = Switch(
/* nothing to emit */ texture,
} else if (texture->Is<ast::MultisampledTexture>()) { [&](const ast::DepthTexture*) { //
out << "multisampled_"; out << "depth_";
} else if (texture->Is<ast::StorageTexture>()) { return true;
out << "storage_"; },
} else { [&](const ast::DepthMultisampledTexture*) { //
diagnostics_.add_error(diag::System::Writer, "unknown texture type"); out << "depth_multisampled_";
return false; return true;
} },
[&](const ast::SampledTexture*) { //
/* nothing to emit */
return true;
},
[&](const ast::MultisampledTexture*) { //
out << "multisampled_";
return true;
},
[&](const ast::StorageTexture*) { //
out << "storage_";
return true;
},
[&](Default) { //
diagnostics_.add_error(diag::System::Writer,
"unknown texture type");
return false;
});
if (!ok) {
return false;
}
switch (texture->dim) { switch (texture->dim) {
case ast::TextureDimension::k1d: case ast::TextureDimension::k1d:
out << "1d"; out << "1d";
break; break;
case ast::TextureDimension::k2d: case ast::TextureDimension::k2d:
out << "2d"; out << "2d";
break; break;
case ast::TextureDimension::k2dArray: case ast::TextureDimension::k2dArray:
out << "2d_array"; out << "2d_array";
break; break;
case ast::TextureDimension::k3d: case ast::TextureDimension::k3d:
out << "3d"; out << "3d";
break; break;
case ast::TextureDimension::kCube: case ast::TextureDimension::kCube:
out << "cube"; out << "cube";
break; break;
case ast::TextureDimension::kCubeArray: case ast::TextureDimension::kCubeArray:
out << "cube_array"; out << "cube_array";
break; break;
default: default:
diagnostics_.add_error(diag::System::Writer, diagnostics_.add_error(diag::System::Writer,
"unknown texture dimension"); "unknown texture dimension");
return false; return false;
} }
if (auto* sampled = texture->As<ast::SampledTexture>()) { return Switch(
out << "<"; texture,
if (!EmitType(out, sampled->type)) { [&](const ast::SampledTexture* sampled) { //
out << "<";
if (!EmitType(out, sampled->type)) {
return false;
}
out << ">";
return true;
},
[&](const ast::MultisampledTexture* ms) { //
out << "<";
if (!EmitType(out, ms->type)) {
return false;
}
out << ">";
return true;
},
[&](const ast::StorageTexture* storage) { //
out << "<";
if (!EmitImageFormat(out, storage->format)) {
return false;
}
out << ", ";
if (!EmitAccess(out, storage->access)) {
return false;
}
out << ">";
return true;
},
[&](Default) { //
return true;
});
},
[&](const ast::U32*) {
out << "u32";
return true;
},
[&](const ast::Vector* vec) {
out << "vec" << vec->width;
if (auto* el_ty = vec->type) {
out << "<";
if (!EmitType(out, el_ty)) {
return false;
}
out << ">";
}
return true;
},
[&](const ast::Void*) {
out << "void";
return true;
},
[&](const ast::TypeName* tn) {
out << program_->Symbols().NameFor(tn->name);
return true;
},
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
"unknown type in EmitType: " + std::string(ty->TypeInfo().name));
return false; return false;
} });
out << ">";
} else if (auto* ms = texture->As<ast::MultisampledTexture>()) {
out << "<";
if (!EmitType(out, ms->type)) {
return false;
}
out << ">";
} else if (auto* storage = texture->As<ast::StorageTexture>()) {
out << "<";
if (!EmitImageFormat(out, storage->format)) {
return false;
}
out << ", ";
if (!EmitAccess(out, storage->access)) {
return false;
}
out << ">";
}
} else if (ty->Is<ast::U32>()) {
out << "u32";
} else if (auto* vec = ty->As<ast::Vector>()) {
out << "vec" << vec->width;
if (auto* el_ty = vec->type) {
out << "<";
if (!EmitType(out, el_ty)) {
return false;
}
out << ">";
}
} else if (ty->Is<ast::Void>()) {
out << "void";
} else if (auto* tn = ty->As<ast::TypeName>()) {
out << program_->Symbols().NameFor(tn->name);
} else {
diagnostics_.add_error(
diag::System::Writer,
"unknown type in EmitType: " + std::string(ty->TypeInfo().name));
return false;
}
return true;
} }
bool GeneratorImpl::EmitStructType(const ast::Struct* str) { bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
@ -632,56 +693,90 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out,
} }
first = false; first = false;
out << "@"; out << "@";
if (auto* workgroup = attr->As<ast::WorkgroupAttribute>()) { bool ok = Switch(
auto values = workgroup->Values(); attr,
out << "workgroup_size("; [&](const ast::WorkgroupAttribute* workgroup) {
for (int i = 0; i < 3; i++) { auto values = workgroup->Values();
if (values[i]) { out << "workgroup_size(";
if (i > 0) { for (int i = 0; i < 3; i++) {
out << ", "; if (values[i]) {
if (i > 0) {
out << ", ";
}
if (!EmitExpression(out, values[i])) {
return false;
}
}
} }
if (!EmitExpression(out, values[i])) { out << ")";
return false; return true;
},
[&](const ast::StructBlockAttribute*) { //
out << "block";
return true;
},
[&](const ast::StageAttribute* stage) {
out << "stage(" << stage->stage << ")";
return true;
},
[&](const ast::BindingAttribute* binding) {
out << "binding(" << binding->value << ")";
return true;
},
[&](const ast::GroupAttribute* group) {
out << "group(" << group->value << ")";
return true;
},
[&](const ast::LocationAttribute* location) {
out << "location(" << location->value << ")";
return true;
},
[&](const ast::BuiltinAttribute* builtin) {
out << "builtin(" << builtin->builtin << ")";
return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
out << "interpolate(" << interpolate->type;
if (interpolate->sampling != ast::InterpolationSampling::kNone) {
out << ", " << interpolate->sampling;
} }
} out << ")";
} return true;
out << ")"; },
} else if (attr->Is<ast::StructBlockAttribute>()) { [&](const ast::InvariantAttribute*) {
out << "block"; out << "invariant";
} else if (auto* stage = attr->As<ast::StageAttribute>()) { return true;
out << "stage(" << stage->stage << ")"; },
} else if (auto* binding = attr->As<ast::BindingAttribute>()) { [&](const ast::OverrideAttribute* override_deco) {
out << "binding(" << binding->value << ")"; out << "override";
} else if (auto* group = attr->As<ast::GroupAttribute>()) { if (override_deco->has_value) {
out << "group(" << group->value << ")"; out << "(" << override_deco->value << ")";
} else if (auto* location = attr->As<ast::LocationAttribute>()) { }
out << "location(" << location->value << ")"; return true;
} else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { },
out << "builtin(" << builtin->builtin << ")"; [&](const ast::StructMemberSizeAttribute* size) {
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { out << "size(" << size->size << ")";
out << "interpolate(" << interpolate->type; return true;
if (interpolate->sampling != ast::InterpolationSampling::kNone) { },
out << ", " << interpolate->sampling; [&](const ast::StructMemberAlignAttribute* align) {
} out << "align(" << align->align << ")";
out << ")"; return true;
} else if (attr->Is<ast::InvariantAttribute>()) { },
out << "invariant"; [&](const ast::StrideAttribute* stride) {
} else if (auto* override_attr = attr->As<ast::OverrideAttribute>()) { out << "stride(" << stride->stride << ")";
out << "override"; return true;
if (override_attr->has_value) { },
out << "(" << override_attr->value << ")"; [&](const ast::InternalAttribute* internal) {
} out << "internal(" << internal->InternalName() << ")";
} else if (auto* size = attr->As<ast::StructMemberSizeAttribute>()) { return true;
out << "size(" << size->size << ")"; },
} else if (auto* align = attr->As<ast::StructMemberAlignAttribute>()) { [&](Default) {
out << "align(" << align->align << ")"; TINT_ICE(Writer, diagnostics_)
} else if (auto* stride = attr->As<ast::StrideAttribute>()) { << "Unsupported attribute '" << attr->TypeInfo().name << "'";
out << "stride(" << stride->stride << ")"; return false;
} else if (auto* internal = attr->As<ast::InternalAttribute>()) { });
out << "internal(" << internal->InternalName() << ")";
} else { if (!ok) {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported attribute '" << attr->TypeInfo().name << "'";
return false; return false;
} }
} }
@ -809,55 +904,36 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
} }
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) { return Switch(
return EmitAssign(a); stmt, //
} [&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
if (auto* b = stmt->As<ast::BlockStatement>()) { [&](const ast::BlockStatement* b) { return EmitBlock(b); },
return EmitBlock(b); [&](const ast::BreakStatement* b) { return EmitBreak(b); },
} [&](const ast::CallStatement* c) {
if (auto* b = stmt->As<ast::BreakStatement>()) { auto out = line();
return EmitBreak(b); if (!EmitCall(out, c->expr)) {
} return false;
if (auto* c = stmt->As<ast::CallStatement>()) { }
auto out = line(); out << ";";
if (!EmitCall(out, c->expr)) { return true;
return false; },
} [&](const ast::ContinueStatement* c) { return EmitContinue(c); },
out << ";"; [&](const ast::DiscardStatement* d) { return EmitDiscard(d); },
return true; [&](const ast::FallthroughStatement* f) { return EmitFallthrough(f); },
} [&](const ast::IfStatement* i) { return EmitIf(i); },
if (auto* c = stmt->As<ast::ContinueStatement>()) { [&](const ast::LoopStatement* l) { return EmitLoop(l); },
return EmitContinue(c); [&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
} [&](const ast::ReturnStatement* r) { return EmitReturn(r); },
if (auto* d = stmt->As<ast::DiscardStatement>()) { [&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
return EmitDiscard(d); [&](const ast::VariableDeclStatement* v) {
} return EmitVariable(line(), v->variable);
if (auto* f = stmt->As<ast::FallthroughStatement>()) { },
return EmitFallthrough(f); [&](Default) {
} diagnostics_.add_error(
if (auto* i = stmt->As<ast::IfStatement>()) { diag::System::Writer,
return EmitIf(i); "unknown statement type: " + std::string(stmt->TypeInfo().name));
} return false;
if (auto* l = stmt->As<ast::LoopStatement>()) { });
return EmitLoop(l);
}
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
return EmitForLoop(l);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return EmitSwitch(s);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(line(), v->variable);
}
diagnostics_.add_error(
diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name));
return false;
} }
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) { bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {