Add tint::Switch()

A type dispatch helper with replaces chains of:

  if (auto* a = obj->As<A>()) {
    ...
  } else if (auto* b = obj->As<B>()) {
    ...
  } else {
    ...
  }

with:

  Switch(obj,
    [&](A* a) { ... },
    [&](B* b) { ... },
    [&](Default) { ... });

This new helper provides greater opportunities for optimizations, avoids
scoping issues with if-else blocks, and is slightly cleaner (IMO).

Bug: tint:1383
Change-Id: Ice469a03342ef57cbcf65f69753e4b528ac50137
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78543
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-02-04 15:38:23 +00:00 committed by Tint LUCI CQ
parent fa0d64b76d
commit de857e1c58
11 changed files with 2413 additions and 1570 deletions

View File

@ -1160,6 +1160,7 @@ if(TINT_BUILD_BENCHMARKS)
endif() endif()
set(TINT_BENCHMARK_SRC set(TINT_BENCHMARK_SRC
"castable_bench.cc"
"bench/benchmark.cc" "bench/benchmark.cc"
"reader/wgsl/parser_bench.cc" "reader/wgsl/parser_bench.cc"
) )

View File

@ -35,16 +35,15 @@ Module::Module(ProgramID pid,
continue; continue;
} }
if (auto* ty = decl->As<ast::TypeDecl>()) { Switch(
type_decls_.push_back(ty); decl, //
} else if (auto* func = decl->As<Function>()) { [&](const ast::TypeDecl* type) { type_decls_.push_back(type); },
functions_.push_back(func); [&](const Function* func) { functions_.push_back(func); },
} else if (auto* var = decl->As<Variable>()) { [&](const Variable* var) { global_variables_.push_back(var); },
global_variables_.push_back(var); [&](Default) {
} else {
diag::List diagnostics; diag::List diagnostics;
TINT_ICE(AST, diagnostics) << "Unknown global declaration type"; TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
} });
} }
} }
@ -101,19 +100,24 @@ void Module::Copy(CloneContext* ctx, const Module* src) {
<< "src global declaration was nullptr"; << "src global declaration was nullptr";
continue; continue;
} }
if (auto* type = decl->As<ast::TypeDecl>()) { Switch(
decl,
[&](const ast::TypeDecl* type) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
type_decls_.push_back(type); type_decls_.push_back(type);
} else if (auto* func = decl->As<Function>()) { },
[&](const Function* func) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
functions_.push_back(func); functions_.push_back(func);
} else if (auto* var = decl->As<Variable>()) { },
[&](const Variable* var) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
global_variables_.push_back(var); global_variables_.push_back(var);
} else { },
[&](Default) {
TINT_ICE(AST, ctx->dst->Diagnostics()) TINT_ICE(AST, ctx->dst->Diagnostics())
<< "Unknown global declaration type"; << "Unknown global declaration type";
} });
} }
} }

View File

@ -101,30 +101,47 @@ bool TraverseExpressions(const ast::Expression* root,
} }
} }
if (auto* idx = expr->As<IndexAccessorExpression>()) { bool ok = Switch(
expr,
[&](const IndexAccessorExpression* idx) {
push_pair(idx->object, idx->index); push_pair(idx->object, idx->index);
} else if (auto* bin_op = expr->As<BinaryExpression>()) { return true;
},
[&](const BinaryExpression* bin_op) {
push_pair(bin_op->lhs, bin_op->rhs); push_pair(bin_op->lhs, bin_op->rhs);
} else if (auto* bitcast = expr->As<BitcastExpression>()) { return true;
},
[&](const BitcastExpression* bitcast) {
to_visit.push_back(bitcast->expr); to_visit.push_back(bitcast->expr);
} else if (auto* call = expr->As<CallExpression>()) { return true;
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the },
// function name in the traversal. [&](const CallExpression* call) {
// to_visit.push_back(call->func); // TODO(crbug.com/tint/1257): Resolver breaks if we actually include
// the function name in the traversal. to_visit.push_back(call->func);
push_list(call->args); push_list(call->args);
} else if (auto* member = expr->As<MemberAccessorExpression>()) { return true;
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the },
// member name in the traversal. [&](const MemberAccessorExpression* member) {
// push_pair(member->structure, member->member); // TODO(crbug.com/tint/1257): Resolver breaks if we actually include
// the member name in the traversal. push_pair(member->structure,
// member->member);
to_visit.push_back(member->structure); to_visit.push_back(member->structure);
} else if (auto* unary = expr->As<UnaryOpExpression>()) { return true;
},
[&](const UnaryOpExpression* unary) {
to_visit.push_back(unary->expr); to_visit.push_back(unary->expr);
} else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression, return true;
},
[&](Default) {
if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
PhonyExpression>()) { PhonyExpression>()) {
// Leaf expression return true; // Leaf expression
} else { }
TINT_ICE(AST, diags) << "unhandled expression type: " TINT_ICE(AST, diags)
<< expr->TypeInfo().name; << "unhandled expression type: " << expr->TypeInfo().name;
return false;
});
if (!ok) {
return false; return false;
} }
} }

View File

@ -453,6 +453,105 @@ class Castable : public BASE {
} }
}; };
/// Default can be used as the default case for a Switch(), when all previous
/// cases failed to match.
///
/// Example:
/// ```
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ },
/// [&](Default) { /* If not TypeA or TypeB */ });
/// ```
struct Default {};
/// Switch is used to dispatch one of the provided callback case handler
/// functions based on the type of `object` and the parameter type of the case
/// handlers. Switch will sequentially check the type of `object` against each
/// of the switch case handler functions, and will invoke the first case handler
/// function which has a parameter type that matches the object type. When a
/// case handler is matched, it will be called with the single argument of
/// `object` cast to the case handler's parameter type. Switch will invoke at
/// most one case handler. Each of the case functions must have the signature
/// `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
/// is the return type, consistent across all case handlers.
///
/// An optional default case function with the signature `R(Default)` can be
/// used as the last case. This default case will be called if all previous
/// cases failed to match.
///
/// Example:
/// ```
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ });
///
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ },
/// [&](Default) { /* Called if object is not TypeA or TypeB */ });
/// ```
///
/// @param object the object who's type is used to
/// @param first_case the first switch case
/// @param other_cases additional switch cases (optional)
/// @return the value returned by the called case. If no cases matched, then the
/// zero value for the consistent case type.
template <typename T, typename FIRST_CASE, typename... OTHER_CASES>
traits::ReturnType<FIRST_CASE> //
Switch(T* object, FIRST_CASE&& first_case, OTHER_CASES&&... other_cases) {
using ReturnType = traits::ReturnType<FIRST_CASE>;
using CaseType = std::remove_pointer_t<traits::ParameterType<FIRST_CASE, 0>>;
static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
static_assert(traits::SignatureOfT<FIRST_CASE>::parameter_count == 1,
"Switch case must have a single parameter");
if constexpr (std::is_same_v<CaseType, Default>) {
// Default case. Must be last.
(void)object; // 'object' is not used by the Default case.
static_assert(sizeof...(other_cases) == 0,
"Switch Default case must come last");
if constexpr (kHasReturnType) {
return first_case({});
} else {
first_case({});
return;
}
} else {
// Regular case.
static_assert(traits::IsTypeOrDerived<CaseType, CastableBase>::value,
"Switch case parameter is not a Castable pointer");
// Does the case match?
if (auto* ptr = As<CaseType>(object)) {
if constexpr (kHasReturnType) {
return first_case(ptr);
} else {
first_case(ptr);
return;
}
}
// Case did not match. Got any more cases to try?
if constexpr (sizeof...(other_cases) > 0) {
// Try the next cases...
if constexpr (kHasReturnType) {
auto res = Switch(object, std::forward<OTHER_CASES>(other_cases)...);
static_assert(std::is_same_v<decltype(res), ReturnType>,
"Switch case types do not have consistent return type");
return res;
} else {
Switch(object, std::forward<OTHER_CASES>(other_cases)...);
return;
}
} else {
// That was the last case. No cases matched.
if constexpr (kHasReturnType) {
return {};
} else {
return;
}
}
}
}
} // namespace tint } // namespace tint
TINT_CASTABLE_POP_DISABLE_WARNINGS(); TINT_CASTABLE_POP_DISABLE_WARNINGS();

270
src/castable_bench.cc Normal file
View File

@ -0,0 +1,270 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bench/benchmark.h"
namespace tint {
namespace {
struct Base : public tint::Castable<Base> {};
struct A : public tint::Castable<A, Base> {};
struct AA : public tint::Castable<AA, A> {};
struct AAA : public tint::Castable<AAA, AA> {};
struct AAB : public tint::Castable<AAB, AA> {};
struct AAC : public tint::Castable<AAC, AA> {};
struct AB : public tint::Castable<AB, A> {};
struct ABA : public tint::Castable<ABA, AB> {};
struct ABB : public tint::Castable<ABB, AB> {};
struct ABC : public tint::Castable<ABC, AB> {};
struct AC : public tint::Castable<AC, A> {};
struct ACA : public tint::Castable<ACA, AC> {};
struct ACB : public tint::Castable<ACB, AC> {};
struct ACC : public tint::Castable<ACC, AC> {};
struct B : public tint::Castable<B, Base> {};
struct BA : public tint::Castable<BA, B> {};
struct BAA : public tint::Castable<BAA, BA> {};
struct BAB : public tint::Castable<BAB, BA> {};
struct BAC : public tint::Castable<BAC, BA> {};
struct BB : public tint::Castable<BB, B> {};
struct BBA : public tint::Castable<BBA, BB> {};
struct BBB : public tint::Castable<BBB, BB> {};
struct BBC : public tint::Castable<BBC, BB> {};
struct BC : public tint::Castable<BC, B> {};
struct BCA : public tint::Castable<BCA, BC> {};
struct BCB : public tint::Castable<BCB, BC> {};
struct BCC : public tint::Castable<BCC, BC> {};
struct C : public tint::Castable<C, Base> {};
struct CA : public tint::Castable<CA, C> {};
struct CAA : public tint::Castable<CAA, CA> {};
struct CAB : public tint::Castable<CAB, CA> {};
struct CAC : public tint::Castable<CAC, CA> {};
struct CB : public tint::Castable<CB, C> {};
struct CBA : public tint::Castable<CBA, CB> {};
struct CBB : public tint::Castable<CBB, CB> {};
struct CBC : public tint::Castable<CBC, CB> {};
struct CC : public tint::Castable<CC, C> {};
struct CCA : public tint::Castable<CCA, CC> {};
struct CCB : public tint::Castable<CCB, CC> {};
struct CCC : public tint::Castable<CCC, CC> {};
using AllTypes = std::tuple<Base,
A,
AA,
AAA,
AAB,
AAC,
AB,
ABA,
ABB,
ABC,
AC,
ACA,
ACB,
ACC,
B,
BA,
BAA,
BAB,
BAC,
BB,
BBA,
BBB,
BBC,
BC,
BCA,
BCB,
BCC,
C,
CA,
CAA,
CAB,
CAC,
CB,
CBA,
CBB,
CBC,
CC,
CCA,
CCB,
CCC>;
std::vector<std::unique_ptr<Base>> MakeObjects() {
std::vector<std::unique_ptr<Base>> out;
out.emplace_back(std::make_unique<Base>());
out.emplace_back(std::make_unique<A>());
out.emplace_back(std::make_unique<AA>());
out.emplace_back(std::make_unique<AAA>());
out.emplace_back(std::make_unique<AAB>());
out.emplace_back(std::make_unique<AAC>());
out.emplace_back(std::make_unique<AB>());
out.emplace_back(std::make_unique<ABA>());
out.emplace_back(std::make_unique<ABB>());
out.emplace_back(std::make_unique<ABC>());
out.emplace_back(std::make_unique<AC>());
out.emplace_back(std::make_unique<ACA>());
out.emplace_back(std::make_unique<ACB>());
out.emplace_back(std::make_unique<ACC>());
out.emplace_back(std::make_unique<B>());
out.emplace_back(std::make_unique<BA>());
out.emplace_back(std::make_unique<BAA>());
out.emplace_back(std::make_unique<BAB>());
out.emplace_back(std::make_unique<BAC>());
out.emplace_back(std::make_unique<BB>());
out.emplace_back(std::make_unique<BBA>());
out.emplace_back(std::make_unique<BBB>());
out.emplace_back(std::make_unique<BBC>());
out.emplace_back(std::make_unique<BC>());
out.emplace_back(std::make_unique<BCA>());
out.emplace_back(std::make_unique<BCB>());
out.emplace_back(std::make_unique<BCC>());
out.emplace_back(std::make_unique<C>());
out.emplace_back(std::make_unique<CA>());
out.emplace_back(std::make_unique<CAA>());
out.emplace_back(std::make_unique<CAB>());
out.emplace_back(std::make_unique<CAC>());
out.emplace_back(std::make_unique<CB>());
out.emplace_back(std::make_unique<CBA>());
out.emplace_back(std::make_unique<CBB>());
out.emplace_back(std::make_unique<CBC>());
out.emplace_back(std::make_unique<CC>());
out.emplace_back(std::make_unique<CCA>());
out.emplace_back(std::make_unique<CCB>());
out.emplace_back(std::make_unique<CCC>());
return out;
}
void CastableLargeSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const AAA*) { ::benchmark::DoNotOptimize(i += 40); },
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 50); },
[&](const AAC*) { ::benchmark::DoNotOptimize(i += 60); },
[&](const ABA*) { ::benchmark::DoNotOptimize(i += 80); },
[&](const ABB*) { ::benchmark::DoNotOptimize(i += 90); },
[&](const ABC*) { ::benchmark::DoNotOptimize(i += 100); },
[&](const ACA*) { ::benchmark::DoNotOptimize(i += 120); },
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
[&](const ACC*) { ::benchmark::DoNotOptimize(i += 140); },
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
[&](const BAC*) { ::benchmark::DoNotOptimize(i += 190); },
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
[&](const BBC*) { ::benchmark::DoNotOptimize(i += 230); },
[&](const BCA*) { ::benchmark::DoNotOptimize(i += 250); },
[&](const BCB*) { ::benchmark::DoNotOptimize(i += 260); },
[&](const BCC*) { ::benchmark::DoNotOptimize(i += 270); },
[&](const CA*) { ::benchmark::DoNotOptimize(i += 290); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
[&](const CAB*) { ::benchmark::DoNotOptimize(i += 310); },
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 320); },
[&](const CBA*) { ::benchmark::DoNotOptimize(i += 340); },
[&](const CBB*) { ::benchmark::DoNotOptimize(i += 350); },
[&](const CBC*) { ::benchmark::DoNotOptimize(i += 360); },
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableLargeSwitch);
void CastableMediumSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableMediumSwitch);
void CastableSmallSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 30); },
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 290); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableSmallSwitch);
} // namespace
} // namespace tint
TINT_INSTANTIATE_TYPEINFO(tint::Base);
TINT_INSTANTIATE_TYPEINFO(tint::A);
TINT_INSTANTIATE_TYPEINFO(tint::AA);
TINT_INSTANTIATE_TYPEINFO(tint::AAA);
TINT_INSTANTIATE_TYPEINFO(tint::AAB);
TINT_INSTANTIATE_TYPEINFO(tint::AAC);
TINT_INSTANTIATE_TYPEINFO(tint::AB);
TINT_INSTANTIATE_TYPEINFO(tint::ABA);
TINT_INSTANTIATE_TYPEINFO(tint::ABB);
TINT_INSTANTIATE_TYPEINFO(tint::ABC);
TINT_INSTANTIATE_TYPEINFO(tint::AC);
TINT_INSTANTIATE_TYPEINFO(tint::ACA);
TINT_INSTANTIATE_TYPEINFO(tint::ACB);
TINT_INSTANTIATE_TYPEINFO(tint::ACC);
TINT_INSTANTIATE_TYPEINFO(tint::B);
TINT_INSTANTIATE_TYPEINFO(tint::BA);
TINT_INSTANTIATE_TYPEINFO(tint::BAA);
TINT_INSTANTIATE_TYPEINFO(tint::BAB);
TINT_INSTANTIATE_TYPEINFO(tint::BAC);
TINT_INSTANTIATE_TYPEINFO(tint::BB);
TINT_INSTANTIATE_TYPEINFO(tint::BBA);
TINT_INSTANTIATE_TYPEINFO(tint::BBB);
TINT_INSTANTIATE_TYPEINFO(tint::BBC);
TINT_INSTANTIATE_TYPEINFO(tint::BC);
TINT_INSTANTIATE_TYPEINFO(tint::BCA);
TINT_INSTANTIATE_TYPEINFO(tint::BCB);
TINT_INSTANTIATE_TYPEINFO(tint::BCC);
TINT_INSTANTIATE_TYPEINFO(tint::C);
TINT_INSTANTIATE_TYPEINFO(tint::CA);
TINT_INSTANTIATE_TYPEINFO(tint::CAA);
TINT_INSTANTIATE_TYPEINFO(tint::CAB);
TINT_INSTANTIATE_TYPEINFO(tint::CAC);
TINT_INSTANTIATE_TYPEINFO(tint::CB);
TINT_INSTANTIATE_TYPEINFO(tint::CBA);
TINT_INSTANTIATE_TYPEINFO(tint::CBB);
TINT_INSTANTIATE_TYPEINFO(tint::CBC);
TINT_INSTANTIATE_TYPEINFO(tint::CC);
TINT_INSTANTIATE_TYPEINFO(tint::CCA);
TINT_INSTANTIATE_TYPEINFO(tint::CCB);
TINT_INSTANTIATE_TYPEINFO(tint::CCC);

View File

@ -252,6 +252,151 @@ TEST(Castable, As) {
ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get())); ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get()));
} }
TEST(Castable, SwitchNoDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_amphibian = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Amphibian* amphibian) {
EXPECT_EQ(amphibian, frog.get());
frog_matched_amphibian = true;
});
EXPECT_TRUE(frog_matched_amphibian);
}
{
bool bear_matched_mammal = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Mammal* mammal) {
EXPECT_EQ(mammal, bear.get());
bear_matched_mammal = true;
});
EXPECT_TRUE(bear_matched_mammal);
}
{
bool gecko_matched_reptile = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Reptile* reptile) {
EXPECT_EQ(reptile, gecko.get());
gecko_matched_reptile = true;
});
EXPECT_TRUE(gecko_matched_reptile);
}
}
TEST(Castable, SwitchWithUnusedDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_amphibian = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Amphibian* amphibian) {
EXPECT_EQ(amphibian, frog.get());
frog_matched_amphibian = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(frog_matched_amphibian);
}
{
bool bear_matched_mammal = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Mammal* mammal) {
EXPECT_EQ(mammal, bear.get());
bear_matched_mammal = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(bear_matched_mammal);
}
{
bool gecko_matched_reptile = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Reptile* reptile) {
EXPECT_EQ(reptile, gecko.get());
gecko_matched_reptile = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(gecko_matched_reptile);
}
}
TEST(Castable, SwitchDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_default = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Default) { frog_matched_default = true; });
EXPECT_TRUE(frog_matched_default);
}
{
bool bear_matched_default = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Default) { bear_matched_default = true; });
EXPECT_TRUE(bear_matched_default);
}
{
bool gecko_matched_default = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Default) { gecko_matched_default = true; });
EXPECT_TRUE(gecko_matched_default);
}
}
TEST(Castable, SwitchMatchFirst) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
{
bool frog_matched_animal = false;
Switch(
frog.get(),
[&](Animal* animal) {
EXPECT_EQ(animal, frog.get());
frog_matched_animal = true;
},
[&](Amphibian*) { FAIL() << "animal should have been matched first"; });
EXPECT_TRUE(frog_matched_animal);
}
{
bool frog_matched_amphibian = false;
Switch(
frog.get(),
[&](Amphibian* amphibain) {
EXPECT_EQ(amphibain, frog.get());
frog_matched_amphibian = true;
},
[&](Animal*) { FAIL() << "amphibian should have been matched first"; });
EXPECT_TRUE(frog_matched_amphibian);
}
}
} // namespace } // namespace
TINT_INSTANTIATE_TYPEINFO(Animal); TINT_INSTANTIATE_TYPEINFO(Animal);

View File

@ -953,7 +953,7 @@ const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
bool FunctionEmitter::EmitPipelineInput(std::string var_name, bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* var_type, const Type* var_type,
ast::AttributeList* decos, ast::AttributeList* attrs,
std::vector<int> index_prefix, std::vector<int> index_prefix,
const Type* tip_type, const Type* tip_type,
const Type* forced_param_type, const Type* forced_param_type,
@ -966,19 +966,23 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name,
} }
// Recursively flatten matrices, arrays, and structures. // Recursively flatten matrices, arrays, and structures.
if (auto* matrix_type = tip_type->As<Matrix>()) { return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
index_prefix.push_back(0); index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns); const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) { for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col; index_prefix.back() = col;
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty, if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
forced_param_type, params, statements)) { vec_ty, forced_param_type, params,
statements)) {
return false; return false;
} }
} }
return success(); return success();
} else if (auto* array_type = tip_type->As<Array>()) { },
[&](const Array* array_type) -> bool {
if (array_type->size == 0) { if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO"; return Fail() << "runtime-size array not allowed on pipeline IO";
} }
@ -986,85 +990,97 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* elem_ty = array_type->type; const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) { for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i; index_prefix.back() = i;
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty, if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
forced_param_type, params, statements)) { elem_ty, forced_param_type, params,
statements)) {
return false; return false;
} }
} }
return success(); return success();
} else if (auto* struct_type = tip_type->As<Struct>()) { },
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members; const auto& members = struct_type->members;
index_prefix.push_back(0); index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) { for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i; index_prefix.back() = i;
ast::AttributeList member_decos(*decos); ast::AttributeList member_attrs(*attrs);
if (!parser_impl_.ConvertPipelineDecorations( if (!parser_impl_.ConvertPipelineDecorations(
struct_type, struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type, i), parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_decos)) { &member_attrs)) {
return false; return false;
} }
if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix, if (!EmitPipelineInput(var_name, var_type, &member_attrs,
members[i], forced_param_type, params, index_prefix, members[i], forced_param_type,
statements)) { params, statements)) {
return false; return false;
} }
// Copy the location as updated by nested expansion of the member. // Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_decos)); parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
} }
return success(); return success();
} },
[&](Default) {
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos); const bool is_builtin =
ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
const Type* param_type = is_builtin ? forced_param_type : tip_type; const Type* param_type = is_builtin ? forced_param_type : tip_type;
const auto param_name = namer_.MakeDerivedName(var_name + "_param"); const auto param_name = namer_.MakeDerivedName(var_name + "_param");
// Create the parameter. // Create the parameter.
// TODO(dneto): Note: If the parameter has non-location decorations, // TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple elements // then those decoration AST nodes will be reused between multiple
// of a matrix, array, or structure. Normally that's disallowed but currently // elements of a matrix, array, or structure. Normally that's
// the SPIR-V reader will make duplicates when the entire AST is cloned // disallowed but currently the SPIR-V reader will make duplicates when
// at the top level of the SPIR-V reader flow. Consider rewriting this // the entire AST is cloned at the top level of the SPIR-V reader flow.
// to avoid this node-sharing. // Consider rewriting this to avoid this node-sharing.
params->push_back( params->push_back(
builder_.Param(param_name, param_type->Build(builder_), *decos)); builder_.Param(param_name, param_type->Build(builder_), *attrs));
// Add a body statement to copy the parameter to the corresponding private // Add a body statement to copy the parameter to the corresponding
// variable. // private variable.
const ast::Expression* param_value = builder_.Expr(param_name); const ast::Expression* param_value = builder_.Expr(param_name);
const ast::Expression* store_dest = builder_.Expr(var_name); const ast::Expression* store_dest = builder_.Expr(var_name);
// Index into the LHS as needed. // Index into the LHS as needed.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); auto* current_type =
var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) { for (auto index : index_prefix) {
if (auto* matrix_type = current_type->As<Matrix>()) { Switch(
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index)); current_type,
[&](const Matrix* matrix_type) {
store_dest =
builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows); current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
} else if (auto* array_type = current_type->As<Array>()) { },
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index)); [&](const Array* array_type) {
store_dest =
builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias(); current_type = array_type->type->UnwrapAlias();
} else if (auto* struct_type = current_type->As<Struct>()) { },
[&](const Struct* struct_type) {
store_dest = builder_.MemberAccessor( store_dest = builder_.MemberAccessor(
store_dest, store_dest, builder_.Expr(parser_impl_.GetMemberName(
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index))); *struct_type, index)));
current_type = struct_type->members[index]; current_type = struct_type->members[index];
} });
} }
if (is_builtin && (tip_type != forced_param_type)) { if (is_builtin && (tip_type != forced_param_type)) {
// The parameter will have the WGSL type, but we need bitcast to // The parameter will have the WGSL type, but we need bitcast to
// the variable store type. // the variable store type.
param_value = param_value = create<ast::BitcastExpression>(
create<ast::BitcastExpression>(tip_type->Build(builder_), param_value); tip_type->Build(builder_), param_value);
} }
statements->push_back(builder_.Assign(store_dest, param_value)); statements->push_back(builder_.Assign(store_dest, param_value));
// Increment the location attribute, in case more parameters will follow. // Increment the location attribute, in case more parameters will
IncrementLocation(decos); // follow.
IncrementLocation(attrs);
return success(); return success();
});
} }
void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) { void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
@ -1102,20 +1118,23 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
} }
// Recursively flatten matrices, arrays, and structures. // Recursively flatten matrices, arrays, and structures.
if (auto* matrix_type = tip_type->As<Matrix>()) { return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
index_prefix.push_back(0); index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns); const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows); const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) { for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col; index_prefix.back() = col;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty, if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
forced_member_type, return_members, vec_ty, forced_member_type, return_members,
return_exprs)) { return_exprs)) {
return false; return false;
} }
} }
return success(); return success();
} else if (auto* array_type = tip_type->As<Array>()) { },
[&](const Array* array_type) -> bool {
if (array_type->size == 0) { if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO"; return Fail() << "runtime-size array not allowed on pipeline IO";
} }
@ -1123,37 +1142,39 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
const Type* elem_ty = array_type->type; const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) { for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i; index_prefix.back() = i;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty, if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
forced_member_type, return_members, elem_ty, forced_member_type, return_members,
return_exprs)) { return_exprs)) {
return false; return false;
} }
} }
return success(); return success();
} else if (auto* struct_type = tip_type->As<Struct>()) { },
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members; const auto& members = struct_type->members;
index_prefix.push_back(0); index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) { for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i; index_prefix.back() = i;
ast::AttributeList member_decos(*decos); ast::AttributeList member_attrs(*decos);
if (!parser_impl_.ConvertPipelineDecorations( if (!parser_impl_.ConvertPipelineDecorations(
struct_type, struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type, i), parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_decos)) { &member_attrs)) {
return false; return false;
} }
if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix, if (!EmitPipelineOutput(var_name, var_type, &member_attrs,
members[i], forced_member_type, return_members, index_prefix, members[i], forced_member_type,
return_exprs)) { return_members, return_exprs)) {
return false; return false;
} }
// Copy the location as updated by nested expansion of the member. // Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_decos)); parser_impl_.SetLocation(decos, GetLocation(member_attrs));
} }
return success(); return success();
} },
[&](Default) {
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos); const bool is_builtin =
ast::HasAttribute<ast::BuiltinAttribute>(*decos);
const Type* member_type = is_builtin ? forced_member_type : tip_type; const Type* member_type = is_builtin ? forced_member_type : tip_type;
// Derive the member name directly from the variable name. They can't // Derive the member name directly from the variable name. They can't
@ -1161,11 +1182,11 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
const auto member_name = namer_.MakeDerivedName(var_name); const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member. // Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations, // TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple elements // then those decoration AST nodes will be reused between multiple
// of a matrix, array, or structure. Normally that's disallowed but currently // elements of a matrix, array, or structure. Normally that's
// the SPIR-V reader will make duplicates when the entire AST is cloned // disallowed but currently the SPIR-V reader will make duplicates when
// at the top level of the SPIR-V reader flow. Consider rewriting this // the entire AST is cloned at the top level of the SPIR-V reader flow.
// to avoid this node-sharing. // Consider rewriting this to avoid this node-sharing.
return_members->push_back( return_members->push_back(
builder_.Member(member_name, member_type->Build(builder_), *decos)); builder_.Member(member_name, member_type->Build(builder_), *decos));
@ -1174,20 +1195,27 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
const ast::Expression* load_source = builder_.Expr(var_name); const ast::Expression* load_source = builder_.Expr(var_name);
// Index into the variable as needed to pick out the flattened member. // Index into the variable as needed to pick out the flattened member.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias(); auto* current_type =
var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) { for (auto index : index_prefix) {
if (auto* matrix_type = current_type->As<Matrix>()) { Switch(
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index)); current_type,
[&](const Matrix* matrix_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows); current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
} else if (auto* array_type = current_type->As<Array>()) { },
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index)); [&](const Array* array_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias(); current_type = array_type->type->UnwrapAlias();
} else if (auto* struct_type = current_type->As<Struct>()) { },
[&](const Struct* struct_type) {
load_source = builder_.MemberAccessor( load_source = builder_.MemberAccessor(
load_source, load_source, builder_.Expr(parser_impl_.GetMemberName(
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index))); *struct_type, index)));
current_type = struct_type->members[index]; current_type = struct_type->members[index];
} });
} }
if (is_builtin && (tip_type != forced_member_type)) { if (is_builtin && (tip_type != forced_member_type)) {
@ -1198,10 +1226,12 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
} }
return_exprs->push_back(load_source); return_exprs->push_back(load_source);
// Increment the location attribute, in case more parameters will follow. // Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(decos); IncrementLocation(decos);
return success(); return success();
});
} }
bool FunctionEmitter::EmitEntryPointAsWrapper() { bool FunctionEmitter::EmitEntryPointAsWrapper() {

View File

@ -239,11 +239,12 @@ bool GeneratorImpl::Generate() {
} }
last_kind = kind; last_kind = kind;
if (auto* global = decl->As<ast::Variable>()) { bool ok = Switch(
if (!EmitGlobalVariable(global)) { decl,
return false; [&](const ast::Variable* global) { //
} return EmitGlobalVariable(global);
} else if (auto* str = decl->As<ast::Struct>()) { },
[&](const ast::Struct* str) {
auto* ty = builder_.Sem().Get(str); auto* ty = builder_.Sem().Get(str);
auto storage_class_uses = ty->StorageClassUsage(); auto storage_class_uses = ty->StorageClassUsage();
if (storage_class_uses.size() != if (storage_class_uses.size() !=
@ -253,25 +254,26 @@ bool GeneratorImpl::Generate() {
// uniform buffer, so it needs to be emitted. // uniform buffer, so it needs to be emitted.
// Storage buffer are read and written to via a ByteAddressBuffer // Storage buffer are read and written to via a ByteAddressBuffer
// instead of true structure. // instead of true structure.
// Structures used as uniform buffer are read from an array of vectors // Structures used as uniform buffer are read from an array of
// instead of true structure. // vectors instead of true structure.
if (!EmitStructType(current_buffer_, ty)) { return EmitStructType(current_buffer_, ty);
return false;
} }
} return true;
} else if (auto* func = decl->As<ast::Function>()) { },
[&](const ast::Function* func) {
if (func->IsEntryPoint()) { if (func->IsEntryPoint()) {
if (!EmitEntryPointFunction(func)) { return EmitEntryPointFunction(func);
return false;
} }
} else { return EmitFunction(func);
if (!EmitFunction(func)) { },
return false; [&](Default) {
}
}
} else {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "unhandled module-scope declaration: " << decl->TypeInfo().name; << "unhandled module-scope declaration: "
<< decl->TypeInfo().name;
return false;
});
if (!ok) {
return false; return false;
} }
} }
@ -929,22 +931,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) { const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr); auto* call = builder_.Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
return Switch(
if (auto* func = target->As<sem::Function>()) { target,
[&](const sem::Function* func) {
return EmitFunctionCall(out, call, func); return EmitFunctionCall(out, call, func);
} },
if (auto* builtin = target->As<sem::Builtin>()) { [&](const sem::Builtin* builtin) {
return EmitBuiltinCall(out, call, builtin); return EmitBuiltinCall(out, call, builtin);
} },
if (auto* conv = target->As<sem::TypeConversion>()) { [&](const sem::TypeConversion* conv) {
return EmitTypeConversion(out, call, conv); return EmitTypeConversion(out, call, conv);
} },
if (auto* ctor = target->As<sem::TypeConstructor>()) { [&](const sem::TypeConstructor* ctor) {
return EmitTypeConstructor(out, call, ctor); return EmitTypeConstructor(out, call, ctor);
} },
[&](Default) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "unhandled call target: " << target->TypeInfo().name; << "unhandled call target: " << target->TypeInfo().name;
return false; return false;
});
} }
bool GeneratorImpl::EmitFunctionCall(std::ostream& out, bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@ -2639,35 +2644,38 @@ bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
bool GeneratorImpl::EmitExpression(std::ostream& out, bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) { const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
return EmitIndexAccessor(out, a); return EmitIndexAccessor(out, a);
} },
if (auto* b = expr->As<ast::BinaryExpression>()) { [&](const ast::BinaryExpression* b) { //
return EmitBinary(out, b); return EmitBinary(out, b);
} },
if (auto* b = expr->As<ast::BitcastExpression>()) { [&](const ast::BitcastExpression* b) { //
return EmitBitcast(out, b); return EmitBitcast(out, b);
} },
if (auto* c = expr->As<ast::CallExpression>()) { [&](const ast::CallExpression* c) { //
return EmitCall(out, c); return EmitCall(out, c);
} },
if (auto* i = expr->As<ast::IdentifierExpression>()) { [&](const ast::IdentifierExpression* i) { //
return EmitIdentifier(out, i); return EmitIdentifier(out, i);
} },
if (auto* l = expr->As<ast::LiteralExpression>()) { [&](const ast::LiteralExpression* l) { //
return EmitLiteral(out, l); return EmitLiteral(out, l);
} },
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { [&](const ast::MemberAccessorExpression* m) { //
return EmitMemberAccessor(out, m); return EmitMemberAccessor(out, m);
} },
if (auto* u = expr->As<ast::UnaryOpExpression>()) { [&](const ast::UnaryOpExpression* u) { //
return EmitUnaryOp(out, u); return EmitUnaryOp(out, u);
} },
[&](Default) { //
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unknown expression type: " + std::string(expr->TypeInfo().name)); "unknown expression type: " + std::string(expr->TypeInfo().name));
return false; return false;
});
} }
bool GeneratorImpl::EmitIdentifier(std::ostream& out, bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@ -3127,41 +3135,61 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
bool GeneratorImpl::EmitLiteral(std::ostream& out, bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) { const ast::LiteralExpression* lit) {
if (auto* l = lit->As<ast::BoolLiteralExpression>()) { return Switch(
lit,
[&](const ast::BoolLiteralExpression* l) {
out << (l->value ? "true" : "false"); out << (l->value ? "true" : "false");
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) { return true;
},
[&](const ast::FloatLiteralExpression* fl) {
if (std::isinf(fl->value)) { if (std::isinf(fl->value)) {
out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)"); out << (fl->value >= 0 ? "asfloat(0x7f800000u)"
: "asfloat(0xff800000u)");
} else if (std::isnan(fl->value)) { } else if (std::isnan(fl->value)) {
out << "asfloat(0x7fc00000u)"; out << "asfloat(0x7fc00000u)";
} else { } else {
out << FloatToString(fl->value) << "f"; out << FloatToString(fl->value) << "f";
} }
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) { return true;
},
[&](const ast::SintLiteralExpression* sl) {
out << sl->value; out << sl->value;
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) { return true;
},
[&](const ast::UintLiteralExpression* ul) {
out << ul->value << "u"; out << ul->value << "u";
} else { return true;
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown literal type"); diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false; return false;
} });
return true;
} }
bool GeneratorImpl::EmitValue(std::ostream& out, bool GeneratorImpl::EmitValue(std::ostream& out,
const sem::Type* type, const sem::Type* type,
int value) { int value) {
if (type->Is<sem::Bool>()) { return Switch(
type,
[&](const sem::Bool*) {
out << (value == 0 ? "false" : "true"); out << (value == 0 ? "false" : "true");
} else if (type->Is<sem::F32>()) { return true;
},
[&](const sem::F32*) {
out << value << ".0f"; out << value << ".0f";
} else if (type->Is<sem::I32>()) { return true;
},
[&](const sem::I32*) {
out << value; out << value;
} else if (type->Is<sem::U32>()) { return true;
},
[&](const sem::U32*) {
out << value << "u"; out << value << "u";
} else if (auto* vec = type->As<sem::Vector>()) { return true;
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, },
"")) { [&](const sem::Vector* vec) {
if (!EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false; return false;
} }
ScopedParen sp(out); ScopedParen sp(out);
@ -3173,9 +3201,11 @@ bool GeneratorImpl::EmitValue(std::ostream& out,
return false; return false;
} }
} }
} else if (auto* mat = type->As<sem::Matrix>()) { return true;
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite, },
"")) { [&](const sem::Matrix* mat) {
if (!EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false; return false;
} }
ScopedParen sp(out); ScopedParen sp(out);
@ -3187,20 +3217,26 @@ bool GeneratorImpl::EmitValue(std::ostream& out,
return false; return false;
} }
} }
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) { return true;
},
[&](const sem::Struct*) {
out << "("; out << "(";
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, TINT_DEFER(out << ")" << value);
"")) { return EmitType(out, type, ast::StorageClass::kNone,
return false; ast::Access::kUndefined, "");
} },
out << ")" << value; [&](const sem::Array*) {
} else { out << "(";
TINT_DEFER(out << ")" << value);
return EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kUndefined, "");
},
[&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"Invalid type for value emission: " + type->type_name()); "Invalid type for value emission: " + type->type_name());
return false; return false;
} });
return true;
} }
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
@ -3375,56 +3411,59 @@ bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
} }
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) { return Switch(
stmt,
[&](const ast::AssignmentStatement* a) { //
return EmitAssign(a); return EmitAssign(a);
} },
if (auto* b = stmt->As<ast::BlockStatement>()) { [&](const ast::BlockStatement* b) { //
return EmitBlock(b); return EmitBlock(b);
} },
if (auto* b = stmt->As<ast::BreakStatement>()) { [&](const ast::BreakStatement* b) { //
return EmitBreak(b); return EmitBreak(b);
} },
if (auto* c = stmt->As<ast::CallStatement>()) { [&](const ast::CallStatement* c) { //
auto out = line(); auto out = line();
if (!EmitCall(out, c->expr)) { if (!EmitCall(out, c->expr)) {
return false; return false;
} }
out << ";"; out << ";";
return true; return true;
} },
if (auto* c = stmt->As<ast::ContinueStatement>()) { [&](const ast::ContinueStatement* c) { //
return EmitContinue(c); return EmitContinue(c);
} },
if (auto* d = stmt->As<ast::DiscardStatement>()) { [&](const ast::DiscardStatement* d) { //
return EmitDiscard(d); return EmitDiscard(d);
} },
if (stmt->As<ast::FallthroughStatement>()) { [&](const ast::FallthroughStatement*) { //
line() << "/* fallthrough */"; line() << "/* fallthrough */";
return true; return true;
} },
if (auto* i = stmt->As<ast::IfStatement>()) { [&](const ast::IfStatement* i) { //
return EmitIf(i); return EmitIf(i);
} },
if (auto* l = stmt->As<ast::LoopStatement>()) { [&](const ast::LoopStatement* l) { //
return EmitLoop(l); return EmitLoop(l);
} },
if (auto* l = stmt->As<ast::ForLoopStatement>()) { [&](const ast::ForLoopStatement* l) { //
return EmitForLoop(l); return EmitForLoop(l);
} },
if (auto* r = stmt->As<ast::ReturnStatement>()) { [&](const ast::ReturnStatement* r) { //
return EmitReturn(r); return EmitReturn(r);
} },
if (auto* s = stmt->As<ast::SwitchStatement>()) { [&](const ast::SwitchStatement* s) { //
return EmitSwitch(s); return EmitSwitch(s);
} },
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { [&](const ast::VariableDeclStatement* v) { //
return EmitVariable(v->variable); return EmitVariable(v->variable);
} },
[&](Default) { //
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name)); "unknown statement type: " + std::string(stmt->TypeInfo().name));
return false; return false;
});
} }
bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) { bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
@ -3516,13 +3555,16 @@ bool GeneratorImpl::EmitType(std::ostream& out,
break; break;
} }
if (auto* ary = type->As<sem::Array>()) { return Switch(
type,
[&](const sem::Array* ary) {
const sem::Type* base_type = ary; const sem::Type* base_type = ary;
std::vector<uint32_t> sizes; std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) { while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) { if (arr->IsRuntimeSized()) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which should " << "Runtime arrays may only exist in storage buffers, which "
"should "
"have been transformed into a ByteAddressBuffer"; "have been transformed into a ByteAddressBuffer";
return false; return false;
} }
@ -3541,38 +3583,53 @@ bool GeneratorImpl::EmitType(std::ostream& out,
for (uint32_t size : sizes) { for (uint32_t size : sizes) {
out << "[" << size << "]"; out << "[" << size << "]";
} }
} else if (type->Is<sem::Bool>()) { return true;
},
[&](const sem::Bool*) {
out << "bool"; out << "bool";
} else if (type->Is<sem::F32>()) { return true;
},
[&](const sem::F32*) {
out << "float"; out << "float";
} else if (type->Is<sem::I32>()) { return true;
},
[&](const sem::I32*) {
out << "int"; out << "int";
} else if (auto* mat = type->As<sem::Matrix>()) { return true;
},
[&](const sem::Matrix* mat) {
if (!EmitType(out, mat->type(), storage_class, access, "")) { if (!EmitType(out, mat->type(), storage_class, access, "")) {
return false; return false;
} }
// Note: HLSL's matrices are declared as <type>NxM, where N is the number of // Note: HLSL's matrices are declared as <type>NxM, where N is the
// rows and M is the number of columns. Despite HLSL's matrices being // number of rows and M is the number of columns. Despite HLSL's
// column-major by default, the index operator and constructors actually // matrices being column-major by default, the index operator and
// operate on row-vectors, where as WGSL operates on column vectors. // constructors actually operate on row-vectors, where as WGSL operates
// To simplify everything we use the transpose of the matrices. // on column vectors. To simplify everything we use the transpose of the
// See: // matrices. See:
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
out << mat->columns() << "x" << mat->rows(); out << mat->columns() << "x" << mat->rows();
} else if (type->Is<sem::Pointer>()) { return true;
},
[&](const sem::Pointer*) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "Attempting to emit pointer type. These should have been removed " << "Attempting to emit pointer type. These should have been "
"with the InlinePointerLets transform"; "removed with the InlinePointerLets transform";
return false; return false;
} else if (auto* sampler = type->As<sem::Sampler>()) { },
[&](const sem::Sampler* sampler) {
out << "Sampler"; out << "Sampler";
if (sampler->IsComparison()) { if (sampler->IsComparison()) {
out << "Comparison"; out << "Comparison";
} }
out << "State"; out << "State";
} else if (auto* str = type->As<sem::Struct>()) { return true;
},
[&](const sem::Struct* str) {
out << StructName(str); out << StructName(str);
} else if (auto* tex = type->As<sem::Texture>()) { return true;
},
[&](const sem::Texture* tex) {
auto* storage = tex->As<sem::StorageTexture>(); auto* storage = tex->As<sem::StorageTexture>();
auto* ms = tex->As<sem::MultisampledTexture>(); auto* ms = tex->As<sem::MultisampledTexture>();
auto* depth_ms = tex->As<sem::DepthMultisampledTexture>(); auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
@ -3609,7 +3666,8 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} }
if (storage) { if (storage) {
auto* component = image_format_to_rwtexture_type(storage->texel_format()); auto* component =
image_format_to_rwtexture_type(storage->texel_format());
if (component == nullptr) { if (component == nullptr) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "Unsupported StorageTexture TexelFormat: " << "Unsupported StorageTexture TexelFormat: "
@ -3635,9 +3693,13 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} }
out << ">"; out << ">";
} }
} else if (type->Is<sem::U32>()) { return true;
},
[&](const sem::U32*) {
out << "uint"; out << "uint";
} else if (auto* vec = type->As<sem::Vector>()) { return true;
},
[&](const sem::Vector* vec) {
auto width = vec->Width(); auto width = vec->Width();
if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) { if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
out << "float" << width; out << "float" << width;
@ -3654,18 +3716,20 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} }
out << ", " << width << ">"; out << ", " << width << ">";
} }
} else if (auto* atomic = type->As<sem::Atomic>()) {
if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
return false;
}
} else if (type->Is<sem::Void>()) {
out << "void";
} else {
diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
return false;
}
return true; return true;
},
[&](const sem::Atomic* atomic) {
return EmitType(out, atomic->Type(), storage_class, access, name);
},
[&](const sem::Void*) {
out << "void";
return true;
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer,
"unknown type in EmitType");
return false;
});
} }
bool GeneratorImpl::EmitTypeAndName(std::ostream& out, bool GeneratorImpl::EmitTypeAndName(std::ostream& out,

View File

@ -538,23 +538,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) { const ast::CallExpression* expr) {
auto* call = program_->Sem().Get(expr); auto* call = program_->Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
return Switch(
if (auto* func = target->As<sem::Function>()) { target,
[&](const sem::Function* func) {
return EmitFunctionCall(out, call, func); return EmitFunctionCall(out, call, func);
} },
if (auto* builtin = target->As<sem::Builtin>()) { [&](const sem::Builtin* builtin) {
return EmitBuiltinCall(out, call, builtin); return EmitBuiltinCall(out, call, builtin);
} },
if (auto* conv = target->As<sem::TypeConversion>()) { [&](const sem::TypeConversion* conv) {
return EmitTypeConversion(out, call, conv); return EmitTypeConversion(out, call, conv);
} },
if (auto* ctor = target->As<sem::TypeConstructor>()) { [&](const sem::TypeConstructor* ctor) {
return EmitTypeConstructor(out, call, ctor); return EmitTypeConstructor(out, call, ctor);
} },
[&](Default) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "unhandled call target: " << target->TypeInfo().name; << "unhandled call target: " << target->TypeInfo().name;
return false; return false;
});
} }
bool GeneratorImpl::EmitFunctionCall(std::ostream& out, bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@ -1476,106 +1478,128 @@ bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
} }
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
if (type->Is<sem::Bool>()) { return Switch(
type,
[&](const sem::Bool*) {
out << "false"; out << "false";
} else if (type->Is<sem::F32>()) { return true;
},
[&](const sem::F32*) {
out << "0.0f"; out << "0.0f";
} else if (type->Is<sem::I32>()) { return true;
},
[&](const sem::I32*) {
out << "0"; out << "0";
} else if (type->Is<sem::U32>()) { return true;
},
[&](const sem::U32*) {
out << "0u"; out << "0u";
} else if (auto* vec = type->As<sem::Vector>()) { return true;
},
[&](const sem::Vector* vec) { //
return EmitZeroValue(out, vec->type()); return EmitZeroValue(out, vec->type());
} else if (auto* mat = type->As<sem::Matrix>()) { },
[&](const sem::Matrix* mat) {
if (!EmitType(out, mat, "")) { if (!EmitType(out, mat, "")) {
return false; return false;
} }
out << "("; out << "(";
if (!EmitZeroValue(out, mat->type())) { TINT_DEFER(out << ")");
return false; return EmitZeroValue(out, mat->type());
} },
out << ")"; [&](const sem::Array* arr) {
} else if (auto* arr = type->As<sem::Array>()) {
out << "{"; out << "{";
if (!EmitZeroValue(out, arr->ElemType())) { TINT_DEFER(out << "}");
return false; return EmitZeroValue(out, arr->ElemType());
} },
out << "}"; [&](const sem::Struct*) {
} else if (type->As<sem::Struct>()) {
out << "{}"; out << "{}";
} else { return true;
},
[&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"Invalid type for zero emission: " + type->type_name()); "Invalid type for zero emission: " + type->type_name());
return false; return false;
} });
return true;
} }
bool GeneratorImpl::EmitLiteral(std::ostream& out, bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) { const ast::LiteralExpression* lit) {
if (auto* l = lit->As<ast::BoolLiteralExpression>()) { return Switch(
lit,
[&](const ast::BoolLiteralExpression* l) {
out << (l->value ? "true" : "false"); out << (l->value ? "true" : "false");
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) { return true;
if (std::isinf(fl->value)) { },
out << (fl->value >= 0 ? "INFINITY" : "-INFINITY"); [&](const ast::FloatLiteralExpression* l) {
} else if (std::isnan(fl->value)) { if (std::isinf(l->value)) {
out << (l->value >= 0 ? "INFINITY" : "-INFINITY");
} else if (std::isnan(l->value)) {
out << "NAN"; out << "NAN";
} else { } else {
out << FloatToString(fl->value) << "f"; out << FloatToString(l->value) << "f";
}
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary
// minus and `2147483648` as separate tokens, and the latter doesn't
// fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid
// issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which
// ensures the expression type is `int`.
const auto int_min = std::numeric_limits<int32_t>::min();
if (sl->ValueAsI32() == int_min) {
out << "(" << int_min + 1 << " - 1)";
} else {
out << sl->value;
}
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
out << ul->value << "u";
} else {
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
} }
return true; return true;
},
[&](const ast::SintLiteralExpression* l) {
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary
// minus and `2147483648` as separate tokens, and the latter doesn't
// fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To
// avoid issues with `long` to `int` casts, emit `(2147483647 - 1)`
// instead, which ensures the expression type is `int`.
const auto int_min = std::numeric_limits<int32_t>::min();
if (l->ValueAsI32() == int_min) {
out << "(" << int_min + 1 << " - 1)";
} else {
out << l->value;
}
return true;
},
[&](const ast::UintLiteralExpression* l) {
out << l->value << "u";
return true;
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
});
} }
bool GeneratorImpl::EmitExpression(std::ostream& out, bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) { const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
return EmitIndexAccessor(out, a); return EmitIndexAccessor(out, a);
} },
if (auto* b = expr->As<ast::BinaryExpression>()) { [&](const ast::BinaryExpression* b) { //
return EmitBinary(out, b); return EmitBinary(out, b);
} },
if (auto* b = expr->As<ast::BitcastExpression>()) { [&](const ast::BitcastExpression* b) { //
return EmitBitcast(out, b); return EmitBitcast(out, b);
} },
if (auto* c = expr->As<ast::CallExpression>()) { [&](const ast::CallExpression* c) { //
return EmitCall(out, c); return EmitCall(out, c);
} },
if (auto* i = expr->As<ast::IdentifierExpression>()) { [&](const ast::IdentifierExpression* i) { //
return EmitIdentifier(out, i); return EmitIdentifier(out, i);
} },
if (auto* l = expr->As<ast::LiteralExpression>()) { [&](const ast::LiteralExpression* l) { //
return EmitLiteral(out, l); return EmitLiteral(out, l);
} },
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { [&](const ast::MemberAccessorExpression* m) { //
return EmitMemberAccessor(out, m); return EmitMemberAccessor(out, m);
} },
if (auto* u = expr->As<ast::UnaryOpExpression>()) { [&](const ast::UnaryOpExpression* u) { //
return EmitUnaryOp(out, u); return EmitUnaryOp(out, u);
} },
[&](Default) { //
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unknown expression type: " + std::string(expr->TypeInfo().name)); "unknown expression type: " + std::string(expr->TypeInfo().name));
return false; return false;
});
} }
void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) { void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) {
@ -2106,57 +2130,60 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
} }
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) { return Switch(
stmt,
[&](const ast::AssignmentStatement* a) { //
return EmitAssign(a); return EmitAssign(a);
} },
if (auto* b = stmt->As<ast::BlockStatement>()) { [&](const ast::BlockStatement* b) { //
return EmitBlock(b); return EmitBlock(b);
} },
if (auto* b = stmt->As<ast::BreakStatement>()) { [&](const ast::BreakStatement* b) { //
return EmitBreak(b); return EmitBreak(b);
} },
if (auto* c = stmt->As<ast::CallStatement>()) { [&](const ast::CallStatement* c) { //
auto out = line(); auto out = line();
if (!EmitCall(out, c->expr)) { if (!EmitCall(out, c->expr)) { //
return false; return false;
} }
out << ";"; out << ";";
return true; return true;
} },
if (auto* c = stmt->As<ast::ContinueStatement>()) { [&](const ast::ContinueStatement* c) { //
return EmitContinue(c); return EmitContinue(c);
} },
if (auto* d = stmt->As<ast::DiscardStatement>()) { [&](const ast::DiscardStatement* d) { //
return EmitDiscard(d); return EmitDiscard(d);
} },
if (stmt->As<ast::FallthroughStatement>()) { [&](const ast::FallthroughStatement*) { //
line() << "/* fallthrough */"; line() << "/* fallthrough */";
return true; return true;
} },
if (auto* i = stmt->As<ast::IfStatement>()) { [&](const ast::IfStatement* i) { //
return EmitIf(i); return EmitIf(i);
} },
if (auto* l = stmt->As<ast::LoopStatement>()) { [&](const ast::LoopStatement* l) { //
return EmitLoop(l); return EmitLoop(l);
} },
if (auto* l = stmt->As<ast::ForLoopStatement>()) { [&](const ast::ForLoopStatement* l) { //
return EmitForLoop(l); return EmitForLoop(l);
} },
if (auto* r = stmt->As<ast::ReturnStatement>()) { [&](const ast::ReturnStatement* r) { //
return EmitReturn(r); return EmitReturn(r);
} },
if (auto* s = stmt->As<ast::SwitchStatement>()) { [&](const ast::SwitchStatement* s) { //
return EmitSwitch(s); return EmitSwitch(s);
} },
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { [&](const ast::VariableDeclStatement* v) { //
auto* var = program_->Sem().Get(v->variable); auto* var = program_->Sem().Get(v->variable);
return EmitVariable(var); return EmitVariable(var);
} },
[&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name)); "unknown statement type: " + std::string(stmt->TypeInfo().name));
return false; return false;
});
} }
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) { bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
@ -2204,7 +2231,10 @@ bool GeneratorImpl::EmitType(std::ostream& out,
if (name_printed) { if (name_printed) {
*name_printed = false; *name_printed = false;
} }
if (auto* atomic = type->As<sem::Atomic>()) {
return Switch(
type,
[&](const sem::Atomic* atomic) {
if (atomic->Type()->Is<sem::I32>()) { if (atomic->Type()->Is<sem::I32>()) {
out << "atomic_int"; out << "atomic_int";
return true; return true;
@ -2216,9 +2246,8 @@ bool GeneratorImpl::EmitType(std::ostream& out,
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "unhandled atomic type " << atomic->Type()->type_name(); << "unhandled atomic type " << atomic->Type()->type_name();
return false; return false;
} },
[&](const sem::Array* ary) {
if (auto* ary = type->As<sem::Array>()) {
const sem::Type* base_type = ary; const sem::Type* base_type = ary;
std::vector<uint32_t> sizes; std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) { while (auto* arr = base_type->As<sem::Array>()) {
@ -2242,32 +2271,27 @@ bool GeneratorImpl::EmitType(std::ostream& out,
out << "[" << size << "]"; out << "[" << size << "]";
} }
return true; return true;
} },
[&](const sem::Bool*) {
if (type->Is<sem::Bool>()) {
out << "bool"; out << "bool";
return true; return true;
} },
[&](const sem::F32*) {
if (type->Is<sem::F32>()) {
out << "float"; out << "float";
return true; return true;
} },
[&](const sem::I32*) {
if (type->Is<sem::I32>()) {
out << "int"; out << "int";
return true; return true;
} },
[&](const sem::Matrix* mat) {
if (auto* mat = type->As<sem::Matrix>()) {
if (!EmitType(out, mat->type(), "")) { if (!EmitType(out, mat->type(), "")) {
return false; return false;
} }
out << mat->columns() << "x" << mat->rows(); out << mat->columns() << "x" << mat->rows();
return true; return true;
} },
[&](const sem::Pointer* ptr) {
if (auto* ptr = type->As<sem::Pointer>()) {
if (ptr->Access() == ast::Access::kRead) { if (ptr->Access() == ast::Access::kRead) {
out << "const "; out << "const ";
} }
@ -2293,21 +2317,18 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} }
} }
return true; return true;
} },
[&](const sem::Sampler*) {
if (type->Is<sem::Sampler>()) {
out << "sampler"; out << "sampler";
return true; return true;
} },
[&](const sem::Struct* str) {
if (auto* str = type->As<sem::Struct>()) { // The struct type emits as just the name. The declaration would be
// The struct type emits as just the name. The declaration would be emitted // emitted as part of emitting the declared types.
// as part of emitting the declared types.
out << StructName(str); out << StructName(str);
return true; return true;
} },
[&](const sem::Texture* tex) {
if (auto* tex = type->As<sem::Texture>()) {
if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) { if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
out << "depth"; out << "depth";
} else { } else {
@ -2343,11 +2364,19 @@ bool GeneratorImpl::EmitType(std::ostream& out,
out << "_ms"; out << "_ms";
} }
out << "<"; out << "<";
if (tex->Is<sem::DepthTexture>()) { TINT_DEFER(out << ">");
return Switch(
tex,
[&](const sem::DepthTexture*) {
out << "float, access::sample"; out << "float, access::sample";
} else if (tex->Is<sem::DepthMultisampledTexture>()) { return true;
},
[&](const sem::DepthMultisampledTexture*) {
out << "float, access::read"; out << "float, access::read";
} else if (auto* storage = tex->As<sem::StorageTexture>()) { return true;
},
[&](const sem::StorageTexture* storage) {
if (!EmitType(out, storage->type(), "")) { if (!EmitType(out, storage->type(), "")) {
return false; return false;
} }
@ -2358,49 +2387,54 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} else if (storage->access() == ast::Access::kWrite) { } else if (storage->access() == ast::Access::kWrite) {
out << ", access::write"; out << ", access::write";
} else { } else {
diagnostics_.add_error(diag::System::Writer, diagnostics_.add_error(
diag::System::Writer,
"Invalid access control for storage texture"); "Invalid access control for storage texture");
return false; return false;
} }
} else if (auto* ms = tex->As<sem::MultisampledTexture>()) { return true;
},
[&](const sem::MultisampledTexture* ms) {
if (!EmitType(out, ms->type(), "")) { if (!EmitType(out, ms->type(), "")) {
return false; return false;
} }
out << ", access::read"; out << ", access::read";
} else if (auto* sampled = tex->As<sem::SampledTexture>()) { return true;
},
[&](const sem::SampledTexture* sampled) {
if (!EmitType(out, sampled->type(), "")) { if (!EmitType(out, sampled->type(), "")) {
return false; return false;
} }
out << ", access::sample"; out << ", access::sample";
} else {
diagnostics_.add_error(diag::System::Writer, "invalid texture type");
return false;
}
out << ">";
return true; return true;
} },
[&](Default) {
if (type->Is<sem::U32>()) { diagnostics_.add_error(diag::System::Writer,
"invalid texture type");
return false;
});
},
[&](const sem::U32*) {
out << "uint"; out << "uint";
return true; return true;
} },
[&](const sem::Vector* vec) {
if (auto* vec = type->As<sem::Vector>()) {
if (!EmitType(out, vec->type(), "")) { if (!EmitType(out, vec->type(), "")) {
return false; return false;
} }
out << vec->Width(); out << vec->Width();
return true; return true;
} },
[&](const sem::Void*) {
if (type->Is<sem::Void>()) {
out << "void"; out << "void";
return true; return true;
} },
[&](Default) {
diagnostics_.add_error(diag::System::Writer, diagnostics_.add_error(
diag::System::Writer,
"unknown type in EmitType: " + type->type_name()); "unknown type in EmitType: " + type->type_name());
return false; return false;
});
} }
bool GeneratorImpl::EmitTypeAndName(std::ostream& out, bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
@ -2542,18 +2576,23 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
// Emit attributes // Emit attributes
if (auto* decl = mem->Declaration()) { if (auto* decl = mem->Declaration()) {
for (auto* attr : decl->attributes) { for (auto* attr : decl->attributes) {
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { bool ok = Switch(
attr,
[&](const ast::BuiltinAttribute* builtin) {
auto name = builtin_to_attribute(builtin->builtin); auto name = builtin_to_attribute(builtin->builtin);
if (name.empty()) { if (name.empty()) {
diagnostics_.add_error(diag::System::Writer, "unknown builtin"); diagnostics_.add_error(diag::System::Writer, "unknown builtin");
return false; return false;
} }
out << " [[" << name << "]]"; out << " [[" << name << "]]";
} else if (auto* loc = attr->As<ast::LocationAttribute>()) { return true;
},
[&](const ast::LocationAttribute* loc) {
auto& pipeline_stage_uses = str->PipelineStageUses(); auto& pipeline_stage_uses = str->PipelineStageUses();
if (pipeline_stage_uses.size() != 1) { if (pipeline_stage_uses.size() != 1) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "invalid entry point IO struct uses"; << "invalid entry point IO struct uses";
return false;
} }
if (pipeline_stage_uses.count( if (pipeline_stage_uses.count(
@ -2570,9 +2609,12 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
out << " [[color(" + std::to_string(loc->value) + ")]]"; out << " [[color(" + std::to_string(loc->value) + ")]]";
} else { } else {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "invalid use of location attribute"; << "invalid use of location decoration";
return false;
} }
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
auto name = interpolation_to_attribute(interpolate->type, auto name = interpolation_to_attribute(interpolate->type,
interpolate->sampling); interpolate->sampling);
if (name.empty()) { if (name.empty()) {
@ -2581,16 +2623,25 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
return false; return false;
} }
out << " [[" << name << "]]"; out << " [[" << name << "]]";
} else if (attr->Is<ast::InvariantAttribute>()) { return true;
},
[&](const ast::InvariantAttribute*) {
if (invariant_define_name_.empty()) { if (invariant_define_name_.empty()) {
invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT"); invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
} }
out << " " << invariant_define_name_; out << " " << invariant_define_name_;
} else if (!attr->IsAnyOf<ast::StructMemberOffsetAttribute, return true;
ast::StructMemberAlignAttribute, },
ast::StructMemberSizeAttribute>()) { [&](const ast::StructMemberOffsetAttribute*) { return true; },
[&](const ast::StructMemberAlignAttribute*) { return true; },
[&](const ast::StructMemberSizeAttribute*) { return true; },
[&](Default) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "unhandled struct member attribute: " << attr->Name(); << "unhandled struct member attribute: " << attr->Name();
return false;
});
if (!ok) {
return false;
} }
} }
} }
@ -2796,13 +2847,22 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign( GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
const sem::Type* ty) { const sem::Type* ty) {
if (ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) { return Switch(
ty,
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// 2.1 Scalar Data Types // 2.1 Scalar Data Types
return {4, 4}; [&](const sem::U32*) {
} return SizeAndAlign{4, 4};
},
[&](const sem::I32*) {
return SizeAndAlign{4, 4};
},
[&](const sem::F32*) {
return SizeAndAlign{4, 4};
},
if (auto* vec = ty->As<sem::Vector>()) { [&](const sem::Vector* vec) {
auto num_els = vec->Width(); auto num_els = vec->Width();
auto* el_ty = vec->type(); auto* el_ty = vec->type();
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) { if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
@ -2817,9 +2877,12 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
return SizeAndAlign{num_els * 4, num_els * 4}; return SizeAndAlign{num_els * 4, num_els * 4};
} }
} }
} TINT_UNREACHABLE(Writer, diagnostics_)
<< "Unhandled vector element type " << el_ty->TypeInfo().name;
return SizeAndAlign{};
},
if (auto* mat = ty->As<sem::Matrix>()) { [&](const sem::Matrix* mat) {
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// 2.3 Matrix Data Types // 2.3 Matrix Data Types
auto cols = mat->columns(); auto cols = mat->columns();
@ -2841,32 +2904,39 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
return table[(3 * (cols - 2)) + (rows - 2)]; return table[(3 * (cols - 2)) + (rows - 2)];
} }
} }
}
if (auto* arr = ty->As<sem::Array>()) { TINT_UNREACHABLE(Writer, diagnostics_)
<< "Unhandled matrix element type " << el_ty->TypeInfo().name;
return SizeAndAlign{};
},
[&](const sem::Array* arr) {
if (!arr->IsStrideImplicit()) { if (!arr->IsStrideImplicit()) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "arrays with explicit strides should have " << "arrays with explicit strides should have "
"removed with the PadArrayElements transform"; "removed with the PadArrayElements transform";
return {}; return SizeAndAlign{};
} }
auto num_els = std::max<uint32_t>(arr->Count(), 1); auto num_els = std::max<uint32_t>(arr->Count(), 1);
return SizeAndAlign{arr->Stride() * num_els, arr->Align()}; return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
} },
if (auto* str = ty->As<sem::Struct>()) { [&](const sem::Struct* str) {
// TODO(crbug.com/tint/650): There's an assumption here that MSL's default // TODO(crbug.com/tint/650): There's an assumption here that MSL's
// structure size and alignment matches WGSL's. We need to confirm this. // default structure size and alignment matches WGSL's. We need to
// confirm this.
return SizeAndAlign{str->Size(), str->Align()}; return SizeAndAlign{str->Size(), str->Align()};
} },
if (auto* atomic = ty->As<sem::Atomic>()) { [&](const sem::Atomic* atomic) {
return MslPackedTypeSizeAndAlign(atomic->Type()); return MslPackedTypeSizeAndAlign(atomic->Type());
} },
[&](Default) {
TINT_UNREACHABLE(Writer, diagnostics_) TINT_UNREACHABLE(Writer, diagnostics_)
<< "Unhandled type " << ty->TypeInfo().name; << "Unhandled type " << ty->TypeInfo().name;
return {}; return SizeAndAlign{};
});
} }
template <typename F> template <typename F>

View File

@ -560,33 +560,37 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
} }
uint32_t Builder::GenerateExpression(const ast::Expression* expr) { uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
return GenerateAccessorExpression(a); return GenerateAccessorExpression(a);
} },
if (auto* b = expr->As<ast::BinaryExpression>()) { [&](const ast::BinaryExpression* b) { //
return GenerateBinaryExpression(b); return GenerateBinaryExpression(b);
} },
if (auto* b = expr->As<ast::BitcastExpression>()) { [&](const ast::BitcastExpression* b) { //
return GenerateBitcastExpression(b); return GenerateBitcastExpression(b);
} },
if (auto* c = expr->As<ast::CallExpression>()) { [&](const ast::CallExpression* c) { //
return GenerateCallExpression(c); return GenerateCallExpression(c);
} },
if (auto* i = expr->As<ast::IdentifierExpression>()) { [&](const ast::IdentifierExpression* i) { //
return GenerateIdentifierExpression(i); return GenerateIdentifierExpression(i);
} },
if (auto* l = expr->As<ast::LiteralExpression>()) { [&](const ast::LiteralExpression* l) { //
return GenerateLiteralIfNeeded(nullptr, l); return GenerateLiteralIfNeeded(nullptr, l);
} },
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { [&](const ast::MemberAccessorExpression* m) { //
return GenerateAccessorExpression(m); return GenerateAccessorExpression(m);
} },
if (auto* u = expr->As<ast::UnaryOpExpression>()) { [&](const ast::UnaryOpExpression* u) { //
return GenerateUnaryOpExpression(u); return GenerateUnaryOpExpression(u);
} },
[&](Default) -> uint32_t {
error_ = "unknown expression type: " + std::string(expr->TypeInfo().name); error_ =
"unknown expression type: " + std::string(expr->TypeInfo().name);
return 0; return 0;
});
} }
bool Builder::GenerateFunction(const ast::Function* func_ast) { bool Builder::GenerateFunction(const ast::Function* func_ast) {
@ -861,34 +865,57 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
push_type(spv::Op::OpVariable, std::move(ops)); push_type(spv::Op::OpVariable, std::move(ops));
for (auto* attr : var->attributes) { for (auto* attr : var->attributes) {
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { bool ok = Switch(
attr,
[&](const ast::BuiltinAttribute* builtin) {
push_annot(spv::Op::OpDecorate, push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn), {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
Operand::Int( Operand::Int(ConvertBuiltin(builtin->builtin,
ConvertBuiltin(builtin->builtin, sem->StorageClass()))}); sem->StorageClass()))});
} else if (auto* location = attr->As<ast::LocationAttribute>()) { return true;
},
[&](const ast::LocationAttribute* location) {
push_annot(spv::Op::OpDecorate, push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationLocation), {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
Operand::Int(location->value)}); Operand::Int(location->value)});
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
AddInterpolationDecorations(var_id, interpolate->type, AddInterpolationDecorations(var_id, interpolate->type,
interpolate->sampling); interpolate->sampling);
} else if (attr->Is<ast::InvariantAttribute>()) { return true;
push_annot(spv::Op::OpDecorate, },
[&](const ast::InvariantAttribute*) {
push_annot(
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)}); {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
} else if (auto* binding = attr->As<ast::BindingAttribute>()) { return true;
},
[&](const ast::BindingAttribute* binding) {
push_annot(spv::Op::OpDecorate, push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBinding), {Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
Operand::Int(binding->value)}); Operand::Int(binding->value)});
} else if (auto* group = attr->As<ast::GroupAttribute>()) { return true;
push_annot(spv::Op::OpDecorate, {Operand::Int(var_id), },
Operand::Int(SpvDecorationDescriptorSet), [&](const ast::GroupAttribute* group) {
push_annot(
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
Operand::Int(group->value)}); Operand::Int(group->value)});
} else if (attr->Is<ast::OverrideAttribute>()) { return true;
// Spec constants are handled elsewhere },
} else if (!attr->Is<ast::InternalAttribute>()) { [&](const ast::OverrideAttribute*) {
return true; // Spec constants are handled elsewhere
},
[&](const ast::InternalAttribute*) {
return true; // ignored
},
[&](Default) {
error_ = "unknown attribute"; error_ = "unknown attribute";
return false; return false;
});
if (!ok) {
return false;
} }
} }
@ -1123,19 +1150,21 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
// promoted to storage with the VarForDynamicIndex transform. // promoted to storage with the VarForDynamicIndex transform.
for (auto* accessor : accessors) { for (auto* accessor : accessors) {
if (auto* array = accessor->As<ast::IndexAccessorExpression>()) { bool ok = Switch(
if (!GenerateIndexAccessor(array, &info)) { accessor,
return 0; [&](const ast::IndexAccessorExpression* array) {
} return GenerateIndexAccessor(array, &info);
} else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) { },
if (!GenerateMemberAccessor(member, &info)) { [&](const ast::MemberAccessorExpression* member) {
return 0; return GenerateMemberAccessor(member, &info);
} },
[&](Default) {
} else { error_ = "invalid accessor in list: " +
error_ = std::string(accessor->TypeInfo().name);
"invalid accessor in list: " + std::string(accessor->TypeInfo().name); return false;
return 0; });
if (!ok) {
return false;
} }
} }
@ -1653,21 +1682,28 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
constant.constant_id = global->ConstantId(); constant.constant_id = global->ConstantId();
} }
if (auto* l = lit->As<ast::BoolLiteralExpression>()) { Switch(
lit,
[&](const ast::BoolLiteralExpression* l) {
constant.kind = ScalarConstant::Kind::kBool; constant.kind = ScalarConstant::Kind::kBool;
constant.value.b = l->value; constant.value.b = l->value;
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) { },
[&](const ast::SintLiteralExpression* sl) {
constant.kind = ScalarConstant::Kind::kI32; constant.kind = ScalarConstant::Kind::kI32;
constant.value.i32 = sl->value; constant.value.i32 = sl->value;
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) { },
[&](const ast::UintLiteralExpression* ul) {
constant.kind = ScalarConstant::Kind::kU32; constant.kind = ScalarConstant::Kind::kU32;
constant.value.u32 = ul->value; constant.value.u32 = ul->value;
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) { },
[&](const ast::FloatLiteralExpression* fl) {
constant.kind = ScalarConstant::Kind::kF32; constant.kind = ScalarConstant::Kind::kF32;
constant.value.f32 = fl->value; constant.value.f32 = fl->value;
} else { },
error_ = "unknown literal type"; [&](Default) { error_ = "unknown literal type"; });
return 0;
if (!error_.empty()) {
return false;
} }
return GenerateConstantIfNeeded(constant); return GenerateConstantIfNeeded(constant);
@ -2209,19 +2245,25 @@ bool Builder::GenerateBlockStatementWithoutScoping(
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) { uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr); auto* call = builder_.Sem().Get(expr);
auto* target = call->Target(); auto* target = call->Target();
return Switch(
if (auto* func = target->As<sem::Function>()) { target,
[&](const sem::Function* func) {
return GenerateFunctionCall(call, func); return GenerateFunctionCall(call, func);
} },
if (auto* builtin = target->As<sem::Builtin>()) { [&](const sem::Builtin* builtin) {
return GenerateBuiltinCall(call, builtin); return GenerateBuiltinCall(call, builtin);
} },
if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) { [&](const sem::TypeConversion*) {
return GenerateTypeConstructorOrConversion(call, nullptr); return GenerateTypeConstructorOrConversion(call, nullptr);
} },
[&](const sem::TypeConstructor*) {
return GenerateTypeConstructorOrConversion(call, nullptr);
},
[&](Default) -> uint32_t {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name; << "unhandled call target: " << target->TypeInfo().name;
return false; return 0;
});
} }
uint32_t Builder::GenerateFunctionCall(const sem::Call* call, uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
@ -3790,46 +3832,49 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
} }
bool Builder::GenerateStatement(const ast::Statement* stmt) { bool Builder::GenerateStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) { return Switch(
stmt,
[&](const ast::AssignmentStatement* a) {
return GenerateAssignStatement(a); return GenerateAssignStatement(a);
} },
if (auto* b = stmt->As<ast::BlockStatement>()) { [&](const ast::BlockStatement* b) { //
return GenerateBlockStatement(b); return GenerateBlockStatement(b);
} },
if (auto* b = stmt->As<ast::BreakStatement>()) { [&](const ast::BreakStatement* b) { //
return GenerateBreakStatement(b); return GenerateBreakStatement(b);
} },
if (auto* c = stmt->As<ast::CallStatement>()) { [&](const ast::CallStatement* c) {
return GenerateCallExpression(c->expr) != 0; return GenerateCallExpression(c->expr) != 0;
} },
if (auto* c = stmt->As<ast::ContinueStatement>()) { [&](const ast::ContinueStatement* c) {
return GenerateContinueStatement(c); return GenerateContinueStatement(c);
} },
if (auto* d = stmt->As<ast::DiscardStatement>()) { [&](const ast::DiscardStatement* d) {
return GenerateDiscardStatement(d); return GenerateDiscardStatement(d);
} },
if (stmt->Is<ast::FallthroughStatement>()) { [&](const ast::FallthroughStatement*) {
// Do nothing here, the fallthrough gets handled by the switch code. // Do nothing here, the fallthrough gets handled by the switch code.
return true; return true;
} },
if (auto* i = stmt->As<ast::IfStatement>()) { [&](const ast::IfStatement* i) { //
return GenerateIfStatement(i); return GenerateIfStatement(i);
} },
if (auto* l = stmt->As<ast::LoopStatement>()) { [&](const ast::LoopStatement* l) { //
return GenerateLoopStatement(l); return GenerateLoopStatement(l);
} },
if (auto* r = stmt->As<ast::ReturnStatement>()) { [&](const ast::ReturnStatement* r) { //
return GenerateReturnStatement(r); return GenerateReturnStatement(r);
} },
if (auto* s = stmt->As<ast::SwitchStatement>()) { [&](const ast::SwitchStatement* s) { //
return GenerateSwitchStatement(s); return GenerateSwitchStatement(s);
} },
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { [&](const ast::VariableDeclStatement* v) {
return GenerateVariableDeclStatement(v); return GenerateVariableDeclStatement(v);
} },
[&](Default) {
error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name); error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
return false; return false;
});
} }
bool Builder::GenerateVariableDeclStatement( bool Builder::GenerateVariableDeclStatement(
@ -3872,78 +3917,91 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t { return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
auto result = result_op(); auto result = result_op();
auto id = result.to_i(); auto id = result.to_i();
if (auto* arr = type->As<sem::Array>()) { bool ok = Switch(
if (!GenerateArrayType(arr, result)) { type,
return 0; [&](const sem::Array* arr) { //
} return GenerateArrayType(arr, result);
} else if (type->Is<sem::Bool>()) { },
[&](const sem::Bool*) {
push_type(spv::Op::OpTypeBool, {result}); push_type(spv::Op::OpTypeBool, {result});
} else if (type->Is<sem::F32>()) { return true;
},
[&](const sem::F32*) {
push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)}); push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
} else if (type->Is<sem::I32>()) { return true;
},
[&](const sem::I32*) {
push_type(spv::Op::OpTypeInt, push_type(spv::Op::OpTypeInt,
{result, Operand::Int(32), Operand::Int(1)}); {result, Operand::Int(32), Operand::Int(1)});
} else if (auto* mat = type->As<sem::Matrix>()) { return true;
if (!GenerateMatrixType(mat, result)) { },
return 0; [&](const sem::Matrix* mat) { //
} return GenerateMatrixType(mat, result);
} else if (auto* ptr = type->As<sem::Pointer>()) { },
if (!GeneratePointerType(ptr, result)) { [&](const sem::Pointer* ptr) { //
return 0; return GeneratePointerType(ptr, result);
} },
} else if (auto* ref = type->As<sem::Reference>()) { [&](const sem::Reference* ref) { //
if (!GenerateReferenceType(ref, result)) { return GenerateReferenceType(ref, result);
return 0; },
} [&](const sem::Struct* str) { //
} else if (auto* str = type->As<sem::Struct>()) { return GenerateStructType(str, result);
if (!GenerateStructType(str, result)) { },
return 0; [&](const sem::U32*) {
}
} else if (type->Is<sem::U32>()) {
push_type(spv::Op::OpTypeInt, push_type(spv::Op::OpTypeInt,
{result, Operand::Int(32), Operand::Int(0)}); {result, Operand::Int(32), Operand::Int(0)});
} else if (auto* vec = type->As<sem::Vector>()) { return true;
if (!GenerateVectorType(vec, result)) { },
return 0; [&](const sem::Vector* vec) { //
} return GenerateVectorType(vec, result);
} else if (type->Is<sem::Void>()) { },
[&](const sem::Void*) {
push_type(spv::Op::OpTypeVoid, {result}); push_type(spv::Op::OpTypeVoid, {result});
} else if (auto* tex = type->As<sem::Texture>()) { return true;
},
[&](const sem::StorageTexture* tex) {
if (!GenerateTextureType(tex, result)) { if (!GenerateTextureType(tex, result)) {
return 0; return false;
} }
if (auto* st = tex->As<sem::StorageTexture>()) { // Register all three access types of StorageTexture names. In
// Register all three access types of StorageTexture names. In SPIR-V, // SPIR-V, we must output a single type, while the variable is
// we must output a single type, while the variable is annotated with // annotated with the access type. Doing this ensures we de-dupe.
// the access type. Doing this ensures we de-dupe.
type_name_to_id_[builder_ type_name_to_id_[builder_
.create<sem::StorageTexture>( .create<sem::StorageTexture>(
st->dim(), st->texel_format(), tex->dim(), tex->texel_format(),
ast::Access::kRead, st->type()) ast::Access::kRead, tex->type())
->type_name()] = id; ->type_name()] = id;
type_name_to_id_[builder_ type_name_to_id_[builder_
.create<sem::StorageTexture>( .create<sem::StorageTexture>(
st->dim(), st->texel_format(), tex->dim(), tex->texel_format(),
ast::Access::kWrite, st->type()) ast::Access::kWrite, tex->type())
->type_name()] = id; ->type_name()] = id;
type_name_to_id_[builder_ type_name_to_id_[builder_
.create<sem::StorageTexture>( .create<sem::StorageTexture>(
st->dim(), st->texel_format(), tex->dim(), tex->texel_format(),
ast::Access::kReadWrite, st->type()) ast::Access::kReadWrite, tex->type())
->type_name()] = id; ->type_name()] = id;
} return true;
},
} else if (type->Is<sem::Sampler>()) { [&](const sem::Texture* tex) {
return GenerateTextureType(tex, result);
},
[&](const sem::Sampler*) {
push_type(spv::Op::OpTypeSampler, {result}); push_type(spv::Op::OpTypeSampler, {result});
// Register both of the sampler type names. In SPIR-V they're the same // Register both of the sampler type names. In SPIR-V they're the same
// sampler type, so we need to match that when we do the dedup check. // sampler type, so we need to match that when we do the dedup check.
type_name_to_id_["__sampler_sampler"] = id; type_name_to_id_["__sampler_sampler"] = id;
type_name_to_id_["__sampler_comparison"] = id; type_name_to_id_["__sampler_comparison"] = id;
return true;
} else { },
[&](Default) {
error_ = "unable to convert type: " + type->type_name(); error_ = "unable to convert type: " + type->type_name();
return false;
});
if (!ok) {
return 0; return 0;
} }
@ -3995,22 +4053,31 @@ bool Builder::GenerateTextureType(const sem::Texture* texture,
} }
if (dim == ast::TextureDimension::kCubeArray) { if (dim == ast::TextureDimension::kCubeArray) {
if (texture->Is<sem::SampledTexture>() || if (texture->IsAnyOf<sem::SampledTexture, sem::DepthTexture>()) {
texture->Is<sem::DepthTexture>()) {
push_capability(SpvCapabilitySampledCubeArray); push_capability(SpvCapabilitySampledCubeArray);
} }
} }
uint32_t type_id = 0u; uint32_t type_id = Switch(
if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) { texture,
type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>()); [&](const sem::DepthTexture*) {
} else if (auto* s = texture->As<sem::SampledTexture>()) { return GenerateTypeIfNeeded(builder_.create<sem::F32>());
type_id = GenerateTypeIfNeeded(s->type()); },
} else if (auto* ms = texture->As<sem::MultisampledTexture>()) { [&](const sem::DepthMultisampledTexture*) {
type_id = GenerateTypeIfNeeded(ms->type()); return GenerateTypeIfNeeded(builder_.create<sem::F32>());
} else if (auto* st = texture->As<sem::StorageTexture>()) { },
type_id = GenerateTypeIfNeeded(st->type()); [&](const sem::SampledTexture* t) {
} return GenerateTypeIfNeeded(t->type());
},
[&](const sem::MultisampledTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](const sem::StorageTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](Default) -> uint32_t { //
return 0u;
});
if (type_id == 0u) { if (type_id == 0u) {
return false; return false;
} }

View File

@ -68,23 +68,17 @@ GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() { bool GeneratorImpl::Generate() {
// Generate global declarations in the order they appear in the module. // Generate global declarations in the order they appear in the module.
for (auto* decl : program_->AST().GlobalDeclarations()) { for (auto* decl : program_->AST().GlobalDeclarations()) {
if (auto* td = decl->As<ast::TypeDecl>()) { if (!Switch(
if (!EmitTypeDecl(td)) { decl, //
return false; [&](const ast::TypeDecl* td) { return EmitTypeDecl(td); },
} [&](const ast::Function* func) { return EmitFunction(func); },
} else if (auto* func = decl->As<ast::Function>()) { [&](const ast::Variable* var) { return EmitVariable(line(), var); },
if (!EmitFunction(func)) { [&](Default) {
return false;
}
} else if (auto* var = decl->As<ast::Variable>()) {
if (!EmitVariable(line(), var)) {
return false;
}
} else {
TINT_UNREACHABLE(Writer, diagnostics_); TINT_UNREACHABLE(Writer, diagnostics_);
return false; return false;
})) {
return false;
} }
if (decl != program_->AST().GlobalDeclarations().back()) { if (decl != program_->AST().GlobalDeclarations().back()) {
line(); line();
} }
@ -94,59 +88,64 @@ bool GeneratorImpl::Generate() {
} }
bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) { bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
if (auto* alias = ty->As<ast::Alias>()) { return Switch(
ty,
[&](const ast::Alias* alias) { //
auto out = line(); auto out = line();
out << "type " << program_->Symbols().NameFor(alias->name) << " = "; out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
if (!EmitType(out, alias->type)) { if (!EmitType(out, alias->type)) {
return false; return false;
} }
out << ";"; out << ";";
} else if (auto* str = ty->As<ast::Struct>()) { return true;
if (!EmitStructType(str)) { },
return false; [&](const ast::Struct* str) { //
} return EmitStructType(str);
} else { },
[&](Default) { //
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unknown declared type: " + std::string(ty->TypeInfo().name)); "unknown declared type: " + std::string(ty->TypeInfo().name));
return false; return false;
} });
return true;
} }
bool GeneratorImpl::EmitExpression(std::ostream& out, bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) { const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
return EmitIndexAccessor(out, a); return EmitIndexAccessor(out, a);
} },
if (auto* b = expr->As<ast::BinaryExpression>()) { [&](const ast::BinaryExpression* b) { //
return EmitBinary(out, b); return EmitBinary(out, b);
} },
if (auto* b = expr->As<ast::BitcastExpression>()) { [&](const ast::BitcastExpression* b) { //
return EmitBitcast(out, b); return EmitBitcast(out, b);
} },
if (auto* c = expr->As<ast::CallExpression>()) { [&](const ast::CallExpression* c) { //
return EmitCall(out, c); return EmitCall(out, c);
} },
if (auto* i = expr->As<ast::IdentifierExpression>()) { [&](const ast::IdentifierExpression* i) { //
return EmitIdentifier(out, i); return EmitIdentifier(out, i);
} },
if (auto* l = expr->As<ast::LiteralExpression>()) { [&](const ast::LiteralExpression* l) { //
return EmitLiteral(out, l); return EmitLiteral(out, l);
} },
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { [&](const ast::MemberAccessorExpression* m) { //
return EmitMemberAccessor(out, m); return EmitMemberAccessor(out, m);
} },
if (expr->Is<ast::PhonyExpression>()) { [&](const ast::PhonyExpression*) { //
out << "_"; out << "_";
return true; return true;
} },
if (auto* u = expr->As<ast::UnaryOpExpression>()) { [&](const ast::UnaryOpExpression* u) { //
return EmitUnaryOp(out, u); return EmitUnaryOp(out, u);
} },
[&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown expression type"); diagnostics_.add_error(diag::System::Writer, "unknown expression type");
return false; return false;
});
} }
bool GeneratorImpl::EmitIndexAccessor( bool GeneratorImpl::EmitIndexAccessor(
@ -250,19 +249,28 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
bool GeneratorImpl::EmitLiteral(std::ostream& out, bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) { const ast::LiteralExpression* lit) {
if (auto* bl = lit->As<ast::BoolLiteralExpression>()) { return Switch(
lit,
[&](const ast::BoolLiteralExpression* bl) { //
out << (bl->value ? "true" : "false"); out << (bl->value ? "true" : "false");
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) { return true;
},
[&](const ast::FloatLiteralExpression* fl) { //
out << FloatToBitPreservingString(fl->value); out << FloatToBitPreservingString(fl->value);
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) { return true;
},
[&](const ast::SintLiteralExpression* sl) { //
out << sl->value; out << sl->value;
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) { return true;
},
[&](const ast::UintLiteralExpression* ul) { //
out << ul->value << "u"; out << ul->value << "u";
} else { return true;
},
[&](Default) { //
diagnostics_.add_error(diag::System::Writer, "unknown literal type"); diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false; return false;
} });
return true;
} }
bool GeneratorImpl::EmitIdentifier(std::ostream& out, bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@ -366,7 +374,9 @@ bool GeneratorImpl::EmitAccess(std::ostream& out, const ast::Access access) {
} }
bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) { bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
if (auto* ary = ty->As<ast::Array>()) { return Switch(
ty,
[&](const ast::Array* ary) {
for (auto* attr : ary->attributes) { for (auto* attr : ary->attributes) {
if (auto* stride = attr->As<ast::StrideAttribute>()) { if (auto* stride = attr->As<ast::StrideAttribute>()) {
out << "@stride(" << stride->stride << ") "; out << "@stride(" << stride->stride << ") ";
@ -386,13 +396,21 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
} }
out << ">"; out << ">";
} else if (ty->Is<ast::Bool>()) { return true;
},
[&](const ast::Bool*) {
out << "bool"; out << "bool";
} else if (ty->Is<ast::F32>()) { return true;
},
[&](const ast::F32*) {
out << "f32"; out << "f32";
} else if (ty->Is<ast::I32>()) { return true;
},
[&](const ast::I32*) {
out << "i32"; out << "i32";
} else if (auto* mat = ty->As<ast::Matrix>()) { return true;
},
[&](const ast::Matrix* mat) {
out << "mat" << mat->columns << "x" << mat->rows; out << "mat" << mat->columns << "x" << mat->rows;
if (auto* el_ty = mat->type) { if (auto* el_ty = mat->type) {
out << "<"; out << "<";
@ -401,7 +419,9 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
} }
out << ">"; out << ">";
} }
} else if (auto* ptr = ty->As<ast::Pointer>()) { return true;
},
[&](const ast::Pointer* ptr) {
out << "ptr<" << ptr->storage_class << ", "; out << "ptr<" << ptr->storage_class << ", ";
if (!EmitType(out, ptr->type)) { if (!EmitType(out, ptr->type)) {
return false; return false;
@ -413,34 +433,58 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
} }
} }
out << ">"; out << ">";
} else if (auto* atomic = ty->As<ast::Atomic>()) { return true;
},
[&](const ast::Atomic* atomic) {
out << "atomic<"; out << "atomic<";
if (!EmitType(out, atomic->type)) { if (!EmitType(out, atomic->type)) {
return false; return false;
} }
out << ">"; out << ">";
} else if (auto* sampler = ty->As<ast::Sampler>()) { return true;
},
[&](const ast::Sampler* sampler) {
out << "sampler"; out << "sampler";
if (sampler->IsComparison()) { if (sampler->IsComparison()) {
out << "_comparison"; out << "_comparison";
} }
} else if (ty->Is<ast::ExternalTexture>()) { return true;
},
[&](const ast::ExternalTexture*) {
out << "texture_external"; out << "texture_external";
} else if (auto* texture = ty->As<ast::Texture>()) { return true;
},
[&](const ast::Texture* texture) {
out << "texture_"; out << "texture_";
if (texture->Is<ast::DepthTexture>()) { bool ok = Switch(
texture,
[&](const ast::DepthTexture*) { //
out << "depth_"; out << "depth_";
} else if (texture->Is<ast::DepthMultisampledTexture>()) { return true;
},
[&](const ast::DepthMultisampledTexture*) { //
out << "depth_multisampled_"; out << "depth_multisampled_";
} else if (texture->Is<ast::SampledTexture>()) { return true;
},
[&](const ast::SampledTexture*) { //
/* nothing to emit */ /* nothing to emit */
} else if (texture->Is<ast::MultisampledTexture>()) { return true;
},
[&](const ast::MultisampledTexture*) { //
out << "multisampled_"; out << "multisampled_";
} else if (texture->Is<ast::StorageTexture>()) { return true;
},
[&](const ast::StorageTexture*) { //
out << "storage_"; out << "storage_";
} else { return true;
diagnostics_.add_error(diag::System::Writer, "unknown texture type"); },
[&](Default) { //
diagnostics_.add_error(diag::System::Writer,
"unknown texture type");
return false;
});
if (!ok) {
return false; return false;
} }
@ -469,19 +513,25 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
return false; return false;
} }
if (auto* sampled = texture->As<ast::SampledTexture>()) { return Switch(
texture,
[&](const ast::SampledTexture* sampled) { //
out << "<"; out << "<";
if (!EmitType(out, sampled->type)) { if (!EmitType(out, sampled->type)) {
return false; return false;
} }
out << ">"; out << ">";
} else if (auto* ms = texture->As<ast::MultisampledTexture>()) { return true;
},
[&](const ast::MultisampledTexture* ms) { //
out << "<"; out << "<";
if (!EmitType(out, ms->type)) { if (!EmitType(out, ms->type)) {
return false; return false;
} }
out << ">"; out << ">";
} else if (auto* storage = texture->As<ast::StorageTexture>()) { return true;
},
[&](const ast::StorageTexture* storage) { //
out << "<"; out << "<";
if (!EmitImageFormat(out, storage->format)) { if (!EmitImageFormat(out, storage->format)) {
return false; return false;
@ -491,11 +541,17 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
return false; return false;
} }
out << ">"; out << ">";
} return true;
},
} else if (ty->Is<ast::U32>()) { [&](Default) { //
return true;
});
},
[&](const ast::U32*) {
out << "u32"; out << "u32";
} else if (auto* vec = ty->As<ast::Vector>()) { return true;
},
[&](const ast::Vector* vec) {
out << "vec" << vec->width; out << "vec" << vec->width;
if (auto* el_ty = vec->type) { if (auto* el_ty = vec->type) {
out << "<"; out << "<";
@ -504,17 +560,22 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
} }
out << ">"; out << ">";
} }
} else if (ty->Is<ast::Void>()) { return true;
},
[&](const ast::Void*) {
out << "void"; out << "void";
} else if (auto* tn = ty->As<ast::TypeName>()) { return true;
},
[&](const ast::TypeName* tn) {
out << program_->Symbols().NameFor(tn->name); out << program_->Symbols().NameFor(tn->name);
} else { return true;
},
[&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unknown type in EmitType: " + std::string(ty->TypeInfo().name)); "unknown type in EmitType: " + std::string(ty->TypeInfo().name));
return false; return false;
} });
return true;
} }
bool GeneratorImpl::EmitStructType(const ast::Struct* str) { bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
@ -632,7 +693,9 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out,
} }
first = false; first = false;
out << "@"; out << "@";
if (auto* workgroup = attr->As<ast::WorkgroupAttribute>()) { bool ok = Switch(
attr,
[&](const ast::WorkgroupAttribute* workgroup) {
auto values = workgroup->Values(); auto values = workgroup->Values();
out << "workgroup_size("; out << "workgroup_size(";
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -646,43 +709,75 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out,
} }
} }
out << ")"; out << ")";
} else if (attr->Is<ast::StructBlockAttribute>()) { return true;
},
[&](const ast::StructBlockAttribute*) { //
out << "block"; out << "block";
} else if (auto* stage = attr->As<ast::StageAttribute>()) { return true;
},
[&](const ast::StageAttribute* stage) {
out << "stage(" << stage->stage << ")"; out << "stage(" << stage->stage << ")";
} else if (auto* binding = attr->As<ast::BindingAttribute>()) { return true;
},
[&](const ast::BindingAttribute* binding) {
out << "binding(" << binding->value << ")"; out << "binding(" << binding->value << ")";
} else if (auto* group = attr->As<ast::GroupAttribute>()) { return true;
},
[&](const ast::GroupAttribute* group) {
out << "group(" << group->value << ")"; out << "group(" << group->value << ")";
} else if (auto* location = attr->As<ast::LocationAttribute>()) { return true;
},
[&](const ast::LocationAttribute* location) {
out << "location(" << location->value << ")"; out << "location(" << location->value << ")";
} else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { return true;
},
[&](const ast::BuiltinAttribute* builtin) {
out << "builtin(" << builtin->builtin << ")"; out << "builtin(" << builtin->builtin << ")";
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
out << "interpolate(" << interpolate->type; out << "interpolate(" << interpolate->type;
if (interpolate->sampling != ast::InterpolationSampling::kNone) { if (interpolate->sampling != ast::InterpolationSampling::kNone) {
out << ", " << interpolate->sampling; out << ", " << interpolate->sampling;
} }
out << ")"; out << ")";
} else if (attr->Is<ast::InvariantAttribute>()) { return true;
},
[&](const ast::InvariantAttribute*) {
out << "invariant"; out << "invariant";
} else if (auto* override_attr = attr->As<ast::OverrideAttribute>()) { return true;
},
[&](const ast::OverrideAttribute* override_deco) {
out << "override"; out << "override";
if (override_attr->has_value) { if (override_deco->has_value) {
out << "(" << override_attr->value << ")"; out << "(" << override_deco->value << ")";
} }
} else if (auto* size = attr->As<ast::StructMemberSizeAttribute>()) { return true;
},
[&](const ast::StructMemberSizeAttribute* size) {
out << "size(" << size->size << ")"; out << "size(" << size->size << ")";
} else if (auto* align = attr->As<ast::StructMemberAlignAttribute>()) { return true;
},
[&](const ast::StructMemberAlignAttribute* align) {
out << "align(" << align->align << ")"; out << "align(" << align->align << ")";
} else if (auto* stride = attr->As<ast::StrideAttribute>()) { return true;
},
[&](const ast::StrideAttribute* stride) {
out << "stride(" << stride->stride << ")"; out << "stride(" << stride->stride << ")";
} else if (auto* internal = attr->As<ast::InternalAttribute>()) { return true;
},
[&](const ast::InternalAttribute* internal) {
out << "internal(" << internal->InternalName() << ")"; out << "internal(" << internal->InternalName() << ")";
} else { return true;
},
[&](Default) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "Unsupported attribute '" << attr->TypeInfo().name << "'"; << "Unsupported attribute '" << attr->TypeInfo().name << "'";
return false; return false;
});
if (!ok) {
return false;
} }
} }
@ -809,55 +904,36 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
} }
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) { return Switch(
return EmitAssign(a); stmt, //
} [&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
if (auto* b = stmt->As<ast::BlockStatement>()) { [&](const ast::BlockStatement* b) { return EmitBlock(b); },
return EmitBlock(b); [&](const ast::BreakStatement* b) { return EmitBreak(b); },
} [&](const ast::CallStatement* c) {
if (auto* b = stmt->As<ast::BreakStatement>()) {
return EmitBreak(b);
}
if (auto* c = stmt->As<ast::CallStatement>()) {
auto out = line(); auto out = line();
if (!EmitCall(out, c->expr)) { if (!EmitCall(out, c->expr)) {
return false; return false;
} }
out << ";"; out << ";";
return true; return true;
} },
if (auto* c = stmt->As<ast::ContinueStatement>()) { [&](const ast::ContinueStatement* c) { return EmitContinue(c); },
return EmitContinue(c); [&](const ast::DiscardStatement* d) { return EmitDiscard(d); },
} [&](const ast::FallthroughStatement* f) { return EmitFallthrough(f); },
if (auto* d = stmt->As<ast::DiscardStatement>()) { [&](const ast::IfStatement* i) { return EmitIf(i); },
return EmitDiscard(d); [&](const ast::LoopStatement* l) { return EmitLoop(l); },
} [&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
if (auto* f = stmt->As<ast::FallthroughStatement>()) { [&](const ast::ReturnStatement* r) { return EmitReturn(r); },
return EmitFallthrough(f); [&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
} [&](const ast::VariableDeclStatement* v) {
if (auto* i = stmt->As<ast::IfStatement>()) {
return EmitIf(i);
}
if (auto* l = stmt->As<ast::LoopStatement>()) {
return EmitLoop(l);
}
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
return EmitForLoop(l);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return EmitSwitch(s);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(line(), v->variable); return EmitVariable(line(), v->variable);
} },
[&](Default) {
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name)); "unknown statement type: " + std::string(stmt->TypeInfo().name));
return false; return false;
});
} }
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) { bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {