Add tint::Switch()

A type dispatch helper with replaces chains of:

  if (auto* a = obj->As<A>()) {
    ...
  } else if (auto* b = obj->As<B>()) {
    ...
  } else {
    ...
  }

with:

  Switch(obj,
    [&](A* a) { ... },
    [&](B* b) { ... },
    [&](Default) { ... });

This new helper provides greater opportunities for optimizations, avoids
scoping issues with if-else blocks, and is slightly cleaner (IMO).

Bug: tint:1383
Change-Id: Ice469a03342ef57cbcf65f69753e4b528ac50137
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78543
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-02-04 15:38:23 +00:00 committed by Tint LUCI CQ
parent fa0d64b76d
commit de857e1c58
11 changed files with 2413 additions and 1570 deletions

View File

@ -1160,6 +1160,7 @@ if(TINT_BUILD_BENCHMARKS)
endif()
set(TINT_BENCHMARK_SRC
"castable_bench.cc"
"bench/benchmark.cc"
"reader/wgsl/parser_bench.cc"
)

View File

@ -35,16 +35,15 @@ Module::Module(ProgramID pid,
continue;
}
if (auto* ty = decl->As<ast::TypeDecl>()) {
type_decls_.push_back(ty);
} else if (auto* func = decl->As<Function>()) {
functions_.push_back(func);
} else if (auto* var = decl->As<Variable>()) {
global_variables_.push_back(var);
} else {
diag::List diagnostics;
TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
}
Switch(
decl, //
[&](const ast::TypeDecl* type) { type_decls_.push_back(type); },
[&](const Function* func) { functions_.push_back(func); },
[&](const Variable* var) { global_variables_.push_back(var); },
[&](Default) {
diag::List diagnostics;
TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
});
}
}
@ -101,19 +100,24 @@ void Module::Copy(CloneContext* ctx, const Module* src) {
<< "src global declaration was nullptr";
continue;
}
if (auto* type = decl->As<ast::TypeDecl>()) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
type_decls_.push_back(type);
} else if (auto* func = decl->As<Function>()) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
functions_.push_back(func);
} else if (auto* var = decl->As<Variable>()) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
global_variables_.push_back(var);
} else {
TINT_ICE(AST, ctx->dst->Diagnostics())
<< "Unknown global declaration type";
}
Switch(
decl,
[&](const ast::TypeDecl* type) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
type_decls_.push_back(type);
},
[&](const Function* func) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
functions_.push_back(func);
},
[&](const Variable* var) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
global_variables_.push_back(var);
},
[&](Default) {
TINT_ICE(AST, ctx->dst->Diagnostics())
<< "Unknown global declaration type";
});
}
}

View File

@ -101,30 +101,47 @@ bool TraverseExpressions(const ast::Expression* root,
}
}
if (auto* idx = expr->As<IndexAccessorExpression>()) {
push_pair(idx->object, idx->index);
} else if (auto* bin_op = expr->As<BinaryExpression>()) {
push_pair(bin_op->lhs, bin_op->rhs);
} else if (auto* bitcast = expr->As<BitcastExpression>()) {
to_visit.push_back(bitcast->expr);
} else if (auto* call = expr->As<CallExpression>()) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
// function name in the traversal.
// to_visit.push_back(call->func);
push_list(call->args);
} else if (auto* member = expr->As<MemberAccessorExpression>()) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
// member name in the traversal.
// push_pair(member->structure, member->member);
to_visit.push_back(member->structure);
} else if (auto* unary = expr->As<UnaryOpExpression>()) {
to_visit.push_back(unary->expr);
} else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
PhonyExpression>()) {
// Leaf expression
} else {
TINT_ICE(AST, diags) << "unhandled expression type: "
<< expr->TypeInfo().name;
bool ok = Switch(
expr,
[&](const IndexAccessorExpression* idx) {
push_pair(idx->object, idx->index);
return true;
},
[&](const BinaryExpression* bin_op) {
push_pair(bin_op->lhs, bin_op->rhs);
return true;
},
[&](const BitcastExpression* bitcast) {
to_visit.push_back(bitcast->expr);
return true;
},
[&](const CallExpression* call) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
// the function name in the traversal. to_visit.push_back(call->func);
push_list(call->args);
return true;
},
[&](const MemberAccessorExpression* member) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
// the member name in the traversal. push_pair(member->structure,
// member->member);
to_visit.push_back(member->structure);
return true;
},
[&](const UnaryOpExpression* unary) {
to_visit.push_back(unary->expr);
return true;
},
[&](Default) {
if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
PhonyExpression>()) {
return true; // Leaf expression
}
TINT_ICE(AST, diags)
<< "unhandled expression type: " << expr->TypeInfo().name;
return false;
});
if (!ok) {
return false;
}
}

View File

@ -453,6 +453,105 @@ class Castable : public BASE {
}
};
/// Default can be used as the default case for a Switch(), when all previous
/// cases failed to match.
///
/// Example:
/// ```
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ },
/// [&](Default) { /* If not TypeA or TypeB */ });
/// ```
struct Default {};
/// Switch is used to dispatch one of the provided callback case handler
/// functions based on the type of `object` and the parameter type of the case
/// handlers. Switch will sequentially check the type of `object` against each
/// of the switch case handler functions, and will invoke the first case handler
/// function which has a parameter type that matches the object type. When a
/// case handler is matched, it will be called with the single argument of
/// `object` cast to the case handler's parameter type. Switch will invoke at
/// most one case handler. Each of the case functions must have the signature
/// `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
/// is the return type, consistent across all case handlers.
///
/// An optional default case function with the signature `R(Default)` can be
/// used as the last case. This default case will be called if all previous
/// cases failed to match.
///
/// Example:
/// ```
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ });
///
/// Switch(object,
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ },
/// [&](Default) { /* Called if object is not TypeA or TypeB */ });
/// ```
///
/// @param object the object who's type is used to
/// @param first_case the first switch case
/// @param other_cases additional switch cases (optional)
/// @return the value returned by the called case. If no cases matched, then the
/// zero value for the consistent case type.
template <typename T, typename FIRST_CASE, typename... OTHER_CASES>
traits::ReturnType<FIRST_CASE> //
Switch(T* object, FIRST_CASE&& first_case, OTHER_CASES&&... other_cases) {
using ReturnType = traits::ReturnType<FIRST_CASE>;
using CaseType = std::remove_pointer_t<traits::ParameterType<FIRST_CASE, 0>>;
static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
static_assert(traits::SignatureOfT<FIRST_CASE>::parameter_count == 1,
"Switch case must have a single parameter");
if constexpr (std::is_same_v<CaseType, Default>) {
// Default case. Must be last.
(void)object; // 'object' is not used by the Default case.
static_assert(sizeof...(other_cases) == 0,
"Switch Default case must come last");
if constexpr (kHasReturnType) {
return first_case({});
} else {
first_case({});
return;
}
} else {
// Regular case.
static_assert(traits::IsTypeOrDerived<CaseType, CastableBase>::value,
"Switch case parameter is not a Castable pointer");
// Does the case match?
if (auto* ptr = As<CaseType>(object)) {
if constexpr (kHasReturnType) {
return first_case(ptr);
} else {
first_case(ptr);
return;
}
}
// Case did not match. Got any more cases to try?
if constexpr (sizeof...(other_cases) > 0) {
// Try the next cases...
if constexpr (kHasReturnType) {
auto res = Switch(object, std::forward<OTHER_CASES>(other_cases)...);
static_assert(std::is_same_v<decltype(res), ReturnType>,
"Switch case types do not have consistent return type");
return res;
} else {
Switch(object, std::forward<OTHER_CASES>(other_cases)...);
return;
}
} else {
// That was the last case. No cases matched.
if constexpr (kHasReturnType) {
return {};
} else {
return;
}
}
}
}
} // namespace tint
TINT_CASTABLE_POP_DISABLE_WARNINGS();

270
src/castable_bench.cc Normal file
View File

@ -0,0 +1,270 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bench/benchmark.h"
namespace tint {
namespace {
struct Base : public tint::Castable<Base> {};
struct A : public tint::Castable<A, Base> {};
struct AA : public tint::Castable<AA, A> {};
struct AAA : public tint::Castable<AAA, AA> {};
struct AAB : public tint::Castable<AAB, AA> {};
struct AAC : public tint::Castable<AAC, AA> {};
struct AB : public tint::Castable<AB, A> {};
struct ABA : public tint::Castable<ABA, AB> {};
struct ABB : public tint::Castable<ABB, AB> {};
struct ABC : public tint::Castable<ABC, AB> {};
struct AC : public tint::Castable<AC, A> {};
struct ACA : public tint::Castable<ACA, AC> {};
struct ACB : public tint::Castable<ACB, AC> {};
struct ACC : public tint::Castable<ACC, AC> {};
struct B : public tint::Castable<B, Base> {};
struct BA : public tint::Castable<BA, B> {};
struct BAA : public tint::Castable<BAA, BA> {};
struct BAB : public tint::Castable<BAB, BA> {};
struct BAC : public tint::Castable<BAC, BA> {};
struct BB : public tint::Castable<BB, B> {};
struct BBA : public tint::Castable<BBA, BB> {};
struct BBB : public tint::Castable<BBB, BB> {};
struct BBC : public tint::Castable<BBC, BB> {};
struct BC : public tint::Castable<BC, B> {};
struct BCA : public tint::Castable<BCA, BC> {};
struct BCB : public tint::Castable<BCB, BC> {};
struct BCC : public tint::Castable<BCC, BC> {};
struct C : public tint::Castable<C, Base> {};
struct CA : public tint::Castable<CA, C> {};
struct CAA : public tint::Castable<CAA, CA> {};
struct CAB : public tint::Castable<CAB, CA> {};
struct CAC : public tint::Castable<CAC, CA> {};
struct CB : public tint::Castable<CB, C> {};
struct CBA : public tint::Castable<CBA, CB> {};
struct CBB : public tint::Castable<CBB, CB> {};
struct CBC : public tint::Castable<CBC, CB> {};
struct CC : public tint::Castable<CC, C> {};
struct CCA : public tint::Castable<CCA, CC> {};
struct CCB : public tint::Castable<CCB, CC> {};
struct CCC : public tint::Castable<CCC, CC> {};
using AllTypes = std::tuple<Base,
A,
AA,
AAA,
AAB,
AAC,
AB,
ABA,
ABB,
ABC,
AC,
ACA,
ACB,
ACC,
B,
BA,
BAA,
BAB,
BAC,
BB,
BBA,
BBB,
BBC,
BC,
BCA,
BCB,
BCC,
C,
CA,
CAA,
CAB,
CAC,
CB,
CBA,
CBB,
CBC,
CC,
CCA,
CCB,
CCC>;
std::vector<std::unique_ptr<Base>> MakeObjects() {
std::vector<std::unique_ptr<Base>> out;
out.emplace_back(std::make_unique<Base>());
out.emplace_back(std::make_unique<A>());
out.emplace_back(std::make_unique<AA>());
out.emplace_back(std::make_unique<AAA>());
out.emplace_back(std::make_unique<AAB>());
out.emplace_back(std::make_unique<AAC>());
out.emplace_back(std::make_unique<AB>());
out.emplace_back(std::make_unique<ABA>());
out.emplace_back(std::make_unique<ABB>());
out.emplace_back(std::make_unique<ABC>());
out.emplace_back(std::make_unique<AC>());
out.emplace_back(std::make_unique<ACA>());
out.emplace_back(std::make_unique<ACB>());
out.emplace_back(std::make_unique<ACC>());
out.emplace_back(std::make_unique<B>());
out.emplace_back(std::make_unique<BA>());
out.emplace_back(std::make_unique<BAA>());
out.emplace_back(std::make_unique<BAB>());
out.emplace_back(std::make_unique<BAC>());
out.emplace_back(std::make_unique<BB>());
out.emplace_back(std::make_unique<BBA>());
out.emplace_back(std::make_unique<BBB>());
out.emplace_back(std::make_unique<BBC>());
out.emplace_back(std::make_unique<BC>());
out.emplace_back(std::make_unique<BCA>());
out.emplace_back(std::make_unique<BCB>());
out.emplace_back(std::make_unique<BCC>());
out.emplace_back(std::make_unique<C>());
out.emplace_back(std::make_unique<CA>());
out.emplace_back(std::make_unique<CAA>());
out.emplace_back(std::make_unique<CAB>());
out.emplace_back(std::make_unique<CAC>());
out.emplace_back(std::make_unique<CB>());
out.emplace_back(std::make_unique<CBA>());
out.emplace_back(std::make_unique<CBB>());
out.emplace_back(std::make_unique<CBC>());
out.emplace_back(std::make_unique<CC>());
out.emplace_back(std::make_unique<CCA>());
out.emplace_back(std::make_unique<CCB>());
out.emplace_back(std::make_unique<CCC>());
return out;
}
void CastableLargeSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const AAA*) { ::benchmark::DoNotOptimize(i += 40); },
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 50); },
[&](const AAC*) { ::benchmark::DoNotOptimize(i += 60); },
[&](const ABA*) { ::benchmark::DoNotOptimize(i += 80); },
[&](const ABB*) { ::benchmark::DoNotOptimize(i += 90); },
[&](const ABC*) { ::benchmark::DoNotOptimize(i += 100); },
[&](const ACA*) { ::benchmark::DoNotOptimize(i += 120); },
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
[&](const ACC*) { ::benchmark::DoNotOptimize(i += 140); },
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
[&](const BAC*) { ::benchmark::DoNotOptimize(i += 190); },
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
[&](const BBC*) { ::benchmark::DoNotOptimize(i += 230); },
[&](const BCA*) { ::benchmark::DoNotOptimize(i += 250); },
[&](const BCB*) { ::benchmark::DoNotOptimize(i += 260); },
[&](const BCC*) { ::benchmark::DoNotOptimize(i += 270); },
[&](const CA*) { ::benchmark::DoNotOptimize(i += 290); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
[&](const CAB*) { ::benchmark::DoNotOptimize(i += 310); },
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 320); },
[&](const CBA*) { ::benchmark::DoNotOptimize(i += 340); },
[&](const CBB*) { ::benchmark::DoNotOptimize(i += 350); },
[&](const CBC*) { ::benchmark::DoNotOptimize(i += 360); },
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableLargeSwitch);
void CastableMediumSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableMediumSwitch);
void CastableSmallSwitch(::benchmark::State& state) {
auto objects = MakeObjects();
size_t i = 0;
for (auto _ : state) {
auto* object = objects[i % objects.size()].get();
Switch(
object, //
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 30); },
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 290); },
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); });
i = (i * 31) ^ (i << 5);
}
}
BENCHMARK(CastableSmallSwitch);
} // namespace
} // namespace tint
TINT_INSTANTIATE_TYPEINFO(tint::Base);
TINT_INSTANTIATE_TYPEINFO(tint::A);
TINT_INSTANTIATE_TYPEINFO(tint::AA);
TINT_INSTANTIATE_TYPEINFO(tint::AAA);
TINT_INSTANTIATE_TYPEINFO(tint::AAB);
TINT_INSTANTIATE_TYPEINFO(tint::AAC);
TINT_INSTANTIATE_TYPEINFO(tint::AB);
TINT_INSTANTIATE_TYPEINFO(tint::ABA);
TINT_INSTANTIATE_TYPEINFO(tint::ABB);
TINT_INSTANTIATE_TYPEINFO(tint::ABC);
TINT_INSTANTIATE_TYPEINFO(tint::AC);
TINT_INSTANTIATE_TYPEINFO(tint::ACA);
TINT_INSTANTIATE_TYPEINFO(tint::ACB);
TINT_INSTANTIATE_TYPEINFO(tint::ACC);
TINT_INSTANTIATE_TYPEINFO(tint::B);
TINT_INSTANTIATE_TYPEINFO(tint::BA);
TINT_INSTANTIATE_TYPEINFO(tint::BAA);
TINT_INSTANTIATE_TYPEINFO(tint::BAB);
TINT_INSTANTIATE_TYPEINFO(tint::BAC);
TINT_INSTANTIATE_TYPEINFO(tint::BB);
TINT_INSTANTIATE_TYPEINFO(tint::BBA);
TINT_INSTANTIATE_TYPEINFO(tint::BBB);
TINT_INSTANTIATE_TYPEINFO(tint::BBC);
TINT_INSTANTIATE_TYPEINFO(tint::BC);
TINT_INSTANTIATE_TYPEINFO(tint::BCA);
TINT_INSTANTIATE_TYPEINFO(tint::BCB);
TINT_INSTANTIATE_TYPEINFO(tint::BCC);
TINT_INSTANTIATE_TYPEINFO(tint::C);
TINT_INSTANTIATE_TYPEINFO(tint::CA);
TINT_INSTANTIATE_TYPEINFO(tint::CAA);
TINT_INSTANTIATE_TYPEINFO(tint::CAB);
TINT_INSTANTIATE_TYPEINFO(tint::CAC);
TINT_INSTANTIATE_TYPEINFO(tint::CB);
TINT_INSTANTIATE_TYPEINFO(tint::CBA);
TINT_INSTANTIATE_TYPEINFO(tint::CBB);
TINT_INSTANTIATE_TYPEINFO(tint::CBC);
TINT_INSTANTIATE_TYPEINFO(tint::CC);
TINT_INSTANTIATE_TYPEINFO(tint::CCA);
TINT_INSTANTIATE_TYPEINFO(tint::CCB);
TINT_INSTANTIATE_TYPEINFO(tint::CCC);

View File

@ -252,6 +252,151 @@ TEST(Castable, As) {
ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get()));
}
TEST(Castable, SwitchNoDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_amphibian = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Amphibian* amphibian) {
EXPECT_EQ(amphibian, frog.get());
frog_matched_amphibian = true;
});
EXPECT_TRUE(frog_matched_amphibian);
}
{
bool bear_matched_mammal = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Mammal* mammal) {
EXPECT_EQ(mammal, bear.get());
bear_matched_mammal = true;
});
EXPECT_TRUE(bear_matched_mammal);
}
{
bool gecko_matched_reptile = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Reptile* reptile) {
EXPECT_EQ(reptile, gecko.get());
gecko_matched_reptile = true;
});
EXPECT_TRUE(gecko_matched_reptile);
}
}
TEST(Castable, SwitchWithUnusedDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_amphibian = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Amphibian* amphibian) {
EXPECT_EQ(amphibian, frog.get());
frog_matched_amphibian = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(frog_matched_amphibian);
}
{
bool bear_matched_mammal = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Mammal* mammal) {
EXPECT_EQ(mammal, bear.get());
bear_matched_mammal = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(bear_matched_mammal);
}
{
bool gecko_matched_reptile = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Reptile* reptile) {
EXPECT_EQ(reptile, gecko.get());
gecko_matched_reptile = true;
},
[&](Default) { FAIL() << "default should not have been selected"; });
EXPECT_TRUE(gecko_matched_reptile);
}
}
TEST(Castable, SwitchDefault) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
{
bool frog_matched_default = false;
Switch(
frog.get(), //
[&](Reptile*) { FAIL() << "frog is not reptile"; },
[&](Mammal*) { FAIL() << "frog is not mammal"; },
[&](Default) { frog_matched_default = true; });
EXPECT_TRUE(frog_matched_default);
}
{
bool bear_matched_default = false;
Switch(
bear.get(), //
[&](Reptile*) { FAIL() << "bear is not reptile"; },
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
[&](Default) { bear_matched_default = true; });
EXPECT_TRUE(bear_matched_default);
}
{
bool gecko_matched_default = false;
Switch(
gecko.get(), //
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
[&](Default) { gecko_matched_default = true; });
EXPECT_TRUE(gecko_matched_default);
}
}
TEST(Castable, SwitchMatchFirst) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
{
bool frog_matched_animal = false;
Switch(
frog.get(),
[&](Animal* animal) {
EXPECT_EQ(animal, frog.get());
frog_matched_animal = true;
},
[&](Amphibian*) { FAIL() << "animal should have been matched first"; });
EXPECT_TRUE(frog_matched_animal);
}
{
bool frog_matched_amphibian = false;
Switch(
frog.get(),
[&](Amphibian* amphibain) {
EXPECT_EQ(amphibain, frog.get());
frog_matched_amphibian = true;
},
[&](Animal*) { FAIL() << "amphibian should have been matched first"; });
EXPECT_TRUE(frog_matched_amphibian);
}
}
} // namespace
TINT_INSTANTIATE_TYPEINFO(Animal);

View File

@ -953,7 +953,7 @@ const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* var_type,
ast::AttributeList* decos,
ast::AttributeList* attrs,
std::vector<int> index_prefix,
const Type* tip_type,
const Type* forced_param_type,
@ -966,105 +966,121 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name,
}
// Recursively flatten matrices, arrays, and structures.
if (auto* matrix_type = tip_type->As<Matrix>()) {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty,
forced_param_type, params, statements)) {
return false;
}
}
return success();
} else if (auto* array_type = tip_type->As<Array>()) {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty,
forced_param_type, params, statements)) {
return false;
}
}
return success();
} else if (auto* struct_type = tip_type->As<Struct>()) {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
ast::AttributeList member_decos(*decos);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_decos)) {
return false;
}
if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix,
members[i], forced_param_type, params,
statements)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_decos));
}
return success();
}
return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
vec_ty, forced_param_type, params,
statements)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
elem_ty, forced_param_type, params,
statements)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
ast::AttributeList member_attrs(*attrs);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_attrs)) {
return false;
}
if (!EmitPipelineInput(var_name, var_type, &member_attrs,
index_prefix, members[i], forced_param_type,
params, statements)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin =
ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
const Type* param_type = is_builtin ? forced_param_type : tip_type;
const Type* param_type = is_builtin ? forced_param_type : tip_type;
const auto param_name = namer_.MakeDerivedName(var_name + "_param");
// Create the parameter.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
params->push_back(
builder_.Param(param_name, param_type->Build(builder_), *attrs));
const auto param_name = namer_.MakeDerivedName(var_name + "_param");
// Create the parameter.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple elements
// of a matrix, array, or structure. Normally that's disallowed but currently
// the SPIR-V reader will make duplicates when the entire AST is cloned
// at the top level of the SPIR-V reader flow. Consider rewriting this
// to avoid this node-sharing.
params->push_back(
builder_.Param(param_name, param_type->Build(builder_), *decos));
// Add a body statement to copy the parameter to the corresponding
// private variable.
const ast::Expression* param_value = builder_.Expr(param_name);
const ast::Expression* store_dest = builder_.Expr(var_name);
// Add a body statement to copy the parameter to the corresponding private
// variable.
const ast::Expression* param_value = builder_.Expr(param_name);
const ast::Expression* store_dest = builder_.Expr(var_name);
// Index into the LHS as needed.
auto* current_type =
var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
store_dest =
builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
store_dest =
builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
store_dest = builder_.MemberAccessor(
store_dest, builder_.Expr(parser_impl_.GetMemberName(
*struct_type, index)));
current_type = struct_type->members[index];
});
}
// Index into the LHS as needed.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
if (auto* matrix_type = current_type->As<Matrix>()) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
} else if (auto* array_type = current_type->As<Array>()) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
} else if (auto* struct_type = current_type->As<Struct>()) {
store_dest = builder_.MemberAccessor(
store_dest,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[index];
}
}
if (is_builtin && (tip_type != forced_param_type)) {
// The parameter will have the WGSL type, but we need bitcast to
// the variable store type.
param_value = create<ast::BitcastExpression>(
tip_type->Build(builder_), param_value);
}
if (is_builtin && (tip_type != forced_param_type)) {
// The parameter will have the WGSL type, but we need bitcast to
// the variable store type.
param_value =
create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
}
statements->push_back(builder_.Assign(store_dest, param_value));
statements->push_back(builder_.Assign(store_dest, param_value));
// Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(attrs);
// Increment the location attribute, in case more parameters will follow.
IncrementLocation(decos);
return success();
return success();
});
}
void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
@ -1102,106 +1118,120 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
}
// Recursively flatten matrices, arrays, and structures.
if (auto* matrix_type = tip_type->As<Matrix>()) {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
forced_member_type, return_members,
return_exprs)) {
return false;
}
}
return success();
} else if (auto* array_type = tip_type->As<Array>()) {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
forced_member_type, return_members,
return_exprs)) {
return false;
}
}
return success();
} else if (auto* struct_type = tip_type->As<Struct>()) {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
ast::AttributeList member_decos(*decos);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_decos)) {
return false;
}
if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix,
members[i], forced_member_type, return_members,
return_exprs)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_decos));
}
return success();
}
return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
vec_ty, forced_member_type, return_members,
return_exprs)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
elem_ty, forced_member_type, return_members,
return_exprs)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
ast::AttributeList member_attrs(*decos);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_attrs)) {
return false;
}
if (!EmitPipelineOutput(var_name, var_type, &member_attrs,
index_prefix, members[i], forced_member_type,
return_members, return_exprs)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin =
ast::HasAttribute<ast::BuiltinAttribute>(*decos);
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
const Type* member_type = is_builtin ? forced_member_type : tip_type;
// Derive the member name directly from the variable name. They can't
// collide.
const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
return_members->push_back(
builder_.Member(member_name, member_type->Build(builder_), *decos));
const Type* member_type = is_builtin ? forced_member_type : tip_type;
// Derive the member name directly from the variable name. They can't
// collide.
const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple elements
// of a matrix, array, or structure. Normally that's disallowed but currently
// the SPIR-V reader will make duplicates when the entire AST is cloned
// at the top level of the SPIR-V reader flow. Consider rewriting this
// to avoid this node-sharing.
return_members->push_back(
builder_.Member(member_name, member_type->Build(builder_), *decos));
// Create an expression to evaluate the part of the variable indexed by
// the index_prefix.
const ast::Expression* load_source = builder_.Expr(var_name);
// Create an expression to evaluate the part of the variable indexed by
// the index_prefix.
const ast::Expression* load_source = builder_.Expr(var_name);
// Index into the variable as needed to pick out the flattened member.
auto* current_type =
var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
load_source = builder_.MemberAccessor(
load_source, builder_.Expr(parser_impl_.GetMemberName(
*struct_type, index)));
current_type = struct_type->members[index];
});
}
// Index into the variable as needed to pick out the flattened member.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
if (auto* matrix_type = current_type->As<Matrix>()) {
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
} else if (auto* array_type = current_type->As<Array>()) {
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
current_type = array_type->type->UnwrapAlias();
} else if (auto* struct_type = current_type->As<Struct>()) {
load_source = builder_.MemberAccessor(
load_source,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[index];
}
}
if (is_builtin && (tip_type != forced_member_type)) {
// The member will have the WGSL type, but we need bitcast to
// the variable store type.
load_source = create<ast::BitcastExpression>(
forced_member_type->Build(builder_), load_source);
}
return_exprs->push_back(load_source);
if (is_builtin && (tip_type != forced_member_type)) {
// The member will have the WGSL type, but we need bitcast to
// the variable store type.
load_source = create<ast::BitcastExpression>(
forced_member_type->Build(builder_), load_source);
}
return_exprs->push_back(load_source);
// Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(decos);
// Increment the location attribute, in case more parameters will follow.
IncrementLocation(decos);
return success();
return success();
});
}
bool FunctionEmitter::EmitEntryPointAsWrapper() {

View File

@ -239,39 +239,41 @@ bool GeneratorImpl::Generate() {
}
last_kind = kind;
if (auto* global = decl->As<ast::Variable>()) {
if (!EmitGlobalVariable(global)) {
return false;
}
} else if (auto* str = decl->As<ast::Struct>()) {
auto* ty = builder_.Sem().Get(str);
auto storage_class_uses = ty->StorageClassUsage();
if (storage_class_uses.size() !=
(storage_class_uses.count(ast::StorageClass::kStorage) +
storage_class_uses.count(ast::StorageClass::kUniform))) {
// The structure is used as something other than a storage buffer or
// uniform buffer, so it needs to be emitted.
// Storage buffer are read and written to via a ByteAddressBuffer
// instead of true structure.
// Structures used as uniform buffer are read from an array of vectors
// instead of true structure.
if (!EmitStructType(current_buffer_, ty)) {
bool ok = Switch(
decl,
[&](const ast::Variable* global) { //
return EmitGlobalVariable(global);
},
[&](const ast::Struct* str) {
auto* ty = builder_.Sem().Get(str);
auto storage_class_uses = ty->StorageClassUsage();
if (storage_class_uses.size() !=
(storage_class_uses.count(ast::StorageClass::kStorage) +
storage_class_uses.count(ast::StorageClass::kUniform))) {
// The structure is used as something other than a storage buffer or
// uniform buffer, so it needs to be emitted.
// Storage buffer are read and written to via a ByteAddressBuffer
// instead of true structure.
// Structures used as uniform buffer are read from an array of
// vectors instead of true structure.
return EmitStructType(current_buffer_, ty);
}
return true;
},
[&](const ast::Function* func) {
if (func->IsEntryPoint()) {
return EmitEntryPointFunction(func);
}
return EmitFunction(func);
},
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unhandled module-scope declaration: "
<< decl->TypeInfo().name;
return false;
}
}
} else if (auto* func = decl->As<ast::Function>()) {
if (func->IsEntryPoint()) {
if (!EmitEntryPointFunction(func)) {
return false;
}
} else {
if (!EmitFunction(func)) {
return false;
}
}
} else {
TINT_ICE(Writer, diagnostics_)
<< "unhandled module-scope declaration: " << decl->TypeInfo().name;
});
if (!ok) {
return false;
}
}
@ -929,22 +931,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
if (auto* func = target->As<sem::Function>()) {
return EmitFunctionCall(out, call, func);
}
if (auto* builtin = target->As<sem::Builtin>()) {
return EmitBuiltinCall(out, call, builtin);
}
if (auto* conv = target->As<sem::TypeConversion>()) {
return EmitTypeConversion(out, call, conv);
}
if (auto* ctor = target->As<sem::TypeConstructor>()) {
return EmitTypeConstructor(out, call, ctor);
}
TINT_ICE(Writer, diagnostics_)
<< "unhandled call target: " << target->TypeInfo().name;
return false;
return Switch(
target,
[&](const sem::Function* func) {
return EmitFunctionCall(out, call, func);
},
[&](const sem::Builtin* builtin) {
return EmitBuiltinCall(out, call, builtin);
},
[&](const sem::TypeConversion* conv) {
return EmitTypeConversion(out, call, conv);
},
[&](const sem::TypeConstructor* ctor) {
return EmitTypeConstructor(out, call, ctor);
},
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unhandled call target: " << target->TypeInfo().name;
return false;
});
}
bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@ -2639,35 +2644,38 @@ bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
return EmitIndexAccessor(out, a);
}
if (auto* b = expr->As<ast::BinaryExpression>()) {
return EmitBinary(out, b);
}
if (auto* b = expr->As<ast::BitcastExpression>()) {
return EmitBitcast(out, b);
}
if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i);
}
if (auto* l = expr->As<ast::LiteralExpression>()) {
return EmitLiteral(out, l);
}
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
return EmitMemberAccessor(out, m);
}
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
return EmitUnaryOp(out, u);
}
diagnostics_.add_error(
diag::System::Writer,
"unknown expression type: " + std::string(expr->TypeInfo().name));
return false;
return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
return EmitIndexAccessor(out, a);
},
[&](const ast::BinaryExpression* b) { //
return EmitBinary(out, b);
},
[&](const ast::BitcastExpression* b) { //
return EmitBitcast(out, b);
},
[&](const ast::CallExpression* c) { //
return EmitCall(out, c);
},
[&](const ast::IdentifierExpression* i) { //
return EmitIdentifier(out, i);
},
[&](const ast::LiteralExpression* l) { //
return EmitLiteral(out, l);
},
[&](const ast::MemberAccessorExpression* m) { //
return EmitMemberAccessor(out, m);
},
[&](const ast::UnaryOpExpression* u) { //
return EmitUnaryOp(out, u);
},
[&](Default) { //
diagnostics_.add_error(
diag::System::Writer,
"unknown expression type: " + std::string(expr->TypeInfo().name));
return false;
});
}
bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@ -3127,80 +3135,108 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) {
if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
out << (l->value ? "true" : "false");
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
if (std::isinf(fl->value)) {
out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
} else if (std::isnan(fl->value)) {
out << "asfloat(0x7fc00000u)";
} else {
out << FloatToString(fl->value) << "f";
}
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
out << sl->value;
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
out << ul->value << "u";
} else {
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
}
return true;
return Switch(
lit,
[&](const ast::BoolLiteralExpression* l) {
out << (l->value ? "true" : "false");
return true;
},
[&](const ast::FloatLiteralExpression* fl) {
if (std::isinf(fl->value)) {
out << (fl->value >= 0 ? "asfloat(0x7f800000u)"
: "asfloat(0xff800000u)");
} else if (std::isnan(fl->value)) {
out << "asfloat(0x7fc00000u)";
} else {
out << FloatToString(fl->value) << "f";
}
return true;
},
[&](const ast::SintLiteralExpression* sl) {
out << sl->value;
return true;
},
[&](const ast::UintLiteralExpression* ul) {
out << ul->value << "u";
return true;
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
});
}
bool GeneratorImpl::EmitValue(std::ostream& out,
const sem::Type* type,
int value) {
if (type->Is<sem::Bool>()) {
out << (value == 0 ? "false" : "true");
} else if (type->Is<sem::F32>()) {
out << value << ".0f";
} else if (type->Is<sem::I32>()) {
out << value;
} else if (type->Is<sem::U32>()) {
out << value << "u";
} else if (auto* vec = type->As<sem::Vector>()) {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
"")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < vec->Width(); i++) {
if (i != 0) {
out << ", ";
}
if (!EmitValue(out, vec->type(), value)) {
return Switch(
type,
[&](const sem::Bool*) {
out << (value == 0 ? "false" : "true");
return true;
},
[&](const sem::F32*) {
out << value << ".0f";
return true;
},
[&](const sem::I32*) {
out << value;
return true;
},
[&](const sem::U32*) {
out << value << "u";
return true;
},
[&](const sem::Vector* vec) {
if (!EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < vec->Width(); i++) {
if (i != 0) {
out << ", ";
}
if (!EmitValue(out, vec->type(), value)) {
return false;
}
}
return true;
},
[&](const sem::Matrix* mat) {
if (!EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kReadWrite, "")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
if (i != 0) {
out << ", ";
}
if (!EmitValue(out, mat->type(), value)) {
return false;
}
}
return true;
},
[&](const sem::Struct*) {
out << "(";
TINT_DEFER(out << ")" << value);
return EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kUndefined, "");
},
[&](const sem::Array*) {
out << "(";
TINT_DEFER(out << ")" << value);
return EmitType(out, type, ast::StorageClass::kNone,
ast::Access::kUndefined, "");
},
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
"Invalid type for value emission: " + type->type_name());
return false;
}
}
} else if (auto* mat = type->As<sem::Matrix>()) {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
"")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
if (i != 0) {
out << ", ";
}
if (!EmitValue(out, mat->type(), value)) {
return false;
}
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
out << "(";
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
"")) {
return false;
}
out << ")" << value;
} else {
diagnostics_.add_error(
diag::System::Writer,
"Invalid type for value emission: " + type->type_name());
return false;
}
return true;
});
}
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
@ -3375,56 +3411,59 @@ bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
}
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return EmitAssign(a);
}
if (auto* b = stmt->As<ast::BlockStatement>()) {
return EmitBlock(b);
}
if (auto* b = stmt->As<ast::BreakStatement>()) {
return EmitBreak(b);
}
if (auto* c = stmt->As<ast::CallStatement>()) {
auto out = line();
if (!EmitCall(out, c->expr)) {
return false;
}
out << ";";
return true;
}
if (auto* c = stmt->As<ast::ContinueStatement>()) {
return EmitContinue(c);
}
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return EmitDiscard(d);
}
if (stmt->As<ast::FallthroughStatement>()) {
line() << "/* fallthrough */";
return true;
}
if (auto* i = stmt->As<ast::IfStatement>()) {
return EmitIf(i);
}
if (auto* l = stmt->As<ast::LoopStatement>()) {
return EmitLoop(l);
}
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
return EmitForLoop(l);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return EmitSwitch(s);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(v->variable);
}
diagnostics_.add_error(
diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name));
return false;
return Switch(
stmt,
[&](const ast::AssignmentStatement* a) { //
return EmitAssign(a);
},
[&](const ast::BlockStatement* b) { //
return EmitBlock(b);
},
[&](const ast::BreakStatement* b) { //
return EmitBreak(b);
},
[&](const ast::CallStatement* c) { //
auto out = line();
if (!EmitCall(out, c->expr)) {
return false;
}
out << ";";
return true;
},
[&](const ast::ContinueStatement* c) { //
return EmitContinue(c);
},
[&](const ast::DiscardStatement* d) { //
return EmitDiscard(d);
},
[&](const ast::FallthroughStatement*) { //
line() << "/* fallthrough */";
return true;
},
[&](const ast::IfStatement* i) { //
return EmitIf(i);
},
[&](const ast::LoopStatement* l) { //
return EmitLoop(l);
},
[&](const ast::ForLoopStatement* l) { //
return EmitForLoop(l);
},
[&](const ast::ReturnStatement* r) { //
return EmitReturn(r);
},
[&](const ast::SwitchStatement* s) { //
return EmitSwitch(s);
},
[&](const ast::VariableDeclStatement* v) { //
return EmitVariable(v->variable);
},
[&](Default) { //
diagnostics_.add_error(
diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name));
return false;
});
}
bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
@ -3516,156 +3555,181 @@ bool GeneratorImpl::EmitType(std::ostream& out,
break;
}
if (auto* ary = type->As<sem::Array>()) {
const sem::Type* base_type = ary;
std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) {
return Switch(
type,
[&](const sem::Array* ary) {
const sem::Type* base_type = ary;
std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) {
TINT_ICE(Writer, diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which "
"should "
"have been transformed into a ByteAddressBuffer";
return false;
}
sizes.push_back(arr->Count());
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, storage_class, access, "")) {
return false;
}
if (!name.empty()) {
out << " " << name;
if (name_printed) {
*name_printed = true;
}
}
for (uint32_t size : sizes) {
out << "[" << size << "]";
}
return true;
},
[&](const sem::Bool*) {
out << "bool";
return true;
},
[&](const sem::F32*) {
out << "float";
return true;
},
[&](const sem::I32*) {
out << "int";
return true;
},
[&](const sem::Matrix* mat) {
if (!EmitType(out, mat->type(), storage_class, access, "")) {
return false;
}
// Note: HLSL's matrices are declared as <type>NxM, where N is the
// number of rows and M is the number of columns. Despite HLSL's
// matrices being column-major by default, the index operator and
// constructors actually operate on row-vectors, where as WGSL operates
// on column vectors. To simplify everything we use the transpose of the
// matrices. See:
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
out << mat->columns() << "x" << mat->rows();
return true;
},
[&](const sem::Pointer*) {
TINT_ICE(Writer, diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which should "
"have been transformed into a ByteAddressBuffer";
<< "Attempting to emit pointer type. These should have been "
"removed with the InlinePointerLets transform";
return false;
}
sizes.push_back(arr->Count());
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, storage_class, access, "")) {
return false;
}
if (!name.empty()) {
out << " " << name;
if (name_printed) {
*name_printed = true;
}
}
for (uint32_t size : sizes) {
out << "[" << size << "]";
}
} else if (type->Is<sem::Bool>()) {
out << "bool";
} else if (type->Is<sem::F32>()) {
out << "float";
} else if (type->Is<sem::I32>()) {
out << "int";
} else if (auto* mat = type->As<sem::Matrix>()) {
if (!EmitType(out, mat->type(), storage_class, access, "")) {
return false;
}
// Note: HLSL's matrices are declared as <type>NxM, where N is the number of
// rows and M is the number of columns. Despite HLSL's matrices being
// column-major by default, the index operator and constructors actually
// operate on row-vectors, where as WGSL operates on column vectors.
// To simplify everything we use the transpose of the matrices.
// See:
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
out << mat->columns() << "x" << mat->rows();
} else if (type->Is<sem::Pointer>()) {
TINT_ICE(Writer, diagnostics_)
<< "Attempting to emit pointer type. These should have been removed "
"with the InlinePointerLets transform";
return false;
} else if (auto* sampler = type->As<sem::Sampler>()) {
out << "Sampler";
if (sampler->IsComparison()) {
out << "Comparison";
}
out << "State";
} else if (auto* str = type->As<sem::Struct>()) {
out << StructName(str);
} else if (auto* tex = type->As<sem::Texture>()) {
auto* storage = tex->As<sem::StorageTexture>();
auto* ms = tex->As<sem::MultisampledTexture>();
auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
auto* sampled = tex->As<sem::SampledTexture>();
},
[&](const sem::Sampler* sampler) {
out << "Sampler";
if (sampler->IsComparison()) {
out << "Comparison";
}
out << "State";
return true;
},
[&](const sem::Struct* str) {
out << StructName(str);
return true;
},
[&](const sem::Texture* tex) {
auto* storage = tex->As<sem::StorageTexture>();
auto* ms = tex->As<sem::MultisampledTexture>();
auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
auto* sampled = tex->As<sem::SampledTexture>();
if (storage && storage->access() != ast::Access::kRead) {
out << "RW";
}
out << "Texture";
if (storage && storage->access() != ast::Access::kRead) {
out << "RW";
}
out << "Texture";
switch (tex->dim()) {
case ast::TextureDimension::k1d:
out << "1D";
break;
case ast::TextureDimension::k2d:
out << ((ms || depth_ms) ? "2DMS" : "2D");
break;
case ast::TextureDimension::k2dArray:
out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
break;
case ast::TextureDimension::k3d:
out << "3D";
break;
case ast::TextureDimension::kCube:
out << "Cube";
break;
case ast::TextureDimension::kCubeArray:
out << "CubeArray";
break;
default:
TINT_UNREACHABLE(Writer, diagnostics_)
<< "unexpected TextureDimension " << tex->dim();
return false;
}
switch (tex->dim()) {
case ast::TextureDimension::k1d:
out << "1D";
break;
case ast::TextureDimension::k2d:
out << ((ms || depth_ms) ? "2DMS" : "2D");
break;
case ast::TextureDimension::k2dArray:
out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
break;
case ast::TextureDimension::k3d:
out << "3D";
break;
case ast::TextureDimension::kCube:
out << "Cube";
break;
case ast::TextureDimension::kCubeArray:
out << "CubeArray";
break;
default:
TINT_UNREACHABLE(Writer, diagnostics_)
<< "unexpected TextureDimension " << tex->dim();
return false;
}
if (storage) {
auto* component = image_format_to_rwtexture_type(storage->texel_format());
if (component == nullptr) {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported StorageTexture TexelFormat: "
<< static_cast<int>(storage->texel_format());
if (storage) {
auto* component =
image_format_to_rwtexture_type(storage->texel_format());
if (component == nullptr) {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported StorageTexture TexelFormat: "
<< static_cast<int>(storage->texel_format());
return false;
}
out << "<" << component << ">";
} else if (depth_ms) {
out << "<float4>";
} else if (sampled || ms) {
auto* subtype = sampled ? sampled->type() : ms->type();
out << "<";
if (subtype->Is<sem::F32>()) {
out << "float4";
} else if (subtype->Is<sem::I32>()) {
out << "int4";
} else if (subtype->Is<sem::U32>()) {
out << "uint4";
} else {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported multisampled texture type";
return false;
}
out << ">";
}
return true;
},
[&](const sem::U32*) {
out << "uint";
return true;
},
[&](const sem::Vector* vec) {
auto width = vec->Width();
if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
out << "float" << width;
} else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
out << "int" << width;
} else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
out << "uint" << width;
} else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
out << "bool" << width;
} else {
out << "vector<";
if (!EmitType(out, vec->type(), storage_class, access, "")) {
return false;
}
out << ", " << width << ">";
}
return true;
},
[&](const sem::Atomic* atomic) {
return EmitType(out, atomic->Type(), storage_class, access, name);
},
[&](const sem::Void*) {
out << "void";
return true;
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer,
"unknown type in EmitType");
return false;
}
out << "<" << component << ">";
} else if (depth_ms) {
out << "<float4>";
} else if (sampled || ms) {
auto* subtype = sampled ? sampled->type() : ms->type();
out << "<";
if (subtype->Is<sem::F32>()) {
out << "float4";
} else if (subtype->Is<sem::I32>()) {
out << "int4";
} else if (subtype->Is<sem::U32>()) {
out << "uint4";
} else {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported multisampled texture type";
return false;
}
out << ">";
}
} else if (type->Is<sem::U32>()) {
out << "uint";
} else if (auto* vec = type->As<sem::Vector>()) {
auto width = vec->Width();
if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
out << "float" << width;
} else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
out << "int" << width;
} else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
out << "uint" << width;
} else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
out << "bool" << width;
} else {
out << "vector<";
if (!EmitType(out, vec->type(), storage_class, access, "")) {
return false;
}
out << ", " << width << ">";
}
} else if (auto* atomic = type->As<sem::Atomic>()) {
if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
return false;
}
} else if (type->Is<sem::Void>()) {
out << "void";
} else {
diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
return false;
}
return true;
});
}
bool GeneratorImpl::EmitTypeAndName(std::ostream& out,

File diff suppressed because it is too large Load Diff

View File

@ -560,33 +560,37 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
}
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
return GenerateAccessorExpression(a);
}
if (auto* b = expr->As<ast::BinaryExpression>()) {
return GenerateBinaryExpression(b);
}
if (auto* b = expr->As<ast::BitcastExpression>()) {
return GenerateBitcastExpression(b);
}
if (auto* c = expr->As<ast::CallExpression>()) {
return GenerateCallExpression(c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return GenerateIdentifierExpression(i);
}
if (auto* l = expr->As<ast::LiteralExpression>()) {
return GenerateLiteralIfNeeded(nullptr, l);
}
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
return GenerateAccessorExpression(m);
}
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
return GenerateUnaryOpExpression(u);
}
error_ = "unknown expression type: " + std::string(expr->TypeInfo().name);
return 0;
return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
return GenerateAccessorExpression(a);
},
[&](const ast::BinaryExpression* b) { //
return GenerateBinaryExpression(b);
},
[&](const ast::BitcastExpression* b) { //
return GenerateBitcastExpression(b);
},
[&](const ast::CallExpression* c) { //
return GenerateCallExpression(c);
},
[&](const ast::IdentifierExpression* i) { //
return GenerateIdentifierExpression(i);
},
[&](const ast::LiteralExpression* l) { //
return GenerateLiteralIfNeeded(nullptr, l);
},
[&](const ast::MemberAccessorExpression* m) { //
return GenerateAccessorExpression(m);
},
[&](const ast::UnaryOpExpression* u) { //
return GenerateUnaryOpExpression(u);
},
[&](Default) -> uint32_t {
error_ =
"unknown expression type: " + std::string(expr->TypeInfo().name);
return 0;
});
}
bool Builder::GenerateFunction(const ast::Function* func_ast) {
@ -861,33 +865,56 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
push_type(spv::Op::OpVariable, std::move(ops));
for (auto* attr : var->attributes) {
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
Operand::Int(
ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
} else if (auto* location = attr->As<ast::LocationAttribute>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
Operand::Int(location->value)});
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
AddInterpolationDecorations(var_id, interpolate->type,
interpolate->sampling);
} else if (attr->Is<ast::InvariantAttribute>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
} else if (auto* binding = attr->As<ast::BindingAttribute>()) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
Operand::Int(binding->value)});
} else if (auto* group = attr->As<ast::GroupAttribute>()) {
push_annot(spv::Op::OpDecorate, {Operand::Int(var_id),
Operand::Int(SpvDecorationDescriptorSet),
Operand::Int(group->value)});
} else if (attr->Is<ast::OverrideAttribute>()) {
// Spec constants are handled elsewhere
} else if (!attr->Is<ast::InternalAttribute>()) {
error_ = "unknown attribute";
bool ok = Switch(
attr,
[&](const ast::BuiltinAttribute* builtin) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
Operand::Int(ConvertBuiltin(builtin->builtin,
sem->StorageClass()))});
return true;
},
[&](const ast::LocationAttribute* location) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
Operand::Int(location->value)});
return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
AddInterpolationDecorations(var_id, interpolate->type,
interpolate->sampling);
return true;
},
[&](const ast::InvariantAttribute*) {
push_annot(
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
return true;
},
[&](const ast::BindingAttribute* binding) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
Operand::Int(binding->value)});
return true;
},
[&](const ast::GroupAttribute* group) {
push_annot(
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
Operand::Int(group->value)});
return true;
},
[&](const ast::OverrideAttribute*) {
return true; // Spec constants are handled elsewhere
},
[&](const ast::InternalAttribute*) {
return true; // ignored
},
[&](Default) {
error_ = "unknown attribute";
return false;
});
if (!ok) {
return false;
}
}
@ -1123,19 +1150,21 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
// promoted to storage with the VarForDynamicIndex transform.
for (auto* accessor : accessors) {
if (auto* array = accessor->As<ast::IndexAccessorExpression>()) {
if (!GenerateIndexAccessor(array, &info)) {
return 0;
}
} else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) {
if (!GenerateMemberAccessor(member, &info)) {
return 0;
}
} else {
error_ =
"invalid accessor in list: " + std::string(accessor->TypeInfo().name);
return 0;
bool ok = Switch(
accessor,
[&](const ast::IndexAccessorExpression* array) {
return GenerateIndexAccessor(array, &info);
},
[&](const ast::MemberAccessorExpression* member) {
return GenerateMemberAccessor(member, &info);
},
[&](Default) {
error_ = "invalid accessor in list: " +
std::string(accessor->TypeInfo().name);
return false;
});
if (!ok) {
return false;
}
}
@ -1653,21 +1682,28 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
constant.constant_id = global->ConstantId();
}
if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
constant.kind = ScalarConstant::Kind::kBool;
constant.value.b = l->value;
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
constant.kind = ScalarConstant::Kind::kI32;
constant.value.i32 = sl->value;
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
constant.kind = ScalarConstant::Kind::kU32;
constant.value.u32 = ul->value;
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
constant.kind = ScalarConstant::Kind::kF32;
constant.value.f32 = fl->value;
} else {
error_ = "unknown literal type";
return 0;
Switch(
lit,
[&](const ast::BoolLiteralExpression* l) {
constant.kind = ScalarConstant::Kind::kBool;
constant.value.b = l->value;
},
[&](const ast::SintLiteralExpression* sl) {
constant.kind = ScalarConstant::Kind::kI32;
constant.value.i32 = sl->value;
},
[&](const ast::UintLiteralExpression* ul) {
constant.kind = ScalarConstant::Kind::kU32;
constant.value.u32 = ul->value;
},
[&](const ast::FloatLiteralExpression* fl) {
constant.kind = ScalarConstant::Kind::kF32;
constant.value.f32 = fl->value;
},
[&](Default) { error_ = "unknown literal type"; });
if (!error_.empty()) {
return false;
}
return GenerateConstantIfNeeded(constant);
@ -2209,19 +2245,25 @@ bool Builder::GenerateBlockStatementWithoutScoping(
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
if (auto* func = target->As<sem::Function>()) {
return GenerateFunctionCall(call, func);
}
if (auto* builtin = target->As<sem::Builtin>()) {
return GenerateBuiltinCall(call, builtin);
}
if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
return GenerateTypeConstructorOrConversion(call, nullptr);
}
TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name;
return false;
return Switch(
target,
[&](const sem::Function* func) {
return GenerateFunctionCall(call, func);
},
[&](const sem::Builtin* builtin) {
return GenerateBuiltinCall(call, builtin);
},
[&](const sem::TypeConversion*) {
return GenerateTypeConstructorOrConversion(call, nullptr);
},
[&](const sem::TypeConstructor*) {
return GenerateTypeConstructorOrConversion(call, nullptr);
},
[&](Default) -> uint32_t {
TINT_ICE(Writer, builder_.Diagnostics())
<< "unhandled call target: " << target->TypeInfo().name;
return 0;
});
}
uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
@ -3790,46 +3832,49 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
}
bool Builder::GenerateStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return GenerateAssignStatement(a);
}
if (auto* b = stmt->As<ast::BlockStatement>()) {
return GenerateBlockStatement(b);
}
if (auto* b = stmt->As<ast::BreakStatement>()) {
return GenerateBreakStatement(b);
}
if (auto* c = stmt->As<ast::CallStatement>()) {
return GenerateCallExpression(c->expr) != 0;
}
if (auto* c = stmt->As<ast::ContinueStatement>()) {
return GenerateContinueStatement(c);
}
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return GenerateDiscardStatement(d);
}
if (stmt->Is<ast::FallthroughStatement>()) {
// Do nothing here, the fallthrough gets handled by the switch code.
return true;
}
if (auto* i = stmt->As<ast::IfStatement>()) {
return GenerateIfStatement(i);
}
if (auto* l = stmt->As<ast::LoopStatement>()) {
return GenerateLoopStatement(l);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return GenerateReturnStatement(r);
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return GenerateSwitchStatement(s);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return GenerateVariableDeclStatement(v);
}
error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
return false;
return Switch(
stmt,
[&](const ast::AssignmentStatement* a) {
return GenerateAssignStatement(a);
},
[&](const ast::BlockStatement* b) { //
return GenerateBlockStatement(b);
},
[&](const ast::BreakStatement* b) { //
return GenerateBreakStatement(b);
},
[&](const ast::CallStatement* c) {
return GenerateCallExpression(c->expr) != 0;
},
[&](const ast::ContinueStatement* c) {
return GenerateContinueStatement(c);
},
[&](const ast::DiscardStatement* d) {
return GenerateDiscardStatement(d);
},
[&](const ast::FallthroughStatement*) {
// Do nothing here, the fallthrough gets handled by the switch code.
return true;
},
[&](const ast::IfStatement* i) { //
return GenerateIfStatement(i);
},
[&](const ast::LoopStatement* l) { //
return GenerateLoopStatement(l);
},
[&](const ast::ReturnStatement* r) { //
return GenerateReturnStatement(r);
},
[&](const ast::SwitchStatement* s) { //
return GenerateSwitchStatement(s);
},
[&](const ast::VariableDeclStatement* v) {
return GenerateVariableDeclStatement(v);
},
[&](Default) {
error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
return false;
});
}
bool Builder::GenerateVariableDeclStatement(
@ -3872,78 +3917,91 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
auto result = result_op();
auto id = result.to_i();
if (auto* arr = type->As<sem::Array>()) {
if (!GenerateArrayType(arr, result)) {
return 0;
}
} else if (type->Is<sem::Bool>()) {
push_type(spv::Op::OpTypeBool, {result});
} else if (type->Is<sem::F32>()) {
push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
} else if (type->Is<sem::I32>()) {
push_type(spv::Op::OpTypeInt,
{result, Operand::Int(32), Operand::Int(1)});
} else if (auto* mat = type->As<sem::Matrix>()) {
if (!GenerateMatrixType(mat, result)) {
return 0;
}
} else if (auto* ptr = type->As<sem::Pointer>()) {
if (!GeneratePointerType(ptr, result)) {
return 0;
}
} else if (auto* ref = type->As<sem::Reference>()) {
if (!GenerateReferenceType(ref, result)) {
return 0;
}
} else if (auto* str = type->As<sem::Struct>()) {
if (!GenerateStructType(str, result)) {
return 0;
}
} else if (type->Is<sem::U32>()) {
push_type(spv::Op::OpTypeInt,
{result, Operand::Int(32), Operand::Int(0)});
} else if (auto* vec = type->As<sem::Vector>()) {
if (!GenerateVectorType(vec, result)) {
return 0;
}
} else if (type->Is<sem::Void>()) {
push_type(spv::Op::OpTypeVoid, {result});
} else if (auto* tex = type->As<sem::Texture>()) {
if (!GenerateTextureType(tex, result)) {
return 0;
}
bool ok = Switch(
type,
[&](const sem::Array* arr) { //
return GenerateArrayType(arr, result);
},
[&](const sem::Bool*) {
push_type(spv::Op::OpTypeBool, {result});
return true;
},
[&](const sem::F32*) {
push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
return true;
},
[&](const sem::I32*) {
push_type(spv::Op::OpTypeInt,
{result, Operand::Int(32), Operand::Int(1)});
return true;
},
[&](const sem::Matrix* mat) { //
return GenerateMatrixType(mat, result);
},
[&](const sem::Pointer* ptr) { //
return GeneratePointerType(ptr, result);
},
[&](const sem::Reference* ref) { //
return GenerateReferenceType(ref, result);
},
[&](const sem::Struct* str) { //
return GenerateStructType(str, result);
},
[&](const sem::U32*) {
push_type(spv::Op::OpTypeInt,
{result, Operand::Int(32), Operand::Int(0)});
return true;
},
[&](const sem::Vector* vec) { //
return GenerateVectorType(vec, result);
},
[&](const sem::Void*) {
push_type(spv::Op::OpTypeVoid, {result});
return true;
},
[&](const sem::StorageTexture* tex) {
if (!GenerateTextureType(tex, result)) {
return false;
}
if (auto* st = tex->As<sem::StorageTexture>()) {
// Register all three access types of StorageTexture names. In SPIR-V,
// we must output a single type, while the variable is annotated with
// the access type. Doing this ensures we de-dupe.
type_name_to_id_[builder_
.create<sem::StorageTexture>(
st->dim(), st->texel_format(),
ast::Access::kRead, st->type())
->type_name()] = id;
type_name_to_id_[builder_
.create<sem::StorageTexture>(
st->dim(), st->texel_format(),
ast::Access::kWrite, st->type())
->type_name()] = id;
type_name_to_id_[builder_
.create<sem::StorageTexture>(
st->dim(), st->texel_format(),
ast::Access::kReadWrite, st->type())
->type_name()] = id;
}
// Register all three access types of StorageTexture names. In
// SPIR-V, we must output a single type, while the variable is
// annotated with the access type. Doing this ensures we de-dupe.
type_name_to_id_[builder_
.create<sem::StorageTexture>(
tex->dim(), tex->texel_format(),
ast::Access::kRead, tex->type())
->type_name()] = id;
type_name_to_id_[builder_
.create<sem::StorageTexture>(
tex->dim(), tex->texel_format(),
ast::Access::kWrite, tex->type())
->type_name()] = id;
type_name_to_id_[builder_
.create<sem::StorageTexture>(
tex->dim(), tex->texel_format(),
ast::Access::kReadWrite, tex->type())
->type_name()] = id;
return true;
},
[&](const sem::Texture* tex) {
return GenerateTextureType(tex, result);
},
[&](const sem::Sampler*) {
push_type(spv::Op::OpTypeSampler, {result});
} else if (type->Is<sem::Sampler>()) {
push_type(spv::Op::OpTypeSampler, {result});
// Register both of the sampler type names. In SPIR-V they're the same
// sampler type, so we need to match that when we do the dedup check.
type_name_to_id_["__sampler_sampler"] = id;
type_name_to_id_["__sampler_comparison"] = id;
return true;
},
[&](Default) {
error_ = "unable to convert type: " + type->type_name();
return false;
});
// Register both of the sampler type names. In SPIR-V they're the same
// sampler type, so we need to match that when we do the dedup check.
type_name_to_id_["__sampler_sampler"] = id;
type_name_to_id_["__sampler_comparison"] = id;
} else {
error_ = "unable to convert type: " + type->type_name();
if (!ok) {
return 0;
}
@ -3995,22 +4053,31 @@ bool Builder::GenerateTextureType(const sem::Texture* texture,
}
if (dim == ast::TextureDimension::kCubeArray) {
if (texture->Is<sem::SampledTexture>() ||
texture->Is<sem::DepthTexture>()) {
if (texture->IsAnyOf<sem::SampledTexture, sem::DepthTexture>()) {
push_capability(SpvCapabilitySampledCubeArray);
}
}
uint32_t type_id = 0u;
if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
} else if (auto* s = texture->As<sem::SampledTexture>()) {
type_id = GenerateTypeIfNeeded(s->type());
} else if (auto* ms = texture->As<sem::MultisampledTexture>()) {
type_id = GenerateTypeIfNeeded(ms->type());
} else if (auto* st = texture->As<sem::StorageTexture>()) {
type_id = GenerateTypeIfNeeded(st->type());
}
uint32_t type_id = Switch(
texture,
[&](const sem::DepthTexture*) {
return GenerateTypeIfNeeded(builder_.create<sem::F32>());
},
[&](const sem::DepthMultisampledTexture*) {
return GenerateTypeIfNeeded(builder_.create<sem::F32>());
},
[&](const sem::SampledTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](const sem::MultisampledTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](const sem::StorageTexture* t) {
return GenerateTypeIfNeeded(t->type());
},
[&](Default) -> uint32_t { //
return 0u;
});
if (type_id == 0u) {
return false;
}

View File

@ -68,23 +68,17 @@ GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() {
// Generate global declarations in the order they appear in the module.
for (auto* decl : program_->AST().GlobalDeclarations()) {
if (auto* td = decl->As<ast::TypeDecl>()) {
if (!EmitTypeDecl(td)) {
return false;
}
} else if (auto* func = decl->As<ast::Function>()) {
if (!EmitFunction(func)) {
return false;
}
} else if (auto* var = decl->As<ast::Variable>()) {
if (!EmitVariable(line(), var)) {
return false;
}
} else {
TINT_UNREACHABLE(Writer, diagnostics_);
if (!Switch(
decl, //
[&](const ast::TypeDecl* td) { return EmitTypeDecl(td); },
[&](const ast::Function* func) { return EmitFunction(func); },
[&](const ast::Variable* var) { return EmitVariable(line(), var); },
[&](Default) {
TINT_UNREACHABLE(Writer, diagnostics_);
return false;
})) {
return false;
}
if (decl != program_->AST().GlobalDeclarations().back()) {
line();
}
@ -94,59 +88,64 @@ bool GeneratorImpl::Generate() {
}
bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
if (auto* alias = ty->As<ast::Alias>()) {
auto out = line();
out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
if (!EmitType(out, alias->type)) {
return false;
}
out << ";";
} else if (auto* str = ty->As<ast::Struct>()) {
if (!EmitStructType(str)) {
return false;
}
} else {
diagnostics_.add_error(
diag::System::Writer,
"unknown declared type: " + std::string(ty->TypeInfo().name));
return false;
}
return true;
return Switch(
ty,
[&](const ast::Alias* alias) { //
auto out = line();
out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
if (!EmitType(out, alias->type)) {
return false;
}
out << ";";
return true;
},
[&](const ast::Struct* str) { //
return EmitStructType(str);
},
[&](Default) { //
diagnostics_.add_error(
diag::System::Writer,
"unknown declared type: " + std::string(ty->TypeInfo().name));
return false;
});
}
bool GeneratorImpl::EmitExpression(std::ostream& out,
const ast::Expression* expr) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
return EmitIndexAccessor(out, a);
}
if (auto* b = expr->As<ast::BinaryExpression>()) {
return EmitBinary(out, b);
}
if (auto* b = expr->As<ast::BitcastExpression>()) {
return EmitBitcast(out, b);
}
if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c);
}
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i);
}
if (auto* l = expr->As<ast::LiteralExpression>()) {
return EmitLiteral(out, l);
}
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
return EmitMemberAccessor(out, m);
}
if (expr->Is<ast::PhonyExpression>()) {
out << "_";
return true;
}
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
return EmitUnaryOp(out, u);
}
diagnostics_.add_error(diag::System::Writer, "unknown expression type");
return false;
return Switch(
expr,
[&](const ast::IndexAccessorExpression* a) { //
return EmitIndexAccessor(out, a);
},
[&](const ast::BinaryExpression* b) { //
return EmitBinary(out, b);
},
[&](const ast::BitcastExpression* b) { //
return EmitBitcast(out, b);
},
[&](const ast::CallExpression* c) { //
return EmitCall(out, c);
},
[&](const ast::IdentifierExpression* i) { //
return EmitIdentifier(out, i);
},
[&](const ast::LiteralExpression* l) { //
return EmitLiteral(out, l);
},
[&](const ast::MemberAccessorExpression* m) { //
return EmitMemberAccessor(out, m);
},
[&](const ast::PhonyExpression*) { //
out << "_";
return true;
},
[&](const ast::UnaryOpExpression* u) { //
return EmitUnaryOp(out, u);
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown expression type");
return false;
});
}
bool GeneratorImpl::EmitIndexAccessor(
@ -250,19 +249,28 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) {
if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {
out << (bl->value ? "true" : "false");
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
out << FloatToBitPreservingString(fl->value);
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
out << sl->value;
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
out << ul->value << "u";
} else {
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
}
return true;
return Switch(
lit,
[&](const ast::BoolLiteralExpression* bl) { //
out << (bl->value ? "true" : "false");
return true;
},
[&](const ast::FloatLiteralExpression* fl) { //
out << FloatToBitPreservingString(fl->value);
return true;
},
[&](const ast::SintLiteralExpression* sl) { //
out << sl->value;
return true;
},
[&](const ast::UintLiteralExpression* ul) { //
out << ul->value << "u";
return true;
},
[&](Default) { //
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
return false;
});
}
bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@ -366,155 +374,208 @@ bool GeneratorImpl::EmitAccess(std::ostream& out, const ast::Access access) {
}
bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
if (auto* ary = ty->As<ast::Array>()) {
for (auto* attr : ary->attributes) {
if (auto* stride = attr->As<ast::StrideAttribute>()) {
out << "@stride(" << stride->stride << ") ";
}
}
return Switch(
ty,
[&](const ast::Array* ary) {
for (auto* attr : ary->attributes) {
if (auto* stride = attr->As<ast::StrideAttribute>()) {
out << "@stride(" << stride->stride << ") ";
}
}
out << "array<";
if (!EmitType(out, ary->type)) {
return false;
}
out << "array<";
if (!EmitType(out, ary->type)) {
return false;
}
if (!ary->IsRuntimeArray()) {
out << ", ";
if (!EmitExpression(out, ary->count)) {
return false;
}
}
if (!ary->IsRuntimeArray()) {
out << ", ";
if (!EmitExpression(out, ary->count)) {
return false;
}
}
out << ">";
} else if (ty->Is<ast::Bool>()) {
out << "bool";
} else if (ty->Is<ast::F32>()) {
out << "f32";
} else if (ty->Is<ast::I32>()) {
out << "i32";
} else if (auto* mat = ty->As<ast::Matrix>()) {
out << "mat" << mat->columns << "x" << mat->rows;
if (auto* el_ty = mat->type) {
out << "<";
if (!EmitType(out, el_ty)) {
return false;
}
out << ">";
}
} else if (auto* ptr = ty->As<ast::Pointer>()) {
out << "ptr<" << ptr->storage_class << ", ";
if (!EmitType(out, ptr->type)) {
return false;
}
if (ptr->access != ast::Access::kUndefined) {
out << ", ";
if (!EmitAccess(out, ptr->access)) {
return false;
}
}
out << ">";
} else if (auto* atomic = ty->As<ast::Atomic>()) {
out << "atomic<";
if (!EmitType(out, atomic->type)) {
return false;
}
out << ">";
} else if (auto* sampler = ty->As<ast::Sampler>()) {
out << "sampler";
out << ">";
return true;
},
[&](const ast::Bool*) {
out << "bool";
return true;
},
[&](const ast::F32*) {
out << "f32";
return true;
},
[&](const ast::I32*) {
out << "i32";
return true;
},
[&](const ast::Matrix* mat) {
out << "mat" << mat->columns << "x" << mat->rows;
if (auto* el_ty = mat->type) {
out << "<";
if (!EmitType(out, el_ty)) {
return false;
}
out << ">";
}
return true;
},
[&](const ast::Pointer* ptr) {
out << "ptr<" << ptr->storage_class << ", ";
if (!EmitType(out, ptr->type)) {
return false;
}
if (ptr->access != ast::Access::kUndefined) {
out << ", ";
if (!EmitAccess(out, ptr->access)) {
return false;
}
}
out << ">";
return true;
},
[&](const ast::Atomic* atomic) {
out << "atomic<";
if (!EmitType(out, atomic->type)) {
return false;
}
out << ">";
return true;
},
[&](const ast::Sampler* sampler) {
out << "sampler";
if (sampler->IsComparison()) {
out << "_comparison";
}
} else if (ty->Is<ast::ExternalTexture>()) {
out << "texture_external";
} else if (auto* texture = ty->As<ast::Texture>()) {
out << "texture_";
if (texture->Is<ast::DepthTexture>()) {
out << "depth_";
} else if (texture->Is<ast::DepthMultisampledTexture>()) {
out << "depth_multisampled_";
} else if (texture->Is<ast::SampledTexture>()) {
/* nothing to emit */
} else if (texture->Is<ast::MultisampledTexture>()) {
out << "multisampled_";
} else if (texture->Is<ast::StorageTexture>()) {
out << "storage_";
} else {
diagnostics_.add_error(diag::System::Writer, "unknown texture type");
return false;
}
if (sampler->IsComparison()) {
out << "_comparison";
}
return true;
},
[&](const ast::ExternalTexture*) {
out << "texture_external";
return true;
},
[&](const ast::Texture* texture) {
out << "texture_";
bool ok = Switch(
texture,
[&](const ast::DepthTexture*) { //
out << "depth_";
return true;
},
[&](const ast::DepthMultisampledTexture*) { //
out << "depth_multisampled_";
return true;
},
[&](const ast::SampledTexture*) { //
/* nothing to emit */
return true;
},
[&](const ast::MultisampledTexture*) { //
out << "multisampled_";
return true;
},
[&](const ast::StorageTexture*) { //
out << "storage_";
return true;
},
[&](Default) { //
diagnostics_.add_error(diag::System::Writer,
"unknown texture type");
return false;
});
if (!ok) {
return false;
}
switch (texture->dim) {
case ast::TextureDimension::k1d:
out << "1d";
break;
case ast::TextureDimension::k2d:
out << "2d";
break;
case ast::TextureDimension::k2dArray:
out << "2d_array";
break;
case ast::TextureDimension::k3d:
out << "3d";
break;
case ast::TextureDimension::kCube:
out << "cube";
break;
case ast::TextureDimension::kCubeArray:
out << "cube_array";
break;
default:
diagnostics_.add_error(diag::System::Writer,
"unknown texture dimension");
return false;
}
switch (texture->dim) {
case ast::TextureDimension::k1d:
out << "1d";
break;
case ast::TextureDimension::k2d:
out << "2d";
break;
case ast::TextureDimension::k2dArray:
out << "2d_array";
break;
case ast::TextureDimension::k3d:
out << "3d";
break;
case ast::TextureDimension::kCube:
out << "cube";
break;
case ast::TextureDimension::kCubeArray:
out << "cube_array";
break;
default:
diagnostics_.add_error(diag::System::Writer,
"unknown texture dimension");
return false;
}
if (auto* sampled = texture->As<ast::SampledTexture>()) {
out << "<";
if (!EmitType(out, sampled->type)) {
return Switch(
texture,
[&](const ast::SampledTexture* sampled) { //
out << "<";
if (!EmitType(out, sampled->type)) {
return false;
}
out << ">";
return true;
},
[&](const ast::MultisampledTexture* ms) { //
out << "<";
if (!EmitType(out, ms->type)) {
return false;
}
out << ">";
return true;
},
[&](const ast::StorageTexture* storage) { //
out << "<";
if (!EmitImageFormat(out, storage->format)) {
return false;
}
out << ", ";
if (!EmitAccess(out, storage->access)) {
return false;
}
out << ">";
return true;
},
[&](Default) { //
return true;
});
},
[&](const ast::U32*) {
out << "u32";
return true;
},
[&](const ast::Vector* vec) {
out << "vec" << vec->width;
if (auto* el_ty = vec->type) {
out << "<";
if (!EmitType(out, el_ty)) {
return false;
}
out << ">";
}
return true;
},
[&](const ast::Void*) {
out << "void";
return true;
},
[&](const ast::TypeName* tn) {
out << program_->Symbols().NameFor(tn->name);
return true;
},
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
"unknown type in EmitType: " + std::string(ty->TypeInfo().name));
return false;
}
out << ">";
} else if (auto* ms = texture->As<ast::MultisampledTexture>()) {
out << "<";
if (!EmitType(out, ms->type)) {
return false;
}
out << ">";
} else if (auto* storage = texture->As<ast::StorageTexture>()) {
out << "<";
if (!EmitImageFormat(out, storage->format)) {
return false;
}
out << ", ";
if (!EmitAccess(out, storage->access)) {
return false;
}
out << ">";
}
} else if (ty->Is<ast::U32>()) {
out << "u32";
} else if (auto* vec = ty->As<ast::Vector>()) {
out << "vec" << vec->width;
if (auto* el_ty = vec->type) {
out << "<";
if (!EmitType(out, el_ty)) {
return false;
}
out << ">";
}
} else if (ty->Is<ast::Void>()) {
out << "void";
} else if (auto* tn = ty->As<ast::TypeName>()) {
out << program_->Symbols().NameFor(tn->name);
} else {
diagnostics_.add_error(
diag::System::Writer,
"unknown type in EmitType: " + std::string(ty->TypeInfo().name));
return false;
}
return true;
});
}
bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
@ -632,56 +693,90 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out,
}
first = false;
out << "@";
if (auto* workgroup = attr->As<ast::WorkgroupAttribute>()) {
auto values = workgroup->Values();
out << "workgroup_size(";
for (int i = 0; i < 3; i++) {
if (values[i]) {
if (i > 0) {
out << ", ";
bool ok = Switch(
attr,
[&](const ast::WorkgroupAttribute* workgroup) {
auto values = workgroup->Values();
out << "workgroup_size(";
for (int i = 0; i < 3; i++) {
if (values[i]) {
if (i > 0) {
out << ", ";
}
if (!EmitExpression(out, values[i])) {
return false;
}
}
}
if (!EmitExpression(out, values[i])) {
return false;
out << ")";
return true;
},
[&](const ast::StructBlockAttribute*) { //
out << "block";
return true;
},
[&](const ast::StageAttribute* stage) {
out << "stage(" << stage->stage << ")";
return true;
},
[&](const ast::BindingAttribute* binding) {
out << "binding(" << binding->value << ")";
return true;
},
[&](const ast::GroupAttribute* group) {
out << "group(" << group->value << ")";
return true;
},
[&](const ast::LocationAttribute* location) {
out << "location(" << location->value << ")";
return true;
},
[&](const ast::BuiltinAttribute* builtin) {
out << "builtin(" << builtin->builtin << ")";
return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
out << "interpolate(" << interpolate->type;
if (interpolate->sampling != ast::InterpolationSampling::kNone) {
out << ", " << interpolate->sampling;
}
}
}
out << ")";
} else if (attr->Is<ast::StructBlockAttribute>()) {
out << "block";
} else if (auto* stage = attr->As<ast::StageAttribute>()) {
out << "stage(" << stage->stage << ")";
} else if (auto* binding = attr->As<ast::BindingAttribute>()) {
out << "binding(" << binding->value << ")";
} else if (auto* group = attr->As<ast::GroupAttribute>()) {
out << "group(" << group->value << ")";
} else if (auto* location = attr->As<ast::LocationAttribute>()) {
out << "location(" << location->value << ")";
} else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
out << "builtin(" << builtin->builtin << ")";
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
out << "interpolate(" << interpolate->type;
if (interpolate->sampling != ast::InterpolationSampling::kNone) {
out << ", " << interpolate->sampling;
}
out << ")";
} else if (attr->Is<ast::InvariantAttribute>()) {
out << "invariant";
} else if (auto* override_attr = attr->As<ast::OverrideAttribute>()) {
out << "override";
if (override_attr->has_value) {
out << "(" << override_attr->value << ")";
}
} else if (auto* size = attr->As<ast::StructMemberSizeAttribute>()) {
out << "size(" << size->size << ")";
} else if (auto* align = attr->As<ast::StructMemberAlignAttribute>()) {
out << "align(" << align->align << ")";
} else if (auto* stride = attr->As<ast::StrideAttribute>()) {
out << "stride(" << stride->stride << ")";
} else if (auto* internal = attr->As<ast::InternalAttribute>()) {
out << "internal(" << internal->InternalName() << ")";
} else {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported attribute '" << attr->TypeInfo().name << "'";
out << ")";
return true;
},
[&](const ast::InvariantAttribute*) {
out << "invariant";
return true;
},
[&](const ast::OverrideAttribute* override_deco) {
out << "override";
if (override_deco->has_value) {
out << "(" << override_deco->value << ")";
}
return true;
},
[&](const ast::StructMemberSizeAttribute* size) {
out << "size(" << size->size << ")";
return true;
},
[&](const ast::StructMemberAlignAttribute* align) {
out << "align(" << align->align << ")";
return true;
},
[&](const ast::StrideAttribute* stride) {
out << "stride(" << stride->stride << ")";
return true;
},
[&](const ast::InternalAttribute* internal) {
out << "internal(" << internal->InternalName() << ")";
return true;
},
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "Unsupported attribute '" << attr->TypeInfo().name << "'";
return false;
});
if (!ok) {
return false;
}
}
@ -809,55 +904,36 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
}
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return EmitAssign(a);
}
if (auto* b = stmt->As<ast::BlockStatement>()) {
return EmitBlock(b);
}
if (auto* b = stmt->As<ast::BreakStatement>()) {
return EmitBreak(b);
}
if (auto* c = stmt->As<ast::CallStatement>()) {
auto out = line();
if (!EmitCall(out, c->expr)) {
return false;
}
out << ";";
return true;
}
if (auto* c = stmt->As<ast::ContinueStatement>()) {
return EmitContinue(c);
}
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return EmitDiscard(d);
}
if (auto* f = stmt->As<ast::FallthroughStatement>()) {
return EmitFallthrough(f);
}
if (auto* i = stmt->As<ast::IfStatement>()) {
return EmitIf(i);
}
if (auto* l = stmt->As<ast::LoopStatement>()) {
return EmitLoop(l);
}
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
return EmitForLoop(l);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return EmitSwitch(s);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(line(), v->variable);
}
diagnostics_.add_error(
diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name));
return false;
return Switch(
stmt, //
[&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
[&](const ast::BlockStatement* b) { return EmitBlock(b); },
[&](const ast::BreakStatement* b) { return EmitBreak(b); },
[&](const ast::CallStatement* c) {
auto out = line();
if (!EmitCall(out, c->expr)) {
return false;
}
out << ";";
return true;
},
[&](const ast::ContinueStatement* c) { return EmitContinue(c); },
[&](const ast::DiscardStatement* d) { return EmitDiscard(d); },
[&](const ast::FallthroughStatement* f) { return EmitFallthrough(f); },
[&](const ast::IfStatement* i) { return EmitIf(i); },
[&](const ast::LoopStatement* l) { return EmitLoop(l); },
[&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
[&](const ast::ReturnStatement* r) { return EmitReturn(r); },
[&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
[&](const ast::VariableDeclStatement* v) {
return EmitVariable(line(), v->variable);
},
[&](Default) {
diagnostics_.add_error(
diag::System::Writer,
"unknown statement type: " + std::string(stmt->TypeInfo().name));
return false;
});
}
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {