Add tint::Switch()
A type dispatch helper with replaces chains of: if (auto* a = obj->As<A>()) { ... } else if (auto* b = obj->As<B>()) { ... } else { ... } with: Switch(obj, [&](A* a) { ... }, [&](B* b) { ... }, [&](Default) { ... }); This new helper provides greater opportunities for optimizations, avoids scoping issues with if-else blocks, and is slightly cleaner (IMO). Bug: tint:1383 Change-Id: Ice469a03342ef57cbcf65f69753e4b528ac50137 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78543 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
fa0d64b76d
commit
de857e1c58
|
@ -1160,6 +1160,7 @@ if(TINT_BUILD_BENCHMARKS)
|
|||
endif()
|
||||
|
||||
set(TINT_BENCHMARK_SRC
|
||||
"castable_bench.cc"
|
||||
"bench/benchmark.cc"
|
||||
"reader/wgsl/parser_bench.cc"
|
||||
)
|
||||
|
|
|
@ -35,16 +35,15 @@ Module::Module(ProgramID pid,
|
|||
continue;
|
||||
}
|
||||
|
||||
if (auto* ty = decl->As<ast::TypeDecl>()) {
|
||||
type_decls_.push_back(ty);
|
||||
} else if (auto* func = decl->As<Function>()) {
|
||||
functions_.push_back(func);
|
||||
} else if (auto* var = decl->As<Variable>()) {
|
||||
global_variables_.push_back(var);
|
||||
} else {
|
||||
Switch(
|
||||
decl, //
|
||||
[&](const ast::TypeDecl* type) { type_decls_.push_back(type); },
|
||||
[&](const Function* func) { functions_.push_back(func); },
|
||||
[&](const Variable* var) { global_variables_.push_back(var); },
|
||||
[&](Default) {
|
||||
diag::List diagnostics;
|
||||
TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,19 +100,24 @@ void Module::Copy(CloneContext* ctx, const Module* src) {
|
|||
<< "src global declaration was nullptr";
|
||||
continue;
|
||||
}
|
||||
if (auto* type = decl->As<ast::TypeDecl>()) {
|
||||
Switch(
|
||||
decl,
|
||||
[&](const ast::TypeDecl* type) {
|
||||
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
|
||||
type_decls_.push_back(type);
|
||||
} else if (auto* func = decl->As<Function>()) {
|
||||
},
|
||||
[&](const Function* func) {
|
||||
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
|
||||
functions_.push_back(func);
|
||||
} else if (auto* var = decl->As<Variable>()) {
|
||||
},
|
||||
[&](const Variable* var) {
|
||||
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
|
||||
global_variables_.push_back(var);
|
||||
} else {
|
||||
},
|
||||
[&](Default) {
|
||||
TINT_ICE(AST, ctx->dst->Diagnostics())
|
||||
<< "Unknown global declaration type";
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -101,30 +101,47 @@ bool TraverseExpressions(const ast::Expression* root,
|
|||
}
|
||||
}
|
||||
|
||||
if (auto* idx = expr->As<IndexAccessorExpression>()) {
|
||||
bool ok = Switch(
|
||||
expr,
|
||||
[&](const IndexAccessorExpression* idx) {
|
||||
push_pair(idx->object, idx->index);
|
||||
} else if (auto* bin_op = expr->As<BinaryExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const BinaryExpression* bin_op) {
|
||||
push_pair(bin_op->lhs, bin_op->rhs);
|
||||
} else if (auto* bitcast = expr->As<BitcastExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const BitcastExpression* bitcast) {
|
||||
to_visit.push_back(bitcast->expr);
|
||||
} else if (auto* call = expr->As<CallExpression>()) {
|
||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
|
||||
// function name in the traversal.
|
||||
// to_visit.push_back(call->func);
|
||||
return true;
|
||||
},
|
||||
[&](const CallExpression* call) {
|
||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
||||
// the function name in the traversal. to_visit.push_back(call->func);
|
||||
push_list(call->args);
|
||||
} else if (auto* member = expr->As<MemberAccessorExpression>()) {
|
||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
|
||||
// member name in the traversal.
|
||||
// push_pair(member->structure, member->member);
|
||||
return true;
|
||||
},
|
||||
[&](const MemberAccessorExpression* member) {
|
||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
||||
// the member name in the traversal. push_pair(member->structure,
|
||||
// member->member);
|
||||
to_visit.push_back(member->structure);
|
||||
} else if (auto* unary = expr->As<UnaryOpExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const UnaryOpExpression* unary) {
|
||||
to_visit.push_back(unary->expr);
|
||||
} else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
|
||||
PhonyExpression>()) {
|
||||
// Leaf expression
|
||||
} else {
|
||||
TINT_ICE(AST, diags) << "unhandled expression type: "
|
||||
<< expr->TypeInfo().name;
|
||||
return true; // Leaf expression
|
||||
}
|
||||
TINT_ICE(AST, diags)
|
||||
<< "unhandled expression type: " << expr->TypeInfo().name;
|
||||
return false;
|
||||
});
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -453,6 +453,105 @@ class Castable : public BASE {
|
|||
}
|
||||
};
|
||||
|
||||
/// Default can be used as the default case for a Switch(), when all previous
|
||||
/// cases failed to match.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// Switch(object,
|
||||
/// [&](TypeA*) { /* ... */ },
|
||||
/// [&](TypeB*) { /* ... */ },
|
||||
/// [&](Default) { /* If not TypeA or TypeB */ });
|
||||
/// ```
|
||||
struct Default {};
|
||||
|
||||
/// Switch is used to dispatch one of the provided callback case handler
|
||||
/// functions based on the type of `object` and the parameter type of the case
|
||||
/// handlers. Switch will sequentially check the type of `object` against each
|
||||
/// of the switch case handler functions, and will invoke the first case handler
|
||||
/// function which has a parameter type that matches the object type. When a
|
||||
/// case handler is matched, it will be called with the single argument of
|
||||
/// `object` cast to the case handler's parameter type. Switch will invoke at
|
||||
/// most one case handler. Each of the case functions must have the signature
|
||||
/// `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
|
||||
/// is the return type, consistent across all case handlers.
|
||||
///
|
||||
/// An optional default case function with the signature `R(Default)` can be
|
||||
/// used as the last case. This default case will be called if all previous
|
||||
/// cases failed to match.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// Switch(object,
|
||||
/// [&](TypeA*) { /* ... */ },
|
||||
/// [&](TypeB*) { /* ... */ });
|
||||
///
|
||||
/// Switch(object,
|
||||
/// [&](TypeA*) { /* ... */ },
|
||||
/// [&](TypeB*) { /* ... */ },
|
||||
/// [&](Default) { /* Called if object is not TypeA or TypeB */ });
|
||||
/// ```
|
||||
///
|
||||
/// @param object the object who's type is used to
|
||||
/// @param first_case the first switch case
|
||||
/// @param other_cases additional switch cases (optional)
|
||||
/// @return the value returned by the called case. If no cases matched, then the
|
||||
/// zero value for the consistent case type.
|
||||
template <typename T, typename FIRST_CASE, typename... OTHER_CASES>
|
||||
traits::ReturnType<FIRST_CASE> //
|
||||
Switch(T* object, FIRST_CASE&& first_case, OTHER_CASES&&... other_cases) {
|
||||
using ReturnType = traits::ReturnType<FIRST_CASE>;
|
||||
using CaseType = std::remove_pointer_t<traits::ParameterType<FIRST_CASE, 0>>;
|
||||
static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
|
||||
static_assert(traits::SignatureOfT<FIRST_CASE>::parameter_count == 1,
|
||||
"Switch case must have a single parameter");
|
||||
if constexpr (std::is_same_v<CaseType, Default>) {
|
||||
// Default case. Must be last.
|
||||
(void)object; // 'object' is not used by the Default case.
|
||||
static_assert(sizeof...(other_cases) == 0,
|
||||
"Switch Default case must come last");
|
||||
if constexpr (kHasReturnType) {
|
||||
return first_case({});
|
||||
} else {
|
||||
first_case({});
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// Regular case.
|
||||
static_assert(traits::IsTypeOrDerived<CaseType, CastableBase>::value,
|
||||
"Switch case parameter is not a Castable pointer");
|
||||
// Does the case match?
|
||||
if (auto* ptr = As<CaseType>(object)) {
|
||||
if constexpr (kHasReturnType) {
|
||||
return first_case(ptr);
|
||||
} else {
|
||||
first_case(ptr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Case did not match. Got any more cases to try?
|
||||
if constexpr (sizeof...(other_cases) > 0) {
|
||||
// Try the next cases...
|
||||
if constexpr (kHasReturnType) {
|
||||
auto res = Switch(object, std::forward<OTHER_CASES>(other_cases)...);
|
||||
static_assert(std::is_same_v<decltype(res), ReturnType>,
|
||||
"Switch case types do not have consistent return type");
|
||||
return res;
|
||||
} else {
|
||||
Switch(object, std::forward<OTHER_CASES>(other_cases)...);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// That was the last case. No cases matched.
|
||||
if constexpr (kHasReturnType) {
|
||||
return {};
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tint
|
||||
|
||||
TINT_CASTABLE_POP_DISABLE_WARNINGS();
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "bench/benchmark.h"
|
||||
|
||||
namespace tint {
|
||||
namespace {
|
||||
|
||||
struct Base : public tint::Castable<Base> {};
|
||||
struct A : public tint::Castable<A, Base> {};
|
||||
struct AA : public tint::Castable<AA, A> {};
|
||||
struct AAA : public tint::Castable<AAA, AA> {};
|
||||
struct AAB : public tint::Castable<AAB, AA> {};
|
||||
struct AAC : public tint::Castable<AAC, AA> {};
|
||||
struct AB : public tint::Castable<AB, A> {};
|
||||
struct ABA : public tint::Castable<ABA, AB> {};
|
||||
struct ABB : public tint::Castable<ABB, AB> {};
|
||||
struct ABC : public tint::Castable<ABC, AB> {};
|
||||
struct AC : public tint::Castable<AC, A> {};
|
||||
struct ACA : public tint::Castable<ACA, AC> {};
|
||||
struct ACB : public tint::Castable<ACB, AC> {};
|
||||
struct ACC : public tint::Castable<ACC, AC> {};
|
||||
struct B : public tint::Castable<B, Base> {};
|
||||
struct BA : public tint::Castable<BA, B> {};
|
||||
struct BAA : public tint::Castable<BAA, BA> {};
|
||||
struct BAB : public tint::Castable<BAB, BA> {};
|
||||
struct BAC : public tint::Castable<BAC, BA> {};
|
||||
struct BB : public tint::Castable<BB, B> {};
|
||||
struct BBA : public tint::Castable<BBA, BB> {};
|
||||
struct BBB : public tint::Castable<BBB, BB> {};
|
||||
struct BBC : public tint::Castable<BBC, BB> {};
|
||||
struct BC : public tint::Castable<BC, B> {};
|
||||
struct BCA : public tint::Castable<BCA, BC> {};
|
||||
struct BCB : public tint::Castable<BCB, BC> {};
|
||||
struct BCC : public tint::Castable<BCC, BC> {};
|
||||
struct C : public tint::Castable<C, Base> {};
|
||||
struct CA : public tint::Castable<CA, C> {};
|
||||
struct CAA : public tint::Castable<CAA, CA> {};
|
||||
struct CAB : public tint::Castable<CAB, CA> {};
|
||||
struct CAC : public tint::Castable<CAC, CA> {};
|
||||
struct CB : public tint::Castable<CB, C> {};
|
||||
struct CBA : public tint::Castable<CBA, CB> {};
|
||||
struct CBB : public tint::Castable<CBB, CB> {};
|
||||
struct CBC : public tint::Castable<CBC, CB> {};
|
||||
struct CC : public tint::Castable<CC, C> {};
|
||||
struct CCA : public tint::Castable<CCA, CC> {};
|
||||
struct CCB : public tint::Castable<CCB, CC> {};
|
||||
struct CCC : public tint::Castable<CCC, CC> {};
|
||||
|
||||
using AllTypes = std::tuple<Base,
|
||||
A,
|
||||
AA,
|
||||
AAA,
|
||||
AAB,
|
||||
AAC,
|
||||
AB,
|
||||
ABA,
|
||||
ABB,
|
||||
ABC,
|
||||
AC,
|
||||
ACA,
|
||||
ACB,
|
||||
ACC,
|
||||
B,
|
||||
BA,
|
||||
BAA,
|
||||
BAB,
|
||||
BAC,
|
||||
BB,
|
||||
BBA,
|
||||
BBB,
|
||||
BBC,
|
||||
BC,
|
||||
BCA,
|
||||
BCB,
|
||||
BCC,
|
||||
C,
|
||||
CA,
|
||||
CAA,
|
||||
CAB,
|
||||
CAC,
|
||||
CB,
|
||||
CBA,
|
||||
CBB,
|
||||
CBC,
|
||||
CC,
|
||||
CCA,
|
||||
CCB,
|
||||
CCC>;
|
||||
|
||||
std::vector<std::unique_ptr<Base>> MakeObjects() {
|
||||
std::vector<std::unique_ptr<Base>> out;
|
||||
out.emplace_back(std::make_unique<Base>());
|
||||
out.emplace_back(std::make_unique<A>());
|
||||
out.emplace_back(std::make_unique<AA>());
|
||||
out.emplace_back(std::make_unique<AAA>());
|
||||
out.emplace_back(std::make_unique<AAB>());
|
||||
out.emplace_back(std::make_unique<AAC>());
|
||||
out.emplace_back(std::make_unique<AB>());
|
||||
out.emplace_back(std::make_unique<ABA>());
|
||||
out.emplace_back(std::make_unique<ABB>());
|
||||
out.emplace_back(std::make_unique<ABC>());
|
||||
out.emplace_back(std::make_unique<AC>());
|
||||
out.emplace_back(std::make_unique<ACA>());
|
||||
out.emplace_back(std::make_unique<ACB>());
|
||||
out.emplace_back(std::make_unique<ACC>());
|
||||
out.emplace_back(std::make_unique<B>());
|
||||
out.emplace_back(std::make_unique<BA>());
|
||||
out.emplace_back(std::make_unique<BAA>());
|
||||
out.emplace_back(std::make_unique<BAB>());
|
||||
out.emplace_back(std::make_unique<BAC>());
|
||||
out.emplace_back(std::make_unique<BB>());
|
||||
out.emplace_back(std::make_unique<BBA>());
|
||||
out.emplace_back(std::make_unique<BBB>());
|
||||
out.emplace_back(std::make_unique<BBC>());
|
||||
out.emplace_back(std::make_unique<BC>());
|
||||
out.emplace_back(std::make_unique<BCA>());
|
||||
out.emplace_back(std::make_unique<BCB>());
|
||||
out.emplace_back(std::make_unique<BCC>());
|
||||
out.emplace_back(std::make_unique<C>());
|
||||
out.emplace_back(std::make_unique<CA>());
|
||||
out.emplace_back(std::make_unique<CAA>());
|
||||
out.emplace_back(std::make_unique<CAB>());
|
||||
out.emplace_back(std::make_unique<CAC>());
|
||||
out.emplace_back(std::make_unique<CB>());
|
||||
out.emplace_back(std::make_unique<CBA>());
|
||||
out.emplace_back(std::make_unique<CBB>());
|
||||
out.emplace_back(std::make_unique<CBC>());
|
||||
out.emplace_back(std::make_unique<CC>());
|
||||
out.emplace_back(std::make_unique<CCA>());
|
||||
out.emplace_back(std::make_unique<CCB>());
|
||||
out.emplace_back(std::make_unique<CCC>());
|
||||
return out;
|
||||
}
|
||||
|
||||
void CastableLargeSwitch(::benchmark::State& state) {
|
||||
auto objects = MakeObjects();
|
||||
size_t i = 0;
|
||||
for (auto _ : state) {
|
||||
auto* object = objects[i % objects.size()].get();
|
||||
Switch(
|
||||
object, //
|
||||
[&](const AAA*) { ::benchmark::DoNotOptimize(i += 40); },
|
||||
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 50); },
|
||||
[&](const AAC*) { ::benchmark::DoNotOptimize(i += 60); },
|
||||
[&](const ABA*) { ::benchmark::DoNotOptimize(i += 80); },
|
||||
[&](const ABB*) { ::benchmark::DoNotOptimize(i += 90); },
|
||||
[&](const ABC*) { ::benchmark::DoNotOptimize(i += 100); },
|
||||
[&](const ACA*) { ::benchmark::DoNotOptimize(i += 120); },
|
||||
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
|
||||
[&](const ACC*) { ::benchmark::DoNotOptimize(i += 140); },
|
||||
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
|
||||
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
|
||||
[&](const BAC*) { ::benchmark::DoNotOptimize(i += 190); },
|
||||
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
|
||||
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
|
||||
[&](const BBC*) { ::benchmark::DoNotOptimize(i += 230); },
|
||||
[&](const BCA*) { ::benchmark::DoNotOptimize(i += 250); },
|
||||
[&](const BCB*) { ::benchmark::DoNotOptimize(i += 260); },
|
||||
[&](const BCC*) { ::benchmark::DoNotOptimize(i += 270); },
|
||||
[&](const CA*) { ::benchmark::DoNotOptimize(i += 290); },
|
||||
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
|
||||
[&](const CAB*) { ::benchmark::DoNotOptimize(i += 310); },
|
||||
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 320); },
|
||||
[&](const CBA*) { ::benchmark::DoNotOptimize(i += 340); },
|
||||
[&](const CBB*) { ::benchmark::DoNotOptimize(i += 350); },
|
||||
[&](const CBC*) { ::benchmark::DoNotOptimize(i += 360); },
|
||||
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
|
||||
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
|
||||
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
|
||||
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
|
||||
i = (i * 31) ^ (i << 5);
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(CastableLargeSwitch);
|
||||
|
||||
void CastableMediumSwitch(::benchmark::State& state) {
|
||||
auto objects = MakeObjects();
|
||||
size_t i = 0;
|
||||
for (auto _ : state) {
|
||||
auto* object = objects[i % objects.size()].get();
|
||||
Switch(
|
||||
object, //
|
||||
[&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
|
||||
[&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
|
||||
[&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
|
||||
[&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
|
||||
[&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
|
||||
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
|
||||
[&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
|
||||
[&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
|
||||
[&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
|
||||
[&](Default) { ::benchmark::DoNotOptimize(i += 123); });
|
||||
i = (i * 31) ^ (i << 5);
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(CastableMediumSwitch);
|
||||
|
||||
void CastableSmallSwitch(::benchmark::State& state) {
|
||||
auto objects = MakeObjects();
|
||||
size_t i = 0;
|
||||
for (auto _ : state) {
|
||||
auto* object = objects[i % objects.size()].get();
|
||||
Switch(
|
||||
object, //
|
||||
[&](const AAB*) { ::benchmark::DoNotOptimize(i += 30); },
|
||||
[&](const CAC*) { ::benchmark::DoNotOptimize(i += 290); },
|
||||
[&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); });
|
||||
i = (i * 31) ^ (i << 5);
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(CastableSmallSwitch);
|
||||
|
||||
} // namespace
|
||||
} // namespace tint
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::Base);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::A);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::AA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::AAA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::AAB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::AAC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::AB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::ABA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::ABB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::ABC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::AC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::ACA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::ACB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::ACC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::B);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BAA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BAB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BAC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BBA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BBB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BBC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BCA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BCB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::BCC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::C);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CAA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CAB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CAC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CBA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CBB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CBC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CC);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CCA);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CCB);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::CCC);
|
|
@ -252,6 +252,151 @@ TEST(Castable, As) {
|
|||
ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get()));
|
||||
}
|
||||
|
||||
TEST(Castable, SwitchNoDefault) {
|
||||
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
|
||||
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
|
||||
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
|
||||
{
|
||||
bool frog_matched_amphibian = false;
|
||||
Switch(
|
||||
frog.get(), //
|
||||
[&](Reptile*) { FAIL() << "frog is not reptile"; },
|
||||
[&](Mammal*) { FAIL() << "frog is not mammal"; },
|
||||
[&](Amphibian* amphibian) {
|
||||
EXPECT_EQ(amphibian, frog.get());
|
||||
frog_matched_amphibian = true;
|
||||
});
|
||||
EXPECT_TRUE(frog_matched_amphibian);
|
||||
}
|
||||
{
|
||||
bool bear_matched_mammal = false;
|
||||
Switch(
|
||||
bear.get(), //
|
||||
[&](Reptile*) { FAIL() << "bear is not reptile"; },
|
||||
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
|
||||
[&](Mammal* mammal) {
|
||||
EXPECT_EQ(mammal, bear.get());
|
||||
bear_matched_mammal = true;
|
||||
});
|
||||
EXPECT_TRUE(bear_matched_mammal);
|
||||
}
|
||||
{
|
||||
bool gecko_matched_reptile = false;
|
||||
Switch(
|
||||
gecko.get(), //
|
||||
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
|
||||
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
|
||||
[&](Reptile* reptile) {
|
||||
EXPECT_EQ(reptile, gecko.get());
|
||||
gecko_matched_reptile = true;
|
||||
});
|
||||
EXPECT_TRUE(gecko_matched_reptile);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Castable, SwitchWithUnusedDefault) {
|
||||
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
|
||||
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
|
||||
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
|
||||
{
|
||||
bool frog_matched_amphibian = false;
|
||||
Switch(
|
||||
frog.get(), //
|
||||
[&](Reptile*) { FAIL() << "frog is not reptile"; },
|
||||
[&](Mammal*) { FAIL() << "frog is not mammal"; },
|
||||
[&](Amphibian* amphibian) {
|
||||
EXPECT_EQ(amphibian, frog.get());
|
||||
frog_matched_amphibian = true;
|
||||
},
|
||||
[&](Default) { FAIL() << "default should not have been selected"; });
|
||||
EXPECT_TRUE(frog_matched_amphibian);
|
||||
}
|
||||
{
|
||||
bool bear_matched_mammal = false;
|
||||
Switch(
|
||||
bear.get(), //
|
||||
[&](Reptile*) { FAIL() << "bear is not reptile"; },
|
||||
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
|
||||
[&](Mammal* mammal) {
|
||||
EXPECT_EQ(mammal, bear.get());
|
||||
bear_matched_mammal = true;
|
||||
},
|
||||
[&](Default) { FAIL() << "default should not have been selected"; });
|
||||
EXPECT_TRUE(bear_matched_mammal);
|
||||
}
|
||||
{
|
||||
bool gecko_matched_reptile = false;
|
||||
Switch(
|
||||
gecko.get(), //
|
||||
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
|
||||
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
|
||||
[&](Reptile* reptile) {
|
||||
EXPECT_EQ(reptile, gecko.get());
|
||||
gecko_matched_reptile = true;
|
||||
},
|
||||
[&](Default) { FAIL() << "default should not have been selected"; });
|
||||
EXPECT_TRUE(gecko_matched_reptile);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Castable, SwitchDefault) {
|
||||
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
|
||||
std::unique_ptr<Animal> bear = std::make_unique<Bear>();
|
||||
std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
|
||||
{
|
||||
bool frog_matched_default = false;
|
||||
Switch(
|
||||
frog.get(), //
|
||||
[&](Reptile*) { FAIL() << "frog is not reptile"; },
|
||||
[&](Mammal*) { FAIL() << "frog is not mammal"; },
|
||||
[&](Default) { frog_matched_default = true; });
|
||||
EXPECT_TRUE(frog_matched_default);
|
||||
}
|
||||
{
|
||||
bool bear_matched_default = false;
|
||||
Switch(
|
||||
bear.get(), //
|
||||
[&](Reptile*) { FAIL() << "bear is not reptile"; },
|
||||
[&](Amphibian*) { FAIL() << "bear is not amphibian"; },
|
||||
[&](Default) { bear_matched_default = true; });
|
||||
EXPECT_TRUE(bear_matched_default);
|
||||
}
|
||||
{
|
||||
bool gecko_matched_default = false;
|
||||
Switch(
|
||||
gecko.get(), //
|
||||
[&](Mammal*) { FAIL() << "gecko is not mammal"; },
|
||||
[&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
|
||||
[&](Default) { gecko_matched_default = true; });
|
||||
EXPECT_TRUE(gecko_matched_default);
|
||||
}
|
||||
}
|
||||
TEST(Castable, SwitchMatchFirst) {
|
||||
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
|
||||
{
|
||||
bool frog_matched_animal = false;
|
||||
Switch(
|
||||
frog.get(),
|
||||
[&](Animal* animal) {
|
||||
EXPECT_EQ(animal, frog.get());
|
||||
frog_matched_animal = true;
|
||||
},
|
||||
[&](Amphibian*) { FAIL() << "animal should have been matched first"; });
|
||||
EXPECT_TRUE(frog_matched_animal);
|
||||
}
|
||||
{
|
||||
bool frog_matched_amphibian = false;
|
||||
Switch(
|
||||
frog.get(),
|
||||
[&](Amphibian* amphibain) {
|
||||
EXPECT_EQ(amphibain, frog.get());
|
||||
frog_matched_amphibian = true;
|
||||
},
|
||||
[&](Animal*) { FAIL() << "amphibian should have been matched first"; });
|
||||
EXPECT_TRUE(frog_matched_amphibian);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(Animal);
|
||||
|
|
|
@ -953,7 +953,7 @@ const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
|
|||
|
||||
bool FunctionEmitter::EmitPipelineInput(std::string var_name,
|
||||
const Type* var_type,
|
||||
ast::AttributeList* decos,
|
||||
ast::AttributeList* attrs,
|
||||
std::vector<int> index_prefix,
|
||||
const Type* tip_type,
|
||||
const Type* forced_param_type,
|
||||
|
@ -966,19 +966,23 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name,
|
|||
}
|
||||
|
||||
// Recursively flatten matrices, arrays, and structures.
|
||||
if (auto* matrix_type = tip_type->As<Matrix>()) {
|
||||
return Switch(
|
||||
tip_type,
|
||||
[&](const Matrix* matrix_type) -> bool {
|
||||
index_prefix.push_back(0);
|
||||
const auto num_columns = static_cast<int>(matrix_type->columns);
|
||||
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
|
||||
for (int col = 0; col < num_columns; col++) {
|
||||
index_prefix.back() = col;
|
||||
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty,
|
||||
forced_param_type, params, statements)) {
|
||||
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
|
||||
vec_ty, forced_param_type, params,
|
||||
statements)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
} else if (auto* array_type = tip_type->As<Array>()) {
|
||||
},
|
||||
[&](const Array* array_type) -> bool {
|
||||
if (array_type->size == 0) {
|
||||
return Fail() << "runtime-size array not allowed on pipeline IO";
|
||||
}
|
||||
|
@ -986,85 +990,97 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name,
|
|||
const Type* elem_ty = array_type->type;
|
||||
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
|
||||
index_prefix.back() = i;
|
||||
if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty,
|
||||
forced_param_type, params, statements)) {
|
||||
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
|
||||
elem_ty, forced_param_type, params,
|
||||
statements)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
} else if (auto* struct_type = tip_type->As<Struct>()) {
|
||||
},
|
||||
[&](const Struct* struct_type) -> bool {
|
||||
const auto& members = struct_type->members;
|
||||
index_prefix.push_back(0);
|
||||
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
|
||||
index_prefix.back() = i;
|
||||
ast::AttributeList member_decos(*decos);
|
||||
ast::AttributeList member_attrs(*attrs);
|
||||
if (!parser_impl_.ConvertPipelineDecorations(
|
||||
struct_type,
|
||||
parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
|
||||
&member_decos)) {
|
||||
&member_attrs)) {
|
||||
return false;
|
||||
}
|
||||
if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix,
|
||||
members[i], forced_param_type, params,
|
||||
statements)) {
|
||||
if (!EmitPipelineInput(var_name, var_type, &member_attrs,
|
||||
index_prefix, members[i], forced_param_type,
|
||||
params, statements)) {
|
||||
return false;
|
||||
}
|
||||
// Copy the location as updated by nested expansion of the member.
|
||||
parser_impl_.SetLocation(decos, GetLocation(member_decos));
|
||||
parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
|
||||
},
|
||||
[&](Default) {
|
||||
const bool is_builtin =
|
||||
ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
|
||||
|
||||
const Type* param_type = is_builtin ? forced_param_type : tip_type;
|
||||
|
||||
const auto param_name = namer_.MakeDerivedName(var_name + "_param");
|
||||
// Create the parameter.
|
||||
// TODO(dneto): Note: If the parameter has non-location decorations,
|
||||
// then those decoration AST nodes will be reused between multiple elements
|
||||
// of a matrix, array, or structure. Normally that's disallowed but currently
|
||||
// the SPIR-V reader will make duplicates when the entire AST is cloned
|
||||
// at the top level of the SPIR-V reader flow. Consider rewriting this
|
||||
// to avoid this node-sharing.
|
||||
// then those decoration AST nodes will be reused between multiple
|
||||
// elements of a matrix, array, or structure. Normally that's
|
||||
// disallowed but currently the SPIR-V reader will make duplicates when
|
||||
// the entire AST is cloned at the top level of the SPIR-V reader flow.
|
||||
// Consider rewriting this to avoid this node-sharing.
|
||||
params->push_back(
|
||||
builder_.Param(param_name, param_type->Build(builder_), *decos));
|
||||
builder_.Param(param_name, param_type->Build(builder_), *attrs));
|
||||
|
||||
// Add a body statement to copy the parameter to the corresponding private
|
||||
// variable.
|
||||
// Add a body statement to copy the parameter to the corresponding
|
||||
// private variable.
|
||||
const ast::Expression* param_value = builder_.Expr(param_name);
|
||||
const ast::Expression* store_dest = builder_.Expr(var_name);
|
||||
|
||||
// Index into the LHS as needed.
|
||||
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
|
||||
auto* current_type =
|
||||
var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
|
||||
for (auto index : index_prefix) {
|
||||
if (auto* matrix_type = current_type->As<Matrix>()) {
|
||||
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
|
||||
Switch(
|
||||
current_type,
|
||||
[&](const Matrix* matrix_type) {
|
||||
store_dest =
|
||||
builder_.IndexAccessor(store_dest, builder_.Expr(index));
|
||||
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
|
||||
} else if (auto* array_type = current_type->As<Array>()) {
|
||||
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
|
||||
},
|
||||
[&](const Array* array_type) {
|
||||
store_dest =
|
||||
builder_.IndexAccessor(store_dest, builder_.Expr(index));
|
||||
current_type = array_type->type->UnwrapAlias();
|
||||
} else if (auto* struct_type = current_type->As<Struct>()) {
|
||||
},
|
||||
[&](const Struct* struct_type) {
|
||||
store_dest = builder_.MemberAccessor(
|
||||
store_dest,
|
||||
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
|
||||
store_dest, builder_.Expr(parser_impl_.GetMemberName(
|
||||
*struct_type, index)));
|
||||
current_type = struct_type->members[index];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (is_builtin && (tip_type != forced_param_type)) {
|
||||
// The parameter will have the WGSL type, but we need bitcast to
|
||||
// the variable store type.
|
||||
param_value =
|
||||
create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
|
||||
param_value = create<ast::BitcastExpression>(
|
||||
tip_type->Build(builder_), param_value);
|
||||
}
|
||||
|
||||
statements->push_back(builder_.Assign(store_dest, param_value));
|
||||
|
||||
// Increment the location attribute, in case more parameters will follow.
|
||||
IncrementLocation(decos);
|
||||
// Increment the location attribute, in case more parameters will
|
||||
// follow.
|
||||
IncrementLocation(attrs);
|
||||
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
|
||||
|
@ -1102,20 +1118,23 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
|
|||
}
|
||||
|
||||
// Recursively flatten matrices, arrays, and structures.
|
||||
if (auto* matrix_type = tip_type->As<Matrix>()) {
|
||||
return Switch(
|
||||
tip_type,
|
||||
[&](const Matrix* matrix_type) -> bool {
|
||||
index_prefix.push_back(0);
|
||||
const auto num_columns = static_cast<int>(matrix_type->columns);
|
||||
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
|
||||
for (int col = 0; col < num_columns; col++) {
|
||||
index_prefix.back() = col;
|
||||
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
|
||||
forced_member_type, return_members,
|
||||
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
|
||||
vec_ty, forced_member_type, return_members,
|
||||
return_exprs)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
} else if (auto* array_type = tip_type->As<Array>()) {
|
||||
},
|
||||
[&](const Array* array_type) -> bool {
|
||||
if (array_type->size == 0) {
|
||||
return Fail() << "runtime-size array not allowed on pipeline IO";
|
||||
}
|
||||
|
@ -1123,37 +1142,39 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
|
|||
const Type* elem_ty = array_type->type;
|
||||
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
|
||||
index_prefix.back() = i;
|
||||
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
|
||||
forced_member_type, return_members,
|
||||
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
|
||||
elem_ty, forced_member_type, return_members,
|
||||
return_exprs)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
} else if (auto* struct_type = tip_type->As<Struct>()) {
|
||||
},
|
||||
[&](const Struct* struct_type) -> bool {
|
||||
const auto& members = struct_type->members;
|
||||
index_prefix.push_back(0);
|
||||
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
|
||||
index_prefix.back() = i;
|
||||
ast::AttributeList member_decos(*decos);
|
||||
ast::AttributeList member_attrs(*decos);
|
||||
if (!parser_impl_.ConvertPipelineDecorations(
|
||||
struct_type,
|
||||
parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
|
||||
&member_decos)) {
|
||||
&member_attrs)) {
|
||||
return false;
|
||||
}
|
||||
if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix,
|
||||
members[i], forced_member_type, return_members,
|
||||
return_exprs)) {
|
||||
if (!EmitPipelineOutput(var_name, var_type, &member_attrs,
|
||||
index_prefix, members[i], forced_member_type,
|
||||
return_members, return_exprs)) {
|
||||
return false;
|
||||
}
|
||||
// Copy the location as updated by nested expansion of the member.
|
||||
parser_impl_.SetLocation(decos, GetLocation(member_decos));
|
||||
parser_impl_.SetLocation(decos, GetLocation(member_attrs));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
|
||||
},
|
||||
[&](Default) {
|
||||
const bool is_builtin =
|
||||
ast::HasAttribute<ast::BuiltinAttribute>(*decos);
|
||||
|
||||
const Type* member_type = is_builtin ? forced_member_type : tip_type;
|
||||
// Derive the member name directly from the variable name. They can't
|
||||
|
@ -1161,11 +1182,11 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
|
|||
const auto member_name = namer_.MakeDerivedName(var_name);
|
||||
// Create the member.
|
||||
// TODO(dneto): Note: If the parameter has non-location decorations,
|
||||
// then those decoration AST nodes will be reused between multiple elements
|
||||
// of a matrix, array, or structure. Normally that's disallowed but currently
|
||||
// the SPIR-V reader will make duplicates when the entire AST is cloned
|
||||
// at the top level of the SPIR-V reader flow. Consider rewriting this
|
||||
// to avoid this node-sharing.
|
||||
// then those decoration AST nodes will be reused between multiple
|
||||
// elements of a matrix, array, or structure. Normally that's
|
||||
// disallowed but currently the SPIR-V reader will make duplicates when
|
||||
// the entire AST is cloned at the top level of the SPIR-V reader flow.
|
||||
// Consider rewriting this to avoid this node-sharing.
|
||||
return_members->push_back(
|
||||
builder_.Member(member_name, member_type->Build(builder_), *decos));
|
||||
|
||||
|
@ -1174,20 +1195,27 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
|
|||
const ast::Expression* load_source = builder_.Expr(var_name);
|
||||
|
||||
// Index into the variable as needed to pick out the flattened member.
|
||||
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
|
||||
auto* current_type =
|
||||
var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
|
||||
for (auto index : index_prefix) {
|
||||
if (auto* matrix_type = current_type->As<Matrix>()) {
|
||||
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
|
||||
Switch(
|
||||
current_type,
|
||||
[&](const Matrix* matrix_type) {
|
||||
load_source =
|
||||
builder_.IndexAccessor(load_source, builder_.Expr(index));
|
||||
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
|
||||
} else if (auto* array_type = current_type->As<Array>()) {
|
||||
load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
|
||||
},
|
||||
[&](const Array* array_type) {
|
||||
load_source =
|
||||
builder_.IndexAccessor(load_source, builder_.Expr(index));
|
||||
current_type = array_type->type->UnwrapAlias();
|
||||
} else if (auto* struct_type = current_type->As<Struct>()) {
|
||||
},
|
||||
[&](const Struct* struct_type) {
|
||||
load_source = builder_.MemberAccessor(
|
||||
load_source,
|
||||
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
|
||||
load_source, builder_.Expr(parser_impl_.GetMemberName(
|
||||
*struct_type, index)));
|
||||
current_type = struct_type->members[index];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (is_builtin && (tip_type != forced_member_type)) {
|
||||
|
@ -1198,10 +1226,12 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
|
|||
}
|
||||
return_exprs->push_back(load_source);
|
||||
|
||||
// Increment the location attribute, in case more parameters will follow.
|
||||
// Increment the location attribute, in case more parameters will
|
||||
// follow.
|
||||
IncrementLocation(decos);
|
||||
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
bool FunctionEmitter::EmitEntryPointAsWrapper() {
|
||||
|
|
|
@ -239,11 +239,12 @@ bool GeneratorImpl::Generate() {
|
|||
}
|
||||
last_kind = kind;
|
||||
|
||||
if (auto* global = decl->As<ast::Variable>()) {
|
||||
if (!EmitGlobalVariable(global)) {
|
||||
return false;
|
||||
}
|
||||
} else if (auto* str = decl->As<ast::Struct>()) {
|
||||
bool ok = Switch(
|
||||
decl,
|
||||
[&](const ast::Variable* global) { //
|
||||
return EmitGlobalVariable(global);
|
||||
},
|
||||
[&](const ast::Struct* str) {
|
||||
auto* ty = builder_.Sem().Get(str);
|
||||
auto storage_class_uses = ty->StorageClassUsage();
|
||||
if (storage_class_uses.size() !=
|
||||
|
@ -253,25 +254,26 @@ bool GeneratorImpl::Generate() {
|
|||
// uniform buffer, so it needs to be emitted.
|
||||
// Storage buffer are read and written to via a ByteAddressBuffer
|
||||
// instead of true structure.
|
||||
// Structures used as uniform buffer are read from an array of vectors
|
||||
// instead of true structure.
|
||||
if (!EmitStructType(current_buffer_, ty)) {
|
||||
return false;
|
||||
// Structures used as uniform buffer are read from an array of
|
||||
// vectors instead of true structure.
|
||||
return EmitStructType(current_buffer_, ty);
|
||||
}
|
||||
}
|
||||
} else if (auto* func = decl->As<ast::Function>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Function* func) {
|
||||
if (func->IsEntryPoint()) {
|
||||
if (!EmitEntryPointFunction(func)) {
|
||||
return false;
|
||||
return EmitEntryPointFunction(func);
|
||||
}
|
||||
} else {
|
||||
if (!EmitFunction(func)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return EmitFunction(func);
|
||||
},
|
||||
[&](Default) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "unhandled module-scope declaration: " << decl->TypeInfo().name;
|
||||
<< "unhandled module-scope declaration: "
|
||||
<< decl->TypeInfo().name;
|
||||
return false;
|
||||
});
|
||||
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -929,22 +931,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
|
|||
const ast::CallExpression* expr) {
|
||||
auto* call = builder_.Sem().Get(expr);
|
||||
auto* target = call->Target();
|
||||
|
||||
if (auto* func = target->As<sem::Function>()) {
|
||||
return Switch(
|
||||
target,
|
||||
[&](const sem::Function* func) {
|
||||
return EmitFunctionCall(out, call, func);
|
||||
}
|
||||
if (auto* builtin = target->As<sem::Builtin>()) {
|
||||
},
|
||||
[&](const sem::Builtin* builtin) {
|
||||
return EmitBuiltinCall(out, call, builtin);
|
||||
}
|
||||
if (auto* conv = target->As<sem::TypeConversion>()) {
|
||||
},
|
||||
[&](const sem::TypeConversion* conv) {
|
||||
return EmitTypeConversion(out, call, conv);
|
||||
}
|
||||
if (auto* ctor = target->As<sem::TypeConstructor>()) {
|
||||
},
|
||||
[&](const sem::TypeConstructor* ctor) {
|
||||
return EmitTypeConstructor(out, call, ctor);
|
||||
}
|
||||
},
|
||||
[&](Default) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "unhandled call target: " << target->TypeInfo().name;
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
|
||||
|
@ -2639,35 +2644,38 @@ bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
|
|||
|
||||
bool GeneratorImpl::EmitExpression(std::ostream& out,
|
||||
const ast::Expression* expr) {
|
||||
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
|
||||
return Switch(
|
||||
expr,
|
||||
[&](const ast::IndexAccessorExpression* a) { //
|
||||
return EmitIndexAccessor(out, a);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BinaryExpression>()) {
|
||||
},
|
||||
[&](const ast::BinaryExpression* b) { //
|
||||
return EmitBinary(out, b);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BitcastExpression>()) {
|
||||
},
|
||||
[&](const ast::BitcastExpression* b) { //
|
||||
return EmitBitcast(out, b);
|
||||
}
|
||||
if (auto* c = expr->As<ast::CallExpression>()) {
|
||||
},
|
||||
[&](const ast::CallExpression* c) { //
|
||||
return EmitCall(out, c);
|
||||
}
|
||||
if (auto* i = expr->As<ast::IdentifierExpression>()) {
|
||||
},
|
||||
[&](const ast::IdentifierExpression* i) { //
|
||||
return EmitIdentifier(out, i);
|
||||
}
|
||||
if (auto* l = expr->As<ast::LiteralExpression>()) {
|
||||
},
|
||||
[&](const ast::LiteralExpression* l) { //
|
||||
return EmitLiteral(out, l);
|
||||
}
|
||||
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
|
||||
},
|
||||
[&](const ast::MemberAccessorExpression* m) { //
|
||||
return EmitMemberAccessor(out, m);
|
||||
}
|
||||
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
|
||||
},
|
||||
[&](const ast::UnaryOpExpression* u) { //
|
||||
return EmitUnaryOp(out, u);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) { //
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown expression type: " + std::string(expr->TypeInfo().name));
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitIdentifier(std::ostream& out,
|
||||
|
@ -3127,41 +3135,61 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
|
|||
|
||||
bool GeneratorImpl::EmitLiteral(std::ostream& out,
|
||||
const ast::LiteralExpression* lit) {
|
||||
if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
|
||||
return Switch(
|
||||
lit,
|
||||
[&](const ast::BoolLiteralExpression* l) {
|
||||
out << (l->value ? "true" : "false");
|
||||
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* fl) {
|
||||
if (std::isinf(fl->value)) {
|
||||
out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
|
||||
out << (fl->value >= 0 ? "asfloat(0x7f800000u)"
|
||||
: "asfloat(0xff800000u)");
|
||||
} else if (std::isnan(fl->value)) {
|
||||
out << "asfloat(0x7fc00000u)";
|
||||
} else {
|
||||
out << FloatToString(fl->value) << "f";
|
||||
}
|
||||
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::SintLiteralExpression* sl) {
|
||||
out << sl->value;
|
||||
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::UintLiteralExpression* ul) {
|
||||
out << ul->value << "u";
|
||||
} else {
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitValue(std::ostream& out,
|
||||
const sem::Type* type,
|
||||
int value) {
|
||||
if (type->Is<sem::Bool>()) {
|
||||
return Switch(
|
||||
type,
|
||||
[&](const sem::Bool*) {
|
||||
out << (value == 0 ? "false" : "true");
|
||||
} else if (type->Is<sem::F32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
out << value << ".0f";
|
||||
} else if (type->Is<sem::I32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
out << value;
|
||||
} else if (type->Is<sem::U32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::U32*) {
|
||||
out << value << "u";
|
||||
} else if (auto* vec = type->As<sem::Vector>()) {
|
||||
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
|
||||
"")) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Vector* vec) {
|
||||
if (!EmitType(out, type, ast::StorageClass::kNone,
|
||||
ast::Access::kReadWrite, "")) {
|
||||
return false;
|
||||
}
|
||||
ScopedParen sp(out);
|
||||
|
@ -3173,9 +3201,11 @@ bool GeneratorImpl::EmitValue(std::ostream& out,
|
|||
return false;
|
||||
}
|
||||
}
|
||||
} else if (auto* mat = type->As<sem::Matrix>()) {
|
||||
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
|
||||
"")) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Matrix* mat) {
|
||||
if (!EmitType(out, type, ast::StorageClass::kNone,
|
||||
ast::Access::kReadWrite, "")) {
|
||||
return false;
|
||||
}
|
||||
ScopedParen sp(out);
|
||||
|
@ -3187,20 +3217,26 @@ bool GeneratorImpl::EmitValue(std::ostream& out,
|
|||
return false;
|
||||
}
|
||||
}
|
||||
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Struct*) {
|
||||
out << "(";
|
||||
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
|
||||
"")) {
|
||||
return false;
|
||||
}
|
||||
out << ")" << value;
|
||||
} else {
|
||||
TINT_DEFER(out << ")" << value);
|
||||
return EmitType(out, type, ast::StorageClass::kNone,
|
||||
ast::Access::kUndefined, "");
|
||||
},
|
||||
[&](const sem::Array*) {
|
||||
out << "(";
|
||||
TINT_DEFER(out << ")" << value);
|
||||
return EmitType(out, type, ast::StorageClass::kNone,
|
||||
ast::Access::kUndefined, "");
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"Invalid type for value emission: " + type->type_name());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
|
||||
|
@ -3375,56 +3411,59 @@ bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
|
|||
}
|
||||
|
||||
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
|
||||
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
|
||||
return Switch(
|
||||
stmt,
|
||||
[&](const ast::AssignmentStatement* a) { //
|
||||
return EmitAssign(a);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BlockStatement>()) {
|
||||
},
|
||||
[&](const ast::BlockStatement* b) { //
|
||||
return EmitBlock(b);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BreakStatement>()) {
|
||||
},
|
||||
[&](const ast::BreakStatement* b) { //
|
||||
return EmitBreak(b);
|
||||
}
|
||||
if (auto* c = stmt->As<ast::CallStatement>()) {
|
||||
},
|
||||
[&](const ast::CallStatement* c) { //
|
||||
auto out = line();
|
||||
if (!EmitCall(out, c->expr)) {
|
||||
return false;
|
||||
}
|
||||
out << ";";
|
||||
return true;
|
||||
}
|
||||
if (auto* c = stmt->As<ast::ContinueStatement>()) {
|
||||
},
|
||||
[&](const ast::ContinueStatement* c) { //
|
||||
return EmitContinue(c);
|
||||
}
|
||||
if (auto* d = stmt->As<ast::DiscardStatement>()) {
|
||||
},
|
||||
[&](const ast::DiscardStatement* d) { //
|
||||
return EmitDiscard(d);
|
||||
}
|
||||
if (stmt->As<ast::FallthroughStatement>()) {
|
||||
},
|
||||
[&](const ast::FallthroughStatement*) { //
|
||||
line() << "/* fallthrough */";
|
||||
return true;
|
||||
}
|
||||
if (auto* i = stmt->As<ast::IfStatement>()) {
|
||||
},
|
||||
[&](const ast::IfStatement* i) { //
|
||||
return EmitIf(i);
|
||||
}
|
||||
if (auto* l = stmt->As<ast::LoopStatement>()) {
|
||||
},
|
||||
[&](const ast::LoopStatement* l) { //
|
||||
return EmitLoop(l);
|
||||
}
|
||||
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
|
||||
},
|
||||
[&](const ast::ForLoopStatement* l) { //
|
||||
return EmitForLoop(l);
|
||||
}
|
||||
if (auto* r = stmt->As<ast::ReturnStatement>()) {
|
||||
},
|
||||
[&](const ast::ReturnStatement* r) { //
|
||||
return EmitReturn(r);
|
||||
}
|
||||
if (auto* s = stmt->As<ast::SwitchStatement>()) {
|
||||
},
|
||||
[&](const ast::SwitchStatement* s) { //
|
||||
return EmitSwitch(s);
|
||||
}
|
||||
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
|
||||
},
|
||||
[&](const ast::VariableDeclStatement* v) { //
|
||||
return EmitVariable(v->variable);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) { //
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown statement type: " + std::string(stmt->TypeInfo().name));
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
|
||||
|
@ -3516,13 +3555,16 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
break;
|
||||
}
|
||||
|
||||
if (auto* ary = type->As<sem::Array>()) {
|
||||
return Switch(
|
||||
type,
|
||||
[&](const sem::Array* ary) {
|
||||
const sem::Type* base_type = ary;
|
||||
std::vector<uint32_t> sizes;
|
||||
while (auto* arr = base_type->As<sem::Array>()) {
|
||||
if (arr->IsRuntimeSized()) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "Runtime arrays may only exist in storage buffers, which should "
|
||||
<< "Runtime arrays may only exist in storage buffers, which "
|
||||
"should "
|
||||
"have been transformed into a ByteAddressBuffer";
|
||||
return false;
|
||||
}
|
||||
|
@ -3541,38 +3583,53 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
for (uint32_t size : sizes) {
|
||||
out << "[" << size << "]";
|
||||
}
|
||||
} else if (type->Is<sem::Bool>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Bool*) {
|
||||
out << "bool";
|
||||
} else if (type->Is<sem::F32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
out << "float";
|
||||
} else if (type->Is<sem::I32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
out << "int";
|
||||
} else if (auto* mat = type->As<sem::Matrix>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Matrix* mat) {
|
||||
if (!EmitType(out, mat->type(), storage_class, access, "")) {
|
||||
return false;
|
||||
}
|
||||
// Note: HLSL's matrices are declared as <type>NxM, where N is the number of
|
||||
// rows and M is the number of columns. Despite HLSL's matrices being
|
||||
// column-major by default, the index operator and constructors actually
|
||||
// operate on row-vectors, where as WGSL operates on column vectors.
|
||||
// To simplify everything we use the transpose of the matrices.
|
||||
// See:
|
||||
// Note: HLSL's matrices are declared as <type>NxM, where N is the
|
||||
// number of rows and M is the number of columns. Despite HLSL's
|
||||
// matrices being column-major by default, the index operator and
|
||||
// constructors actually operate on row-vectors, where as WGSL operates
|
||||
// on column vectors. To simplify everything we use the transpose of the
|
||||
// matrices. See:
|
||||
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
|
||||
out << mat->columns() << "x" << mat->rows();
|
||||
} else if (type->Is<sem::Pointer>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Pointer*) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "Attempting to emit pointer type. These should have been removed "
|
||||
"with the InlinePointerLets transform";
|
||||
<< "Attempting to emit pointer type. These should have been "
|
||||
"removed with the InlinePointerLets transform";
|
||||
return false;
|
||||
} else if (auto* sampler = type->As<sem::Sampler>()) {
|
||||
},
|
||||
[&](const sem::Sampler* sampler) {
|
||||
out << "Sampler";
|
||||
if (sampler->IsComparison()) {
|
||||
out << "Comparison";
|
||||
}
|
||||
out << "State";
|
||||
} else if (auto* str = type->As<sem::Struct>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Struct* str) {
|
||||
out << StructName(str);
|
||||
} else if (auto* tex = type->As<sem::Texture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Texture* tex) {
|
||||
auto* storage = tex->As<sem::StorageTexture>();
|
||||
auto* ms = tex->As<sem::MultisampledTexture>();
|
||||
auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
|
||||
|
@ -3609,7 +3666,8 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
}
|
||||
|
||||
if (storage) {
|
||||
auto* component = image_format_to_rwtexture_type(storage->texel_format());
|
||||
auto* component =
|
||||
image_format_to_rwtexture_type(storage->texel_format());
|
||||
if (component == nullptr) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "Unsupported StorageTexture TexelFormat: "
|
||||
|
@ -3635,9 +3693,13 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
}
|
||||
out << ">";
|
||||
}
|
||||
} else if (type->Is<sem::U32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::U32*) {
|
||||
out << "uint";
|
||||
} else if (auto* vec = type->As<sem::Vector>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Vector* vec) {
|
||||
auto width = vec->Width();
|
||||
if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
|
||||
out << "float" << width;
|
||||
|
@ -3654,18 +3716,20 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
}
|
||||
out << ", " << width << ">";
|
||||
}
|
||||
} else if (auto* atomic = type->As<sem::Atomic>()) {
|
||||
if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
|
||||
return false;
|
||||
}
|
||||
} else if (type->Is<sem::Void>()) {
|
||||
out << "void";
|
||||
} else {
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Atomic* atomic) {
|
||||
return EmitType(out, atomic->Type(), storage_class, access, name);
|
||||
},
|
||||
[&](const sem::Void*) {
|
||||
out << "void";
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
"unknown type in EmitType");
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
|
||||
|
|
|
@ -538,23 +538,25 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
|
|||
const ast::CallExpression* expr) {
|
||||
auto* call = program_->Sem().Get(expr);
|
||||
auto* target = call->Target();
|
||||
|
||||
if (auto* func = target->As<sem::Function>()) {
|
||||
return Switch(
|
||||
target,
|
||||
[&](const sem::Function* func) {
|
||||
return EmitFunctionCall(out, call, func);
|
||||
}
|
||||
if (auto* builtin = target->As<sem::Builtin>()) {
|
||||
},
|
||||
[&](const sem::Builtin* builtin) {
|
||||
return EmitBuiltinCall(out, call, builtin);
|
||||
}
|
||||
if (auto* conv = target->As<sem::TypeConversion>()) {
|
||||
},
|
||||
[&](const sem::TypeConversion* conv) {
|
||||
return EmitTypeConversion(out, call, conv);
|
||||
}
|
||||
if (auto* ctor = target->As<sem::TypeConstructor>()) {
|
||||
},
|
||||
[&](const sem::TypeConstructor* ctor) {
|
||||
return EmitTypeConstructor(out, call, ctor);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "unhandled call target: " << target->TypeInfo().name;
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
|
||||
|
@ -1476,106 +1478,128 @@ bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
|
|||
}
|
||||
|
||||
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
|
||||
if (type->Is<sem::Bool>()) {
|
||||
return Switch(
|
||||
type,
|
||||
[&](const sem::Bool*) {
|
||||
out << "false";
|
||||
} else if (type->Is<sem::F32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
out << "0.0f";
|
||||
} else if (type->Is<sem::I32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
out << "0";
|
||||
} else if (type->Is<sem::U32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::U32*) {
|
||||
out << "0u";
|
||||
} else if (auto* vec = type->As<sem::Vector>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Vector* vec) { //
|
||||
return EmitZeroValue(out, vec->type());
|
||||
} else if (auto* mat = type->As<sem::Matrix>()) {
|
||||
},
|
||||
[&](const sem::Matrix* mat) {
|
||||
if (!EmitType(out, mat, "")) {
|
||||
return false;
|
||||
}
|
||||
out << "(";
|
||||
if (!EmitZeroValue(out, mat->type())) {
|
||||
return false;
|
||||
}
|
||||
out << ")";
|
||||
} else if (auto* arr = type->As<sem::Array>()) {
|
||||
TINT_DEFER(out << ")");
|
||||
return EmitZeroValue(out, mat->type());
|
||||
},
|
||||
[&](const sem::Array* arr) {
|
||||
out << "{";
|
||||
if (!EmitZeroValue(out, arr->ElemType())) {
|
||||
return false;
|
||||
}
|
||||
out << "}";
|
||||
} else if (type->As<sem::Struct>()) {
|
||||
TINT_DEFER(out << "}");
|
||||
return EmitZeroValue(out, arr->ElemType());
|
||||
},
|
||||
[&](const sem::Struct*) {
|
||||
out << "{}";
|
||||
} else {
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"Invalid type for zero emission: " + type->type_name());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitLiteral(std::ostream& out,
|
||||
const ast::LiteralExpression* lit) {
|
||||
if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
|
||||
return Switch(
|
||||
lit,
|
||||
[&](const ast::BoolLiteralExpression* l) {
|
||||
out << (l->value ? "true" : "false");
|
||||
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
|
||||
if (std::isinf(fl->value)) {
|
||||
out << (fl->value >= 0 ? "INFINITY" : "-INFINITY");
|
||||
} else if (std::isnan(fl->value)) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* l) {
|
||||
if (std::isinf(l->value)) {
|
||||
out << (l->value >= 0 ? "INFINITY" : "-INFINITY");
|
||||
} else if (std::isnan(l->value)) {
|
||||
out << "NAN";
|
||||
} else {
|
||||
out << FloatToString(fl->value) << "f";
|
||||
}
|
||||
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
|
||||
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary
|
||||
// minus and `2147483648` as separate tokens, and the latter doesn't
|
||||
// fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid
|
||||
// issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which
|
||||
// ensures the expression type is `int`.
|
||||
const auto int_min = std::numeric_limits<int32_t>::min();
|
||||
if (sl->ValueAsI32() == int_min) {
|
||||
out << "(" << int_min + 1 << " - 1)";
|
||||
} else {
|
||||
out << sl->value;
|
||||
}
|
||||
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
|
||||
out << ul->value << "u";
|
||||
} else {
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
|
||||
return false;
|
||||
out << FloatToString(l->value) << "f";
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](const ast::SintLiteralExpression* l) {
|
||||
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary
|
||||
// minus and `2147483648` as separate tokens, and the latter doesn't
|
||||
// fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To
|
||||
// avoid issues with `long` to `int` casts, emit `(2147483647 - 1)`
|
||||
// instead, which ensures the expression type is `int`.
|
||||
const auto int_min = std::numeric_limits<int32_t>::min();
|
||||
if (l->ValueAsI32() == int_min) {
|
||||
out << "(" << int_min + 1 << " - 1)";
|
||||
} else {
|
||||
out << l->value;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](const ast::UintLiteralExpression* l) {
|
||||
out << l->value << "u";
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitExpression(std::ostream& out,
|
||||
const ast::Expression* expr) {
|
||||
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
|
||||
return Switch(
|
||||
expr,
|
||||
[&](const ast::IndexAccessorExpression* a) { //
|
||||
return EmitIndexAccessor(out, a);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BinaryExpression>()) {
|
||||
},
|
||||
[&](const ast::BinaryExpression* b) { //
|
||||
return EmitBinary(out, b);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BitcastExpression>()) {
|
||||
},
|
||||
[&](const ast::BitcastExpression* b) { //
|
||||
return EmitBitcast(out, b);
|
||||
}
|
||||
if (auto* c = expr->As<ast::CallExpression>()) {
|
||||
},
|
||||
[&](const ast::CallExpression* c) { //
|
||||
return EmitCall(out, c);
|
||||
}
|
||||
if (auto* i = expr->As<ast::IdentifierExpression>()) {
|
||||
},
|
||||
[&](const ast::IdentifierExpression* i) { //
|
||||
return EmitIdentifier(out, i);
|
||||
}
|
||||
if (auto* l = expr->As<ast::LiteralExpression>()) {
|
||||
},
|
||||
[&](const ast::LiteralExpression* l) { //
|
||||
return EmitLiteral(out, l);
|
||||
}
|
||||
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
|
||||
},
|
||||
[&](const ast::MemberAccessorExpression* m) { //
|
||||
return EmitMemberAccessor(out, m);
|
||||
}
|
||||
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
|
||||
},
|
||||
[&](const ast::UnaryOpExpression* u) { //
|
||||
return EmitUnaryOp(out, u);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) { //
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown expression type: " + std::string(expr->TypeInfo().name));
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) {
|
||||
|
@ -2106,57 +2130,60 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
|
|||
}
|
||||
|
||||
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
|
||||
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
|
||||
return Switch(
|
||||
stmt,
|
||||
[&](const ast::AssignmentStatement* a) { //
|
||||
return EmitAssign(a);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BlockStatement>()) {
|
||||
},
|
||||
[&](const ast::BlockStatement* b) { //
|
||||
return EmitBlock(b);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BreakStatement>()) {
|
||||
},
|
||||
[&](const ast::BreakStatement* b) { //
|
||||
return EmitBreak(b);
|
||||
}
|
||||
if (auto* c = stmt->As<ast::CallStatement>()) {
|
||||
},
|
||||
[&](const ast::CallStatement* c) { //
|
||||
auto out = line();
|
||||
if (!EmitCall(out, c->expr)) {
|
||||
if (!EmitCall(out, c->expr)) { //
|
||||
return false;
|
||||
}
|
||||
out << ";";
|
||||
return true;
|
||||
}
|
||||
if (auto* c = stmt->As<ast::ContinueStatement>()) {
|
||||
},
|
||||
[&](const ast::ContinueStatement* c) { //
|
||||
return EmitContinue(c);
|
||||
}
|
||||
if (auto* d = stmt->As<ast::DiscardStatement>()) {
|
||||
},
|
||||
[&](const ast::DiscardStatement* d) { //
|
||||
return EmitDiscard(d);
|
||||
}
|
||||
if (stmt->As<ast::FallthroughStatement>()) {
|
||||
},
|
||||
[&](const ast::FallthroughStatement*) { //
|
||||
line() << "/* fallthrough */";
|
||||
return true;
|
||||
}
|
||||
if (auto* i = stmt->As<ast::IfStatement>()) {
|
||||
},
|
||||
[&](const ast::IfStatement* i) { //
|
||||
return EmitIf(i);
|
||||
}
|
||||
if (auto* l = stmt->As<ast::LoopStatement>()) {
|
||||
},
|
||||
[&](const ast::LoopStatement* l) { //
|
||||
return EmitLoop(l);
|
||||
}
|
||||
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
|
||||
},
|
||||
[&](const ast::ForLoopStatement* l) { //
|
||||
return EmitForLoop(l);
|
||||
}
|
||||
if (auto* r = stmt->As<ast::ReturnStatement>()) {
|
||||
},
|
||||
[&](const ast::ReturnStatement* r) { //
|
||||
return EmitReturn(r);
|
||||
}
|
||||
if (auto* s = stmt->As<ast::SwitchStatement>()) {
|
||||
},
|
||||
[&](const ast::SwitchStatement* s) { //
|
||||
return EmitSwitch(s);
|
||||
}
|
||||
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
|
||||
},
|
||||
[&](const ast::VariableDeclStatement* v) { //
|
||||
auto* var = program_->Sem().Get(v->variable);
|
||||
return EmitVariable(var);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown statement type: " + std::string(stmt->TypeInfo().name));
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
|
||||
|
@ -2204,7 +2231,10 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
if (name_printed) {
|
||||
*name_printed = false;
|
||||
}
|
||||
if (auto* atomic = type->As<sem::Atomic>()) {
|
||||
|
||||
return Switch(
|
||||
type,
|
||||
[&](const sem::Atomic* atomic) {
|
||||
if (atomic->Type()->Is<sem::I32>()) {
|
||||
out << "atomic_int";
|
||||
return true;
|
||||
|
@ -2216,9 +2246,8 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "unhandled atomic type " << atomic->Type()->type_name();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* ary = type->As<sem::Array>()) {
|
||||
},
|
||||
[&](const sem::Array* ary) {
|
||||
const sem::Type* base_type = ary;
|
||||
std::vector<uint32_t> sizes;
|
||||
while (auto* arr = base_type->As<sem::Array>()) {
|
||||
|
@ -2242,32 +2271,27 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
out << "[" << size << "]";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type->Is<sem::Bool>()) {
|
||||
},
|
||||
[&](const sem::Bool*) {
|
||||
out << "bool";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type->Is<sem::F32>()) {
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
out << "float";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type->Is<sem::I32>()) {
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
out << "int";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* mat = type->As<sem::Matrix>()) {
|
||||
},
|
||||
[&](const sem::Matrix* mat) {
|
||||
if (!EmitType(out, mat->type(), "")) {
|
||||
return false;
|
||||
}
|
||||
out << mat->columns() << "x" << mat->rows();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* ptr = type->As<sem::Pointer>()) {
|
||||
},
|
||||
[&](const sem::Pointer* ptr) {
|
||||
if (ptr->Access() == ast::Access::kRead) {
|
||||
out << "const ";
|
||||
}
|
||||
|
@ -2293,21 +2317,18 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type->Is<sem::Sampler>()) {
|
||||
},
|
||||
[&](const sem::Sampler*) {
|
||||
out << "sampler";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* str = type->As<sem::Struct>()) {
|
||||
// The struct type emits as just the name. The declaration would be emitted
|
||||
// as part of emitting the declared types.
|
||||
},
|
||||
[&](const sem::Struct* str) {
|
||||
// The struct type emits as just the name. The declaration would be
|
||||
// emitted as part of emitting the declared types.
|
||||
out << StructName(str);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* tex = type->As<sem::Texture>()) {
|
||||
},
|
||||
[&](const sem::Texture* tex) {
|
||||
if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
|
||||
out << "depth";
|
||||
} else {
|
||||
|
@ -2343,11 +2364,19 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
out << "_ms";
|
||||
}
|
||||
out << "<";
|
||||
if (tex->Is<sem::DepthTexture>()) {
|
||||
TINT_DEFER(out << ">");
|
||||
|
||||
return Switch(
|
||||
tex,
|
||||
[&](const sem::DepthTexture*) {
|
||||
out << "float, access::sample";
|
||||
} else if (tex->Is<sem::DepthMultisampledTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::DepthMultisampledTexture*) {
|
||||
out << "float, access::read";
|
||||
} else if (auto* storage = tex->As<sem::StorageTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::StorageTexture* storage) {
|
||||
if (!EmitType(out, storage->type(), "")) {
|
||||
return false;
|
||||
}
|
||||
|
@ -2358,49 +2387,54 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
} else if (storage->access() == ast::Access::kWrite) {
|
||||
out << ", access::write";
|
||||
} else {
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"Invalid access control for storage texture");
|
||||
return false;
|
||||
}
|
||||
} else if (auto* ms = tex->As<sem::MultisampledTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::MultisampledTexture* ms) {
|
||||
if (!EmitType(out, ms->type(), "")) {
|
||||
return false;
|
||||
}
|
||||
out << ", access::read";
|
||||
} else if (auto* sampled = tex->As<sem::SampledTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::SampledTexture* sampled) {
|
||||
if (!EmitType(out, sampled->type(), "")) {
|
||||
return false;
|
||||
}
|
||||
out << ", access::sample";
|
||||
} else {
|
||||
diagnostics_.add_error(diag::System::Writer, "invalid texture type");
|
||||
return false;
|
||||
}
|
||||
out << ">";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type->Is<sem::U32>()) {
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
"invalid texture type");
|
||||
return false;
|
||||
});
|
||||
},
|
||||
[&](const sem::U32*) {
|
||||
out << "uint";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* vec = type->As<sem::Vector>()) {
|
||||
},
|
||||
[&](const sem::Vector* vec) {
|
||||
if (!EmitType(out, vec->type(), "")) {
|
||||
return false;
|
||||
}
|
||||
out << vec->Width();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type->Is<sem::Void>()) {
|
||||
},
|
||||
[&](const sem::Void*) {
|
||||
out << "void";
|
||||
return true;
|
||||
}
|
||||
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown type in EmitType: " + type->type_name());
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
|
||||
|
@ -2542,18 +2576,23 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
|
|||
// Emit attributes
|
||||
if (auto* decl = mem->Declaration()) {
|
||||
for (auto* attr : decl->attributes) {
|
||||
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
|
||||
bool ok = Switch(
|
||||
attr,
|
||||
[&](const ast::BuiltinAttribute* builtin) {
|
||||
auto name = builtin_to_attribute(builtin->builtin);
|
||||
if (name.empty()) {
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown builtin");
|
||||
return false;
|
||||
}
|
||||
out << " [[" << name << "]]";
|
||||
} else if (auto* loc = attr->As<ast::LocationAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::LocationAttribute* loc) {
|
||||
auto& pipeline_stage_uses = str->PipelineStageUses();
|
||||
if (pipeline_stage_uses.size() != 1) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "invalid entry point IO struct uses";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (pipeline_stage_uses.count(
|
||||
|
@ -2570,9 +2609,12 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
|
|||
out << " [[color(" + std::to_string(loc->value) + ")]]";
|
||||
} else {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "invalid use of location attribute";
|
||||
<< "invalid use of location decoration";
|
||||
return false;
|
||||
}
|
||||
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::InterpolateAttribute* interpolate) {
|
||||
auto name = interpolation_to_attribute(interpolate->type,
|
||||
interpolate->sampling);
|
||||
if (name.empty()) {
|
||||
|
@ -2581,16 +2623,25 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
|
|||
return false;
|
||||
}
|
||||
out << " [[" << name << "]]";
|
||||
} else if (attr->Is<ast::InvariantAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::InvariantAttribute*) {
|
||||
if (invariant_define_name_.empty()) {
|
||||
invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
|
||||
}
|
||||
out << " " << invariant_define_name_;
|
||||
} else if (!attr->IsAnyOf<ast::StructMemberOffsetAttribute,
|
||||
ast::StructMemberAlignAttribute,
|
||||
ast::StructMemberSizeAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StructMemberOffsetAttribute*) { return true; },
|
||||
[&](const ast::StructMemberAlignAttribute*) { return true; },
|
||||
[&](const ast::StructMemberSizeAttribute*) { return true; },
|
||||
[&](Default) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "unhandled struct member attribute: " << attr->Name();
|
||||
return false;
|
||||
});
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2796,13 +2847,22 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
|
|||
|
||||
GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
|
||||
const sem::Type* ty) {
|
||||
if (ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
|
||||
return Switch(
|
||||
ty,
|
||||
|
||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
// 2.1 Scalar Data Types
|
||||
return {4, 4};
|
||||
}
|
||||
[&](const sem::U32*) {
|
||||
return SizeAndAlign{4, 4};
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
return SizeAndAlign{4, 4};
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
return SizeAndAlign{4, 4};
|
||||
},
|
||||
|
||||
if (auto* vec = ty->As<sem::Vector>()) {
|
||||
[&](const sem::Vector* vec) {
|
||||
auto num_els = vec->Width();
|
||||
auto* el_ty = vec->type();
|
||||
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
|
||||
|
@ -2817,9 +2877,12 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
|
|||
return SizeAndAlign{num_els * 4, num_els * 4};
|
||||
}
|
||||
}
|
||||
}
|
||||
TINT_UNREACHABLE(Writer, diagnostics_)
|
||||
<< "Unhandled vector element type " << el_ty->TypeInfo().name;
|
||||
return SizeAndAlign{};
|
||||
},
|
||||
|
||||
if (auto* mat = ty->As<sem::Matrix>()) {
|
||||
[&](const sem::Matrix* mat) {
|
||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
// 2.3 Matrix Data Types
|
||||
auto cols = mat->columns();
|
||||
|
@ -2841,32 +2904,39 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
|
|||
return table[(3 * (cols - 2)) + (rows - 2)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* arr = ty->As<sem::Array>()) {
|
||||
TINT_UNREACHABLE(Writer, diagnostics_)
|
||||
<< "Unhandled matrix element type " << el_ty->TypeInfo().name;
|
||||
return SizeAndAlign{};
|
||||
},
|
||||
|
||||
[&](const sem::Array* arr) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "arrays with explicit strides should have "
|
||||
"removed with the PadArrayElements transform";
|
||||
return {};
|
||||
return SizeAndAlign{};
|
||||
}
|
||||
auto num_els = std::max<uint32_t>(arr->Count(), 1);
|
||||
return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
|
||||
}
|
||||
},
|
||||
|
||||
if (auto* str = ty->As<sem::Struct>()) {
|
||||
// TODO(crbug.com/tint/650): There's an assumption here that MSL's default
|
||||
// structure size and alignment matches WGSL's. We need to confirm this.
|
||||
[&](const sem::Struct* str) {
|
||||
// TODO(crbug.com/tint/650): There's an assumption here that MSL's
|
||||
// default structure size and alignment matches WGSL's. We need to
|
||||
// confirm this.
|
||||
return SizeAndAlign{str->Size(), str->Align()};
|
||||
}
|
||||
},
|
||||
|
||||
if (auto* atomic = ty->As<sem::Atomic>()) {
|
||||
[&](const sem::Atomic* atomic) {
|
||||
return MslPackedTypeSizeAndAlign(atomic->Type());
|
||||
}
|
||||
},
|
||||
|
||||
[&](Default) {
|
||||
TINT_UNREACHABLE(Writer, diagnostics_)
|
||||
<< "Unhandled type " << ty->TypeInfo().name;
|
||||
return {};
|
||||
return SizeAndAlign{};
|
||||
});
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
|
|
|
@ -560,33 +560,37 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
|
|||
}
|
||||
|
||||
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
|
||||
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
|
||||
return Switch(
|
||||
expr,
|
||||
[&](const ast::IndexAccessorExpression* a) { //
|
||||
return GenerateAccessorExpression(a);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BinaryExpression>()) {
|
||||
},
|
||||
[&](const ast::BinaryExpression* b) { //
|
||||
return GenerateBinaryExpression(b);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BitcastExpression>()) {
|
||||
},
|
||||
[&](const ast::BitcastExpression* b) { //
|
||||
return GenerateBitcastExpression(b);
|
||||
}
|
||||
if (auto* c = expr->As<ast::CallExpression>()) {
|
||||
},
|
||||
[&](const ast::CallExpression* c) { //
|
||||
return GenerateCallExpression(c);
|
||||
}
|
||||
if (auto* i = expr->As<ast::IdentifierExpression>()) {
|
||||
},
|
||||
[&](const ast::IdentifierExpression* i) { //
|
||||
return GenerateIdentifierExpression(i);
|
||||
}
|
||||
if (auto* l = expr->As<ast::LiteralExpression>()) {
|
||||
},
|
||||
[&](const ast::LiteralExpression* l) { //
|
||||
return GenerateLiteralIfNeeded(nullptr, l);
|
||||
}
|
||||
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
|
||||
},
|
||||
[&](const ast::MemberAccessorExpression* m) { //
|
||||
return GenerateAccessorExpression(m);
|
||||
}
|
||||
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
|
||||
},
|
||||
[&](const ast::UnaryOpExpression* u) { //
|
||||
return GenerateUnaryOpExpression(u);
|
||||
}
|
||||
|
||||
error_ = "unknown expression type: " + std::string(expr->TypeInfo().name);
|
||||
},
|
||||
[&](Default) -> uint32_t {
|
||||
error_ =
|
||||
"unknown expression type: " + std::string(expr->TypeInfo().name);
|
||||
return 0;
|
||||
});
|
||||
}
|
||||
|
||||
bool Builder::GenerateFunction(const ast::Function* func_ast) {
|
||||
|
@ -861,34 +865,57 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
|
|||
push_type(spv::Op::OpVariable, std::move(ops));
|
||||
|
||||
for (auto* attr : var->attributes) {
|
||||
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
|
||||
bool ok = Switch(
|
||||
attr,
|
||||
[&](const ast::BuiltinAttribute* builtin) {
|
||||
push_annot(spv::Op::OpDecorate,
|
||||
{Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
|
||||
Operand::Int(
|
||||
ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
|
||||
} else if (auto* location = attr->As<ast::LocationAttribute>()) {
|
||||
Operand::Int(ConvertBuiltin(builtin->builtin,
|
||||
sem->StorageClass()))});
|
||||
return true;
|
||||
},
|
||||
[&](const ast::LocationAttribute* location) {
|
||||
push_annot(spv::Op::OpDecorate,
|
||||
{Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
|
||||
Operand::Int(location->value)});
|
||||
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::InterpolateAttribute* interpolate) {
|
||||
AddInterpolationDecorations(var_id, interpolate->type,
|
||||
interpolate->sampling);
|
||||
} else if (attr->Is<ast::InvariantAttribute>()) {
|
||||
push_annot(spv::Op::OpDecorate,
|
||||
return true;
|
||||
},
|
||||
[&](const ast::InvariantAttribute*) {
|
||||
push_annot(
|
||||
spv::Op::OpDecorate,
|
||||
{Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
|
||||
} else if (auto* binding = attr->As<ast::BindingAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::BindingAttribute* binding) {
|
||||
push_annot(spv::Op::OpDecorate,
|
||||
{Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
|
||||
Operand::Int(binding->value)});
|
||||
} else if (auto* group = attr->As<ast::GroupAttribute>()) {
|
||||
push_annot(spv::Op::OpDecorate, {Operand::Int(var_id),
|
||||
Operand::Int(SpvDecorationDescriptorSet),
|
||||
return true;
|
||||
},
|
||||
[&](const ast::GroupAttribute* group) {
|
||||
push_annot(
|
||||
spv::Op::OpDecorate,
|
||||
{Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
|
||||
Operand::Int(group->value)});
|
||||
} else if (attr->Is<ast::OverrideAttribute>()) {
|
||||
// Spec constants are handled elsewhere
|
||||
} else if (!attr->Is<ast::InternalAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::OverrideAttribute*) {
|
||||
return true; // Spec constants are handled elsewhere
|
||||
},
|
||||
[&](const ast::InternalAttribute*) {
|
||||
return true; // ignored
|
||||
},
|
||||
[&](Default) {
|
||||
error_ = "unknown attribute";
|
||||
return false;
|
||||
});
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1123,19 +1150,21 @@ uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
|
|||
// promoted to storage with the VarForDynamicIndex transform.
|
||||
|
||||
for (auto* accessor : accessors) {
|
||||
if (auto* array = accessor->As<ast::IndexAccessorExpression>()) {
|
||||
if (!GenerateIndexAccessor(array, &info)) {
|
||||
return 0;
|
||||
}
|
||||
} else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) {
|
||||
if (!GenerateMemberAccessor(member, &info)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
} else {
|
||||
error_ =
|
||||
"invalid accessor in list: " + std::string(accessor->TypeInfo().name);
|
||||
return 0;
|
||||
bool ok = Switch(
|
||||
accessor,
|
||||
[&](const ast::IndexAccessorExpression* array) {
|
||||
return GenerateIndexAccessor(array, &info);
|
||||
},
|
||||
[&](const ast::MemberAccessorExpression* member) {
|
||||
return GenerateMemberAccessor(member, &info);
|
||||
},
|
||||
[&](Default) {
|
||||
error_ = "invalid accessor in list: " +
|
||||
std::string(accessor->TypeInfo().name);
|
||||
return false;
|
||||
});
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1653,21 +1682,28 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
|
|||
constant.constant_id = global->ConstantId();
|
||||
}
|
||||
|
||||
if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
|
||||
Switch(
|
||||
lit,
|
||||
[&](const ast::BoolLiteralExpression* l) {
|
||||
constant.kind = ScalarConstant::Kind::kBool;
|
||||
constant.value.b = l->value;
|
||||
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
|
||||
},
|
||||
[&](const ast::SintLiteralExpression* sl) {
|
||||
constant.kind = ScalarConstant::Kind::kI32;
|
||||
constant.value.i32 = sl->value;
|
||||
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
|
||||
},
|
||||
[&](const ast::UintLiteralExpression* ul) {
|
||||
constant.kind = ScalarConstant::Kind::kU32;
|
||||
constant.value.u32 = ul->value;
|
||||
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* fl) {
|
||||
constant.kind = ScalarConstant::Kind::kF32;
|
||||
constant.value.f32 = fl->value;
|
||||
} else {
|
||||
error_ = "unknown literal type";
|
||||
return 0;
|
||||
},
|
||||
[&](Default) { error_ = "unknown literal type"; });
|
||||
|
||||
if (!error_.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return GenerateConstantIfNeeded(constant);
|
||||
|
@ -2209,19 +2245,25 @@ bool Builder::GenerateBlockStatementWithoutScoping(
|
|||
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
|
||||
auto* call = builder_.Sem().Get(expr);
|
||||
auto* target = call->Target();
|
||||
|
||||
if (auto* func = target->As<sem::Function>()) {
|
||||
return Switch(
|
||||
target,
|
||||
[&](const sem::Function* func) {
|
||||
return GenerateFunctionCall(call, func);
|
||||
}
|
||||
if (auto* builtin = target->As<sem::Builtin>()) {
|
||||
},
|
||||
[&](const sem::Builtin* builtin) {
|
||||
return GenerateBuiltinCall(call, builtin);
|
||||
}
|
||||
if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
|
||||
},
|
||||
[&](const sem::TypeConversion*) {
|
||||
return GenerateTypeConstructorOrConversion(call, nullptr);
|
||||
}
|
||||
},
|
||||
[&](const sem::TypeConstructor*) {
|
||||
return GenerateTypeConstructorOrConversion(call, nullptr);
|
||||
},
|
||||
[&](Default) -> uint32_t {
|
||||
TINT_ICE(Writer, builder_.Diagnostics())
|
||||
<< "unhandled call target: " << target->TypeInfo().name;
|
||||
return false;
|
||||
return 0;
|
||||
});
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
|
||||
|
@ -3790,46 +3832,49 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
|
|||
}
|
||||
|
||||
bool Builder::GenerateStatement(const ast::Statement* stmt) {
|
||||
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
|
||||
return Switch(
|
||||
stmt,
|
||||
[&](const ast::AssignmentStatement* a) {
|
||||
return GenerateAssignStatement(a);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BlockStatement>()) {
|
||||
},
|
||||
[&](const ast::BlockStatement* b) { //
|
||||
return GenerateBlockStatement(b);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BreakStatement>()) {
|
||||
},
|
||||
[&](const ast::BreakStatement* b) { //
|
||||
return GenerateBreakStatement(b);
|
||||
}
|
||||
if (auto* c = stmt->As<ast::CallStatement>()) {
|
||||
},
|
||||
[&](const ast::CallStatement* c) {
|
||||
return GenerateCallExpression(c->expr) != 0;
|
||||
}
|
||||
if (auto* c = stmt->As<ast::ContinueStatement>()) {
|
||||
},
|
||||
[&](const ast::ContinueStatement* c) {
|
||||
return GenerateContinueStatement(c);
|
||||
}
|
||||
if (auto* d = stmt->As<ast::DiscardStatement>()) {
|
||||
},
|
||||
[&](const ast::DiscardStatement* d) {
|
||||
return GenerateDiscardStatement(d);
|
||||
}
|
||||
if (stmt->Is<ast::FallthroughStatement>()) {
|
||||
},
|
||||
[&](const ast::FallthroughStatement*) {
|
||||
// Do nothing here, the fallthrough gets handled by the switch code.
|
||||
return true;
|
||||
}
|
||||
if (auto* i = stmt->As<ast::IfStatement>()) {
|
||||
},
|
||||
[&](const ast::IfStatement* i) { //
|
||||
return GenerateIfStatement(i);
|
||||
}
|
||||
if (auto* l = stmt->As<ast::LoopStatement>()) {
|
||||
},
|
||||
[&](const ast::LoopStatement* l) { //
|
||||
return GenerateLoopStatement(l);
|
||||
}
|
||||
if (auto* r = stmt->As<ast::ReturnStatement>()) {
|
||||
},
|
||||
[&](const ast::ReturnStatement* r) { //
|
||||
return GenerateReturnStatement(r);
|
||||
}
|
||||
if (auto* s = stmt->As<ast::SwitchStatement>()) {
|
||||
},
|
||||
[&](const ast::SwitchStatement* s) { //
|
||||
return GenerateSwitchStatement(s);
|
||||
}
|
||||
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
|
||||
},
|
||||
[&](const ast::VariableDeclStatement* v) {
|
||||
return GenerateVariableDeclStatement(v);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) {
|
||||
error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool Builder::GenerateVariableDeclStatement(
|
||||
|
@ -3872,78 +3917,91 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
|
|||
return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
|
||||
auto result = result_op();
|
||||
auto id = result.to_i();
|
||||
if (auto* arr = type->As<sem::Array>()) {
|
||||
if (!GenerateArrayType(arr, result)) {
|
||||
return 0;
|
||||
}
|
||||
} else if (type->Is<sem::Bool>()) {
|
||||
bool ok = Switch(
|
||||
type,
|
||||
[&](const sem::Array* arr) { //
|
||||
return GenerateArrayType(arr, result);
|
||||
},
|
||||
[&](const sem::Bool*) {
|
||||
push_type(spv::Op::OpTypeBool, {result});
|
||||
} else if (type->Is<sem::F32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
|
||||
} else if (type->Is<sem::I32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
push_type(spv::Op::OpTypeInt,
|
||||
{result, Operand::Int(32), Operand::Int(1)});
|
||||
} else if (auto* mat = type->As<sem::Matrix>()) {
|
||||
if (!GenerateMatrixType(mat, result)) {
|
||||
return 0;
|
||||
}
|
||||
} else if (auto* ptr = type->As<sem::Pointer>()) {
|
||||
if (!GeneratePointerType(ptr, result)) {
|
||||
return 0;
|
||||
}
|
||||
} else if (auto* ref = type->As<sem::Reference>()) {
|
||||
if (!GenerateReferenceType(ref, result)) {
|
||||
return 0;
|
||||
}
|
||||
} else if (auto* str = type->As<sem::Struct>()) {
|
||||
if (!GenerateStructType(str, result)) {
|
||||
return 0;
|
||||
}
|
||||
} else if (type->Is<sem::U32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Matrix* mat) { //
|
||||
return GenerateMatrixType(mat, result);
|
||||
},
|
||||
[&](const sem::Pointer* ptr) { //
|
||||
return GeneratePointerType(ptr, result);
|
||||
},
|
||||
[&](const sem::Reference* ref) { //
|
||||
return GenerateReferenceType(ref, result);
|
||||
},
|
||||
[&](const sem::Struct* str) { //
|
||||
return GenerateStructType(str, result);
|
||||
},
|
||||
[&](const sem::U32*) {
|
||||
push_type(spv::Op::OpTypeInt,
|
||||
{result, Operand::Int(32), Operand::Int(0)});
|
||||
} else if (auto* vec = type->As<sem::Vector>()) {
|
||||
if (!GenerateVectorType(vec, result)) {
|
||||
return 0;
|
||||
}
|
||||
} else if (type->Is<sem::Void>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Vector* vec) { //
|
||||
return GenerateVectorType(vec, result);
|
||||
},
|
||||
[&](const sem::Void*) {
|
||||
push_type(spv::Op::OpTypeVoid, {result});
|
||||
} else if (auto* tex = type->As<sem::Texture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::StorageTexture* tex) {
|
||||
if (!GenerateTextureType(tex, result)) {
|
||||
return 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* st = tex->As<sem::StorageTexture>()) {
|
||||
// Register all three access types of StorageTexture names. In SPIR-V,
|
||||
// we must output a single type, while the variable is annotated with
|
||||
// the access type. Doing this ensures we de-dupe.
|
||||
// Register all three access types of StorageTexture names. In
|
||||
// SPIR-V, we must output a single type, while the variable is
|
||||
// annotated with the access type. Doing this ensures we de-dupe.
|
||||
type_name_to_id_[builder_
|
||||
.create<sem::StorageTexture>(
|
||||
st->dim(), st->texel_format(),
|
||||
ast::Access::kRead, st->type())
|
||||
tex->dim(), tex->texel_format(),
|
||||
ast::Access::kRead, tex->type())
|
||||
->type_name()] = id;
|
||||
type_name_to_id_[builder_
|
||||
.create<sem::StorageTexture>(
|
||||
st->dim(), st->texel_format(),
|
||||
ast::Access::kWrite, st->type())
|
||||
tex->dim(), tex->texel_format(),
|
||||
ast::Access::kWrite, tex->type())
|
||||
->type_name()] = id;
|
||||
type_name_to_id_[builder_
|
||||
.create<sem::StorageTexture>(
|
||||
st->dim(), st->texel_format(),
|
||||
ast::Access::kReadWrite, st->type())
|
||||
tex->dim(), tex->texel_format(),
|
||||
ast::Access::kReadWrite, tex->type())
|
||||
->type_name()] = id;
|
||||
}
|
||||
|
||||
} else if (type->Is<sem::Sampler>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Texture* tex) {
|
||||
return GenerateTextureType(tex, result);
|
||||
},
|
||||
[&](const sem::Sampler*) {
|
||||
push_type(spv::Op::OpTypeSampler, {result});
|
||||
|
||||
// Register both of the sampler type names. In SPIR-V they're the same
|
||||
// sampler type, so we need to match that when we do the dedup check.
|
||||
type_name_to_id_["__sampler_sampler"] = id;
|
||||
type_name_to_id_["__sampler_comparison"] = id;
|
||||
|
||||
} else {
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
error_ = "unable to convert type: " + type->type_name();
|
||||
return false;
|
||||
});
|
||||
|
||||
if (!ok) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -3995,22 +4053,31 @@ bool Builder::GenerateTextureType(const sem::Texture* texture,
|
|||
}
|
||||
|
||||
if (dim == ast::TextureDimension::kCubeArray) {
|
||||
if (texture->Is<sem::SampledTexture>() ||
|
||||
texture->Is<sem::DepthTexture>()) {
|
||||
if (texture->IsAnyOf<sem::SampledTexture, sem::DepthTexture>()) {
|
||||
push_capability(SpvCapabilitySampledCubeArray);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t type_id = 0u;
|
||||
if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
|
||||
type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
|
||||
} else if (auto* s = texture->As<sem::SampledTexture>()) {
|
||||
type_id = GenerateTypeIfNeeded(s->type());
|
||||
} else if (auto* ms = texture->As<sem::MultisampledTexture>()) {
|
||||
type_id = GenerateTypeIfNeeded(ms->type());
|
||||
} else if (auto* st = texture->As<sem::StorageTexture>()) {
|
||||
type_id = GenerateTypeIfNeeded(st->type());
|
||||
}
|
||||
uint32_t type_id = Switch(
|
||||
texture,
|
||||
[&](const sem::DepthTexture*) {
|
||||
return GenerateTypeIfNeeded(builder_.create<sem::F32>());
|
||||
},
|
||||
[&](const sem::DepthMultisampledTexture*) {
|
||||
return GenerateTypeIfNeeded(builder_.create<sem::F32>());
|
||||
},
|
||||
[&](const sem::SampledTexture* t) {
|
||||
return GenerateTypeIfNeeded(t->type());
|
||||
},
|
||||
[&](const sem::MultisampledTexture* t) {
|
||||
return GenerateTypeIfNeeded(t->type());
|
||||
},
|
||||
[&](const sem::StorageTexture* t) {
|
||||
return GenerateTypeIfNeeded(t->type());
|
||||
},
|
||||
[&](Default) -> uint32_t { //
|
||||
return 0u;
|
||||
});
|
||||
if (type_id == 0u) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -68,23 +68,17 @@ GeneratorImpl::~GeneratorImpl() = default;
|
|||
bool GeneratorImpl::Generate() {
|
||||
// Generate global declarations in the order they appear in the module.
|
||||
for (auto* decl : program_->AST().GlobalDeclarations()) {
|
||||
if (auto* td = decl->As<ast::TypeDecl>()) {
|
||||
if (!EmitTypeDecl(td)) {
|
||||
return false;
|
||||
}
|
||||
} else if (auto* func = decl->As<ast::Function>()) {
|
||||
if (!EmitFunction(func)) {
|
||||
return false;
|
||||
}
|
||||
} else if (auto* var = decl->As<ast::Variable>()) {
|
||||
if (!EmitVariable(line(), var)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!Switch(
|
||||
decl, //
|
||||
[&](const ast::TypeDecl* td) { return EmitTypeDecl(td); },
|
||||
[&](const ast::Function* func) { return EmitFunction(func); },
|
||||
[&](const ast::Variable* var) { return EmitVariable(line(), var); },
|
||||
[&](Default) {
|
||||
TINT_UNREACHABLE(Writer, diagnostics_);
|
||||
return false;
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decl != program_->AST().GlobalDeclarations().back()) {
|
||||
line();
|
||||
}
|
||||
|
@ -94,59 +88,64 @@ bool GeneratorImpl::Generate() {
|
|||
}
|
||||
|
||||
bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
|
||||
if (auto* alias = ty->As<ast::Alias>()) {
|
||||
return Switch(
|
||||
ty,
|
||||
[&](const ast::Alias* alias) { //
|
||||
auto out = line();
|
||||
out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
|
||||
if (!EmitType(out, alias->type)) {
|
||||
return false;
|
||||
}
|
||||
out << ";";
|
||||
} else if (auto* str = ty->As<ast::Struct>()) {
|
||||
if (!EmitStructType(str)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Struct* str) { //
|
||||
return EmitStructType(str);
|
||||
},
|
||||
[&](Default) { //
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown declared type: " + std::string(ty->TypeInfo().name));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitExpression(std::ostream& out,
|
||||
const ast::Expression* expr) {
|
||||
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
|
||||
return Switch(
|
||||
expr,
|
||||
[&](const ast::IndexAccessorExpression* a) { //
|
||||
return EmitIndexAccessor(out, a);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BinaryExpression>()) {
|
||||
},
|
||||
[&](const ast::BinaryExpression* b) { //
|
||||
return EmitBinary(out, b);
|
||||
}
|
||||
if (auto* b = expr->As<ast::BitcastExpression>()) {
|
||||
},
|
||||
[&](const ast::BitcastExpression* b) { //
|
||||
return EmitBitcast(out, b);
|
||||
}
|
||||
if (auto* c = expr->As<ast::CallExpression>()) {
|
||||
},
|
||||
[&](const ast::CallExpression* c) { //
|
||||
return EmitCall(out, c);
|
||||
}
|
||||
if (auto* i = expr->As<ast::IdentifierExpression>()) {
|
||||
},
|
||||
[&](const ast::IdentifierExpression* i) { //
|
||||
return EmitIdentifier(out, i);
|
||||
}
|
||||
if (auto* l = expr->As<ast::LiteralExpression>()) {
|
||||
},
|
||||
[&](const ast::LiteralExpression* l) { //
|
||||
return EmitLiteral(out, l);
|
||||
}
|
||||
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
|
||||
},
|
||||
[&](const ast::MemberAccessorExpression* m) { //
|
||||
return EmitMemberAccessor(out, m);
|
||||
}
|
||||
if (expr->Is<ast::PhonyExpression>()) {
|
||||
},
|
||||
[&](const ast::PhonyExpression*) { //
|
||||
out << "_";
|
||||
return true;
|
||||
}
|
||||
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
|
||||
},
|
||||
[&](const ast::UnaryOpExpression* u) { //
|
||||
return EmitUnaryOp(out, u);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown expression type");
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitIndexAccessor(
|
||||
|
@ -250,19 +249,28 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
|
|||
|
||||
bool GeneratorImpl::EmitLiteral(std::ostream& out,
|
||||
const ast::LiteralExpression* lit) {
|
||||
if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {
|
||||
return Switch(
|
||||
lit,
|
||||
[&](const ast::BoolLiteralExpression* bl) { //
|
||||
out << (bl->value ? "true" : "false");
|
||||
} else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* fl) { //
|
||||
out << FloatToBitPreservingString(fl->value);
|
||||
} else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::SintLiteralExpression* sl) { //
|
||||
out << sl->value;
|
||||
} else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::UintLiteralExpression* ul) { //
|
||||
out << ul->value << "u";
|
||||
} else {
|
||||
return true;
|
||||
},
|
||||
[&](Default) { //
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown literal type");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitIdentifier(std::ostream& out,
|
||||
|
@ -366,7 +374,9 @@ bool GeneratorImpl::EmitAccess(std::ostream& out, const ast::Access access) {
|
|||
}
|
||||
|
||||
bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
|
||||
if (auto* ary = ty->As<ast::Array>()) {
|
||||
return Switch(
|
||||
ty,
|
||||
[&](const ast::Array* ary) {
|
||||
for (auto* attr : ary->attributes) {
|
||||
if (auto* stride = attr->As<ast::StrideAttribute>()) {
|
||||
out << "@stride(" << stride->stride << ") ";
|
||||
|
@ -386,13 +396,21 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
|
|||
}
|
||||
|
||||
out << ">";
|
||||
} else if (ty->Is<ast::Bool>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Bool*) {
|
||||
out << "bool";
|
||||
} else if (ty->Is<ast::F32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::F32*) {
|
||||
out << "f32";
|
||||
} else if (ty->Is<ast::I32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::I32*) {
|
||||
out << "i32";
|
||||
} else if (auto* mat = ty->As<ast::Matrix>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Matrix* mat) {
|
||||
out << "mat" << mat->columns << "x" << mat->rows;
|
||||
if (auto* el_ty = mat->type) {
|
||||
out << "<";
|
||||
|
@ -401,7 +419,9 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
|
|||
}
|
||||
out << ">";
|
||||
}
|
||||
} else if (auto* ptr = ty->As<ast::Pointer>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Pointer* ptr) {
|
||||
out << "ptr<" << ptr->storage_class << ", ";
|
||||
if (!EmitType(out, ptr->type)) {
|
||||
return false;
|
||||
|
@ -413,34 +433,58 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
|
|||
}
|
||||
}
|
||||
out << ">";
|
||||
} else if (auto* atomic = ty->As<ast::Atomic>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Atomic* atomic) {
|
||||
out << "atomic<";
|
||||
if (!EmitType(out, atomic->type)) {
|
||||
return false;
|
||||
}
|
||||
out << ">";
|
||||
} else if (auto* sampler = ty->As<ast::Sampler>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Sampler* sampler) {
|
||||
out << "sampler";
|
||||
|
||||
if (sampler->IsComparison()) {
|
||||
out << "_comparison";
|
||||
}
|
||||
} else if (ty->Is<ast::ExternalTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::ExternalTexture*) {
|
||||
out << "texture_external";
|
||||
} else if (auto* texture = ty->As<ast::Texture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Texture* texture) {
|
||||
out << "texture_";
|
||||
if (texture->Is<ast::DepthTexture>()) {
|
||||
bool ok = Switch(
|
||||
texture,
|
||||
[&](const ast::DepthTexture*) { //
|
||||
out << "depth_";
|
||||
} else if (texture->Is<ast::DepthMultisampledTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::DepthMultisampledTexture*) { //
|
||||
out << "depth_multisampled_";
|
||||
} else if (texture->Is<ast::SampledTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::SampledTexture*) { //
|
||||
/* nothing to emit */
|
||||
} else if (texture->Is<ast::MultisampledTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::MultisampledTexture*) { //
|
||||
out << "multisampled_";
|
||||
} else if (texture->Is<ast::StorageTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StorageTexture*) { //
|
||||
out << "storage_";
|
||||
} else {
|
||||
diagnostics_.add_error(diag::System::Writer, "unknown texture type");
|
||||
return true;
|
||||
},
|
||||
[&](Default) { //
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
"unknown texture type");
|
||||
return false;
|
||||
});
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -469,19 +513,25 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (auto* sampled = texture->As<ast::SampledTexture>()) {
|
||||
return Switch(
|
||||
texture,
|
||||
[&](const ast::SampledTexture* sampled) { //
|
||||
out << "<";
|
||||
if (!EmitType(out, sampled->type)) {
|
||||
return false;
|
||||
}
|
||||
out << ">";
|
||||
} else if (auto* ms = texture->As<ast::MultisampledTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::MultisampledTexture* ms) { //
|
||||
out << "<";
|
||||
if (!EmitType(out, ms->type)) {
|
||||
return false;
|
||||
}
|
||||
out << ">";
|
||||
} else if (auto* storage = texture->As<ast::StorageTexture>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StorageTexture* storage) { //
|
||||
out << "<";
|
||||
if (!EmitImageFormat(out, storage->format)) {
|
||||
return false;
|
||||
|
@ -491,11 +541,17 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
|
|||
return false;
|
||||
}
|
||||
out << ">";
|
||||
}
|
||||
|
||||
} else if (ty->Is<ast::U32>()) {
|
||||
return true;
|
||||
},
|
||||
[&](Default) { //
|
||||
return true;
|
||||
});
|
||||
},
|
||||
[&](const ast::U32*) {
|
||||
out << "u32";
|
||||
} else if (auto* vec = ty->As<ast::Vector>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Vector* vec) {
|
||||
out << "vec" << vec->width;
|
||||
if (auto* el_ty = vec->type) {
|
||||
out << "<";
|
||||
|
@ -504,17 +560,22 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
|
|||
}
|
||||
out << ">";
|
||||
}
|
||||
} else if (ty->Is<ast::Void>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::Void*) {
|
||||
out << "void";
|
||||
} else if (auto* tn = ty->As<ast::TypeName>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::TypeName* tn) {
|
||||
out << program_->Symbols().NameFor(tn->name);
|
||||
} else {
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown type in EmitType: " + std::string(ty->TypeInfo().name));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
|
||||
|
@ -632,7 +693,9 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out,
|
|||
}
|
||||
first = false;
|
||||
out << "@";
|
||||
if (auto* workgroup = attr->As<ast::WorkgroupAttribute>()) {
|
||||
bool ok = Switch(
|
||||
attr,
|
||||
[&](const ast::WorkgroupAttribute* workgroup) {
|
||||
auto values = workgroup->Values();
|
||||
out << "workgroup_size(";
|
||||
for (int i = 0; i < 3; i++) {
|
||||
|
@ -646,43 +709,75 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out,
|
|||
}
|
||||
}
|
||||
out << ")";
|
||||
} else if (attr->Is<ast::StructBlockAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StructBlockAttribute*) { //
|
||||
out << "block";
|
||||
} else if (auto* stage = attr->As<ast::StageAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StageAttribute* stage) {
|
||||
out << "stage(" << stage->stage << ")";
|
||||
} else if (auto* binding = attr->As<ast::BindingAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::BindingAttribute* binding) {
|
||||
out << "binding(" << binding->value << ")";
|
||||
} else if (auto* group = attr->As<ast::GroupAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::GroupAttribute* group) {
|
||||
out << "group(" << group->value << ")";
|
||||
} else if (auto* location = attr->As<ast::LocationAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::LocationAttribute* location) {
|
||||
out << "location(" << location->value << ")";
|
||||
} else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::BuiltinAttribute* builtin) {
|
||||
out << "builtin(" << builtin->builtin << ")";
|
||||
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::InterpolateAttribute* interpolate) {
|
||||
out << "interpolate(" << interpolate->type;
|
||||
if (interpolate->sampling != ast::InterpolationSampling::kNone) {
|
||||
out << ", " << interpolate->sampling;
|
||||
}
|
||||
out << ")";
|
||||
} else if (attr->Is<ast::InvariantAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::InvariantAttribute*) {
|
||||
out << "invariant";
|
||||
} else if (auto* override_attr = attr->As<ast::OverrideAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::OverrideAttribute* override_deco) {
|
||||
out << "override";
|
||||
if (override_attr->has_value) {
|
||||
out << "(" << override_attr->value << ")";
|
||||
if (override_deco->has_value) {
|
||||
out << "(" << override_deco->value << ")";
|
||||
}
|
||||
} else if (auto* size = attr->As<ast::StructMemberSizeAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StructMemberSizeAttribute* size) {
|
||||
out << "size(" << size->size << ")";
|
||||
} else if (auto* align = attr->As<ast::StructMemberAlignAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StructMemberAlignAttribute* align) {
|
||||
out << "align(" << align->align << ")";
|
||||
} else if (auto* stride = attr->As<ast::StrideAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::StrideAttribute* stride) {
|
||||
out << "stride(" << stride->stride << ")";
|
||||
} else if (auto* internal = attr->As<ast::InternalAttribute>()) {
|
||||
return true;
|
||||
},
|
||||
[&](const ast::InternalAttribute* internal) {
|
||||
out << "internal(" << internal->InternalName() << ")";
|
||||
} else {
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "Unsupported attribute '" << attr->TypeInfo().name << "'";
|
||||
return false;
|
||||
});
|
||||
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -809,55 +904,36 @@ bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
|
|||
}
|
||||
|
||||
bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
|
||||
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
|
||||
return EmitAssign(a);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BlockStatement>()) {
|
||||
return EmitBlock(b);
|
||||
}
|
||||
if (auto* b = stmt->As<ast::BreakStatement>()) {
|
||||
return EmitBreak(b);
|
||||
}
|
||||
if (auto* c = stmt->As<ast::CallStatement>()) {
|
||||
return Switch(
|
||||
stmt, //
|
||||
[&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
|
||||
[&](const ast::BlockStatement* b) { return EmitBlock(b); },
|
||||
[&](const ast::BreakStatement* b) { return EmitBreak(b); },
|
||||
[&](const ast::CallStatement* c) {
|
||||
auto out = line();
|
||||
if (!EmitCall(out, c->expr)) {
|
||||
return false;
|
||||
}
|
||||
out << ";";
|
||||
return true;
|
||||
}
|
||||
if (auto* c = stmt->As<ast::ContinueStatement>()) {
|
||||
return EmitContinue(c);
|
||||
}
|
||||
if (auto* d = stmt->As<ast::DiscardStatement>()) {
|
||||
return EmitDiscard(d);
|
||||
}
|
||||
if (auto* f = stmt->As<ast::FallthroughStatement>()) {
|
||||
return EmitFallthrough(f);
|
||||
}
|
||||
if (auto* i = stmt->As<ast::IfStatement>()) {
|
||||
return EmitIf(i);
|
||||
}
|
||||
if (auto* l = stmt->As<ast::LoopStatement>()) {
|
||||
return EmitLoop(l);
|
||||
}
|
||||
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
|
||||
return EmitForLoop(l);
|
||||
}
|
||||
if (auto* r = stmt->As<ast::ReturnStatement>()) {
|
||||
return EmitReturn(r);
|
||||
}
|
||||
if (auto* s = stmt->As<ast::SwitchStatement>()) {
|
||||
return EmitSwitch(s);
|
||||
}
|
||||
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
|
||||
},
|
||||
[&](const ast::ContinueStatement* c) { return EmitContinue(c); },
|
||||
[&](const ast::DiscardStatement* d) { return EmitDiscard(d); },
|
||||
[&](const ast::FallthroughStatement* f) { return EmitFallthrough(f); },
|
||||
[&](const ast::IfStatement* i) { return EmitIf(i); },
|
||||
[&](const ast::LoopStatement* l) { return EmitLoop(l); },
|
||||
[&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
|
||||
[&](const ast::ReturnStatement* r) { return EmitReturn(r); },
|
||||
[&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
|
||||
[&](const ast::VariableDeclStatement* v) {
|
||||
return EmitVariable(line(), v->variable);
|
||||
}
|
||||
|
||||
},
|
||||
[&](Default) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"unknown statement type: " + std::string(stmt->TypeInfo().name));
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
|
||||
|
|
Loading…
Reference in New Issue