tint/resolver: Move from STL to tint::utils containers

Change-Id: I883168a1a84457138de85decb921c5c430c32bd8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108702
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-11-09 20:55:33 +00:00 committed by Dawn LUCI CQ
parent 865c3f8e94
commit 9418152d08
17 changed files with 259 additions and 264 deletions

View File

@ -306,7 +306,6 @@ endif()
message(STATUS "Using python3") message(STATUS "Using python3")
find_package(PythonInterp 3 REQUIRED) find_package(PythonInterp 3 REQUIRED)
################################################################################ ################################################################################
# common_compile_options - sets compiler and linker options common for dawn and # common_compile_options - sets compiler and linker options common for dawn and
# tint on the given target # tint on the given target
@ -347,6 +346,12 @@ endif()
# Dawn's public and internal "configs" # Dawn's public and internal "configs"
################################################################################ ################################################################################
set(IS_DEBUG_BUILD 0)
string(TOUPPER "${CMAKE_BUILD_TYPE}" build_type)
if ((NOT ${build_type} STREQUAL "RELEASE") AND (NOT ${build_type} STREQUAL "RELWITHDEBINFO"))
set(IS_DEBUG_BUILD 1)
endif()
# The public config contains only the include paths for the Dawn headers. # The public config contains only the include paths for the Dawn headers.
add_library(dawn_public_config INTERFACE) add_library(dawn_public_config INTERFACE)
target_include_directories(dawn_public_config INTERFACE target_include_directories(dawn_public_config INTERFACE
@ -363,7 +368,7 @@ target_include_directories(dawn_internal_config INTERFACE
target_link_libraries(dawn_internal_config INTERFACE dawn_public_config) target_link_libraries(dawn_internal_config INTERFACE dawn_public_config)
# Compile definitions for the internal config # Compile definitions for the internal config
if (DAWN_ALWAYS_ASSERT OR $<CONFIG:Debug>) if (DAWN_ALWAYS_ASSERT OR IS_DEBUG_BUILD)
target_compile_definitions(dawn_internal_config INTERFACE "DAWN_ENABLE_ASSERTS") target_compile_definitions(dawn_internal_config INTERFACE "DAWN_ENABLE_ASSERTS")
endif() endif()
if (DAWN_ENABLE_D3D12) if (DAWN_ENABLE_D3D12)

View File

@ -1357,9 +1357,8 @@ if(TINT_BUILD_TESTS)
# overflows when resolving deeply nested expression chains or statements. # overflows when resolving deeply nested expression chains or statements.
# Production builds neither use MSVC nor debug, so just bump the stack size # Production builds neither use MSVC nor debug, so just bump the stack size
# for this build combination. # for this build combination.
string(TOUPPER "${CMAKE_BUILD_TYPE}" build_type) if (IS_DEBUG_BUILD)
if ((NOT ${build_type} STREQUAL "RELEASE") AND (NOT ${build_type} STREQUAL "RELWITHDEBINFO")) target_link_options(tint_unittests PRIVATE "/STACK:4194304") # 4MB, default is 1MB
target_link_options(tint_unittests PRIVATE "/STACK 2097152") # 2MB, default is 1MB
endif() endif()
else() else()
target_compile_options(tint_unittests PRIVATE target_compile_options(tint_unittests PRIVATE

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <unordered_set>
#include "src/tint/ast/builtin_texture_helper_test.h" #include "src/tint/ast/builtin_texture_helper_test.h"
#include "src/tint/resolver/resolver_test_helper.h" #include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/type_initializer.h" #include "src/tint/sem/type_initializer.h"

View File

@ -20,7 +20,6 @@
#include <optional> #include <optional>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <utility> #include <utility>
#include "src/tint/program_builder.h" #include "src/tint/program_builder.h"
@ -463,18 +462,18 @@ const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return nullptr; return nullptr;
}, },
[&](const sem::Struct* s) -> const ImplConstant* { [&](const sem::Struct* s) -> const ImplConstant* {
std::unordered_map<const sem::Type*, const ImplConstant*> zero_by_type; utils::Hashmap<const sem::Type*, const ImplConstant*, 8> zero_by_type;
utils::Vector<const sem::Constant*, 4> zeros; utils::Vector<const sem::Constant*, 4> zeros;
zeros.Reserve(s->Members().size()); zeros.Reserve(s->Members().size());
for (auto* member : s->Members()) { for (auto* member : s->Members()) {
auto* zero = utils::GetOrCreate(zero_by_type, member->Type(), auto* zero = zero_by_type.GetOrCreate(
[&] { return ZeroValue(builder, member->Type()); }); member->Type(), [&] { return ZeroValue(builder, member->Type()); });
if (!zero) { if (!zero) {
return nullptr; return nullptr;
} }
zeros.Push(zero); zeros.Push(zero);
} }
if (zero_by_type.size() == 1) { if (zero_by_type.Count() == 1) {
// All members were of the same type, so the zero value is the same for all members. // All members were of the same type, so the zero value is the same for all members.
return builder.create<Splat>(type, zeros[0], s->Members().size()); return builder.create<Splat>(type, zeros[0], s->Members().size());
} }

View File

@ -15,7 +15,6 @@
#include "src/tint/resolver/dependency_graph.h" #include "src/tint/resolver/dependency_graph.h"
#include <string> #include <string>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -117,7 +116,7 @@ struct DependencyEdgeCmp {
/// A map of DependencyEdge to DependencyInfo /// A map of DependencyEdge to DependencyInfo
using DependencyEdges = using DependencyEdges =
std::unordered_map<DependencyEdge, DependencyInfo, DependencyEdgeCmp, DependencyEdgeCmp>; utils::Hashmap<DependencyEdge, DependencyInfo, 64, DependencyEdgeCmp, DependencyEdgeCmp>;
/// Global describes a module-scope variable, type or function. /// Global describes a module-scope variable, type or function.
struct Global { struct Global {
@ -126,11 +125,11 @@ struct Global {
/// The declaration ast::Node /// The declaration ast::Node
const ast::Node* node; const ast::Node* node;
/// A list of dependencies that this global depends on /// A list of dependencies that this global depends on
std::vector<Global*> deps; utils::Vector<Global*, 8> deps;
}; };
/// A map of global name to Global /// A map of global name to Global
using GlobalMap = std::unordered_map<Symbol, Global*>; using GlobalMap = utils::Hashmap<Symbol, Global*, 16>;
/// Raises an ICE that a global ast::Node type was not handled by this system. /// Raises an ICE that a global ast::Node type was not handled by this system.
void UnhandledNode(diag::List& diagnostics, const ast::Node* node) { void UnhandledNode(diag::List& diagnostics, const ast::Node* node) {
@ -170,7 +169,7 @@ class DependencyScanner {
dependency_edges_(edges) { dependency_edges_(edges) {
// Register all the globals at global-scope // Register all the globals at global-scope
for (auto it : globals_by_name) { for (auto it : globals_by_name) {
scope_stack_.Set(it.first, it.second->node); scope_stack_.Set(it.key, it.value->node);
} }
} }
@ -232,7 +231,7 @@ class DependencyScanner {
for (auto* param : func->params) { for (auto* param : func->params) {
if (auto* shadows = scope_stack_.Get(param->symbol)) { if (auto* shadows = scope_stack_.Get(param->symbol)) {
graph_.shadows.emplace(param, shadows); graph_.shadows.Add(param, shadows);
} }
Declare(param->symbol, param); Declare(param->symbol, param);
} }
@ -306,7 +305,7 @@ class DependencyScanner {
}, },
[&](const ast::VariableDeclStatement* v) { [&](const ast::VariableDeclStatement* v) {
if (auto* shadows = scope_stack_.Get(v->variable->symbol)) { if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
graph_.shadows.emplace(v->variable, shadows); graph_.shadows.Add(v->variable, shadows);
} }
TraverseType(v->variable->type); TraverseType(v->variable->type);
TraverseExpression(v->variable->initializer); TraverseExpression(v->variable->initializer);
@ -473,16 +472,14 @@ class DependencyScanner {
} }
} }
if (auto* global = utils::Lookup(globals_, to); global && global->node == resolved) { if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
if (dependency_edges_ if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
.emplace(DependencyEdge{current_global_, global}, DependencyInfo{from->source, action})) {
DependencyInfo{from->source, action}) current_global_->deps.Push(*global);
.second) {
current_global_->deps.emplace_back(global);
} }
} }
graph_.resolved_symbols.emplace(from, resolved); graph_.resolved_symbols.Add(from, resolved);
} }
/// @returns true if `name` is the name of a builtin function /// @returns true if `name` is the name of a builtin function
@ -497,7 +494,7 @@ class DependencyScanner {
source); source);
} }
using VariableMap = std::unordered_map<Symbol, const ast::Variable*>; using VariableMap = utils::Hashmap<Symbol, const ast::Variable*, 32>;
const SymbolTable& symbols_; const SymbolTable& symbols_;
const GlobalMap& globals_; const GlobalMap& globals_;
diag::List& diagnostics_; diag::List& diagnostics_;
@ -520,7 +517,7 @@ struct DependencyAnalysis {
/// @returns true if analysis found no errors, otherwise false. /// @returns true if analysis found no errors, otherwise false.
bool Run(const ast::Module& module) { bool Run(const ast::Module& module) {
// Reserve container memory // Reserve container memory
graph_.resolved_symbols.reserve(module.GlobalDeclarations().Length()); graph_.resolved_symbols.Reserve(module.GlobalDeclarations().Length());
sorted_.Reserve(module.GlobalDeclarations().Length()); sorted_.Reserve(module.GlobalDeclarations().Length());
// Collect all the named globals from the AST module // Collect all the named globals from the AST module
@ -589,9 +586,9 @@ struct DependencyAnalysis {
for (auto* node : module.GlobalDeclarations()) { for (auto* node : module.GlobalDeclarations()) {
auto* global = allocator_.Create(node); auto* global = allocator_.Create(node);
if (auto symbol = SymbolOf(node); symbol.IsValid()) { if (auto symbol = SymbolOf(node); symbol.IsValid()) {
globals_.emplace(symbol, global); globals_.Add(symbol, global);
} }
declaration_order_.emplace_back(global); declaration_order_.Push(global);
} }
} }
@ -625,16 +622,16 @@ struct DependencyAnalysis {
return; return;
} }
std::vector<Entry> stack{Entry{root, 0}}; utils::Vector<Entry, 16> stack{Entry{root, 0}};
while (true) { while (true) {
auto& entry = stack.back(); auto& entry = stack.Back();
// Have we exhausted the dependencies of entry.global? // Have we exhausted the dependencies of entry.global?
if (entry.dep_idx < entry.global->deps.size()) { if (entry.dep_idx < entry.global->deps.Length()) {
// No, there's more dependencies to traverse. // No, there's more dependencies to traverse.
auto& dep = entry.global->deps[entry.dep_idx]; auto& dep = entry.global->deps[entry.dep_idx];
// Does the caller want to enter this dependency? // Does the caller want to enter this dependency?
if (enter(dep)) { // Yes. if (enter(dep)) { // Yes.
stack.push_back(Entry{dep, 0}); // Enter the dependency. stack.Push(Entry{dep, 0}); // Enter the dependency.
} else { } else {
entry.dep_idx++; // No. Skip this node. entry.dep_idx++; // No. Skip this node.
} }
@ -643,11 +640,11 @@ struct DependencyAnalysis {
// Exit this global, pop the stack, and if there's another parent node, // Exit this global, pop the stack, and if there's another parent node,
// increment its dependency index, and loop again. // increment its dependency index, and loop again.
exit(entry.global); exit(entry.global);
stack.pop_back(); stack.Pop();
if (stack.empty()) { if (stack.IsEmpty()) {
return; // All done. return; // All done.
} }
stack.back().dep_idx++; stack.Back().dep_idx++;
} }
} }
} }
@ -707,9 +704,8 @@ struct DependencyAnalysis {
/// of global `from` depending on `to`. /// of global `from` depending on `to`.
/// @note will raise an ICE if the edge is not found. /// @note will raise an ICE if the edge is not found.
DependencyInfo DepInfoFor(const Global* from, const Global* to) const { DependencyInfo DepInfoFor(const Global* from, const Global* to) const {
auto it = dependency_edges_.find(DependencyEdge{from, to}); if (auto info = dependency_edges_.Find(DependencyEdge{from, to})) {
if (it != dependency_edges_.end()) { return *info;
return it->second;
} }
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "failed to find dependency info for edge: '" << NameOf(from->node) << "' -> '" << "failed to find dependency info for edge: '" << NameOf(from->node) << "' -> '"
@ -762,7 +758,7 @@ struct DependencyAnalysis {
printf("------ dependencies ------ \n"); printf("------ dependencies ------ \n");
for (auto* node : sorted_) { for (auto* node : sorted_) {
auto symbol = SymbolOf(node); auto symbol = SymbolOf(node);
auto* global = globals_.at(symbol); auto* global = *globals_.Find(symbol);
printf("%s depends on:\n", symbols_.NameFor(symbol).c_str()); printf("%s depends on:\n", symbols_.NameFor(symbol).c_str());
for (auto* dep : global->deps) { for (auto* dep : global->deps) {
printf(" %s\n", NameOf(dep->node).c_str()); printf(" %s\n", NameOf(dep->node).c_str());
@ -791,7 +787,7 @@ struct DependencyAnalysis {
DependencyEdges dependency_edges_; DependencyEdges dependency_edges_;
/// Globals in declaration order. Populated by GatherGlobals(). /// Globals in declaration order. Populated by GatherGlobals().
std::vector<Global*> declaration_order_; utils::Vector<Global*, 64> declaration_order_;
/// Globals in sorted dependency order. Populated by SortGlobals(). /// Globals in sorted dependency order. Populated by SortGlobals().
utils::UniqueVector<const ast::Node*, 64> sorted_; utils::UniqueVector<const ast::Node*, 64> sorted_;

View File

@ -15,11 +15,11 @@
#ifndef SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_ #ifndef SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
#define SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_ #define SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
#include <unordered_map>
#include <vector> #include <vector>
#include "src/tint/ast/module.h" #include "src/tint/ast/module.h"
#include "src/tint/diagnostic/diagnostic.h" #include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/utils/hashmap.h"
namespace tint::resolver { namespace tint::resolver {
@ -50,13 +50,13 @@ struct DependencyGraph {
/// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or /// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or
/// variable that declares the symbol. /// variable that declares the symbol.
std::unordered_map<const ast::Node*, const ast::Node*> resolved_symbols; utils::Hashmap<const ast::Node*, const ast::Node*, 64> resolved_symbols;
/// Map of ast::Variable to a type, function, or variable that is shadowed by /// Map of ast::Variable to a type, function, or variable that is shadowed by
/// the variable key. A declaration (X) shadows another (Y) if X and Y use /// the variable key. A declaration (X) shadows another (Y) if X and Y use
/// the same symbol, and X is declared in a sub-scope of the scope that /// the same symbol, and X is declared in a sub-scope of the scope that
/// declares Y. /// declares Y.
std::unordered_map<const ast::Variable*, const ast::Node*> shadows; utils::Hashmap<const ast::Variable*, const ast::Node*, 16> shadows;
}; };
} // namespace tint::resolver } // namespace tint::resolver

View File

@ -1128,9 +1128,10 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
if (expect_pass) { if (expect_pass) {
// Check that the use resolves to the declaration // Check that the use resolves to the declaration
auto* resolved_symbol = graph.resolved_symbols[use]; auto* resolved_symbol = graph.resolved_symbols.Find(use);
EXPECT_EQ(resolved_symbol, decl) ASSERT_NE(resolved_symbol, nullptr);
<< "resolved: " << (resolved_symbol ? resolved_symbol->TypeInfo().name : "<null>") EXPECT_EQ(*resolved_symbol, decl)
<< "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
<< "\n" << "\n"
<< "decl: " << decl->TypeInfo().name; << "decl: " << decl->TypeInfo().name;
} }
@ -1177,7 +1178,10 @@ TEST_P(ResolverDependencyShadowTest, Test) {
: helper.parameters[0]; : helper.parameters[0];
helper.Build(); helper.Build();
EXPECT_EQ(Build().shadows[inner_var], outer); auto shadows = Build().shadows;
auto* shadow = shadows.Find(inner_var);
ASSERT_NE(shadow, nullptr);
EXPECT_EQ(*shadow, outer);
} }
INSTANTIATE_TEST_SUITE_P(LocalShadowGlobal, INSTANTIATE_TEST_SUITE_P(LocalShadowGlobal,
@ -1308,8 +1312,9 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
auto graph = Build(); auto graph = Build();
for (auto use : symbol_uses) { for (auto use : symbol_uses) {
auto* resolved_symbol = graph.resolved_symbols[use.use]; auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
EXPECT_EQ(resolved_symbol, use.decl) << use.where; ASSERT_NE(resolved_symbol, nullptr) << use.where;
EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
} }
} }

View File

@ -16,7 +16,6 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <unordered_map>
#include <utility> #include <utility>
#include "src/tint/ast/binary_expression.h" #include "src/tint/ast/binary_expression.h"
@ -36,7 +35,7 @@
#include "src/tint/sem/type_conversion.h" #include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/type_initializer.h" #include "src/tint/sem/type_initializer.h"
#include "src/tint/utils/hash.h" #include "src/tint/utils/hash.h"
#include "src/tint/utils/map.h" #include "src/tint/utils/hashmap.h"
#include "src/tint/utils/math.h" #include "src/tint/utils/math.h"
#include "src/tint/utils/scoped_assignment.h" #include "src/tint/utils/scoped_assignment.h"
@ -1114,10 +1113,10 @@ class Impl : public IntrinsicTable {
ProgramBuilder& builder; ProgramBuilder& builder;
Matchers matchers; Matchers matchers;
std::unordered_map<IntrinsicPrototype, sem::Builtin*, IntrinsicPrototype::Hasher> builtins; utils::Hashmap<IntrinsicPrototype, sem::Builtin*, 64, IntrinsicPrototype::Hasher> builtins;
std::unordered_map<IntrinsicPrototype, sem::TypeInitializer*, IntrinsicPrototype::Hasher> utils::Hashmap<IntrinsicPrototype, sem::TypeInitializer*, 16, IntrinsicPrototype::Hasher>
initializers; initializers;
std::unordered_map<IntrinsicPrototype, sem::TypeConversion*, IntrinsicPrototype::Hasher> utils::Hashmap<IntrinsicPrototype, sem::TypeConversion*, 16, IntrinsicPrototype::Hasher>
converters; converters;
}; };
@ -1185,7 +1184,7 @@ Impl::Builtin Impl::Lookup(sem::BuiltinType builtin_type,
} }
// De-duplicate builtins that are identical. // De-duplicate builtins that are identical.
auto* sem = utils::GetOrCreate(builtins, match, [&] { auto* sem = builtins.GetOrCreate(match, [&] {
utils::Vector<sem::Parameter*, kNumFixedParams> params; utils::Vector<sem::Parameter*, kNumFixedParams> params;
params.Reserve(match.parameters.Length()); params.Reserve(match.parameters.Length());
for (auto& p : match.parameters) { for (auto& p : match.parameters) {
@ -1396,7 +1395,7 @@ IntrinsicTable::InitOrConv Impl::Lookup(InitConvIntrinsic type,
} }
auto eval_stage = match.overload->const_eval_fn ? sem::EvaluationStage::kConstant auto eval_stage = match.overload->const_eval_fn ? sem::EvaluationStage::kConstant
: sem::EvaluationStage::kRuntime; : sem::EvaluationStage::kRuntime;
auto* target = utils::GetOrCreate(initializers, match, [&]() { auto* target = initializers.GetOrCreate(match, [&]() {
return builder.create<sem::TypeInitializer>(match.return_type, std::move(params), return builder.create<sem::TypeInitializer>(match.return_type, std::move(params),
eval_stage); eval_stage);
}); });
@ -1404,7 +1403,7 @@ IntrinsicTable::InitOrConv Impl::Lookup(InitConvIntrinsic type,
} }
// Conversion. // Conversion.
auto* target = utils::GetOrCreate(converters, match, [&]() { auto* target = converters.GetOrCreate(match, [&]() {
auto param = builder.create<sem::Parameter>( auto param = builder.create<sem::Parameter>(
nullptr, 0u, match.parameters[0].type, ast::AddressSpace::kNone, nullptr, 0u, match.parameters[0].type, ast::AddressSpace::kNone,
ast::Access::kUndefined, match.parameters[0].usage); ast::Access::kUndefined, match.parameters[0].usage);

View File

@ -482,7 +482,7 @@ sem::Variable* Resolver::Override(const ast::Override* v) {
sem->SetOverrideId(o); sem->SetOverrideId(o);
// Track the constant IDs that are specified in the shader. // Track the constant IDs that are specified in the shader.
override_ids_.emplace(o, sem); override_ids_.Add(o, sem);
} }
builder_->Sem().Add(v, sem); builder_->Sem().Add(v, sem);
@ -842,7 +842,7 @@ bool Resolver::AllocateOverridableConstantIds() {
id = builder_->Sem().Get<sem::GlobalVariable>(override)->OverrideId(); id = builder_->Sem().Get<sem::GlobalVariable>(override)->OverrideId();
} else { } else {
// No ID was specified, so allocate the next available ID. // No ID was specified, so allocate the next available ID.
while (!ids_exhausted && override_ids_.count(next_id)) { while (!ids_exhausted && override_ids_.Contains(next_id)) {
increment_next_id(); increment_next_id();
} }
if (ids_exhausted) { if (ids_exhausted) {
@ -864,9 +864,9 @@ bool Resolver::AllocateOverridableConstantIds() {
void Resolver::SetShadows() { void Resolver::SetShadows() {
for (auto it : dependencies_.shadows) { for (auto it : dependencies_.shadows) {
Switch( Switch(
sem_.Get(it.first), // sem_.Get(it.key), //
[&](sem::LocalVariable* local) { local->SetShadows(sem_.Get(it.second)); }, [&](sem::LocalVariable* local) { local->SetShadows(sem_.Get(it.value)); },
[&](sem::Parameter* param) { param->SetShadows(sem_.Get(it.second)); }); [&](sem::Parameter* param) { param->SetShadows(sem_.Get(it.value)); });
} }
} }
@ -923,7 +923,7 @@ sem::Statement* Resolver::StaticAssert(const ast::StaticAssert* assertion) {
sem::Function* Resolver::Function(const ast::Function* decl) { sem::Function* Resolver::Function(const ast::Function* decl) {
uint32_t parameter_index = 0; uint32_t parameter_index = 0;
std::unordered_map<Symbol, Source> parameter_names; utils::Hashmap<Symbol, Source, 8> parameter_names;
utils::Vector<sem::Parameter*, 8> parameters; utils::Vector<sem::Parameter*, 8> parameters;
// Resolve all the parameters // Resolve all the parameters
@ -931,11 +931,10 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
Mark(param); Mark(param);
{ // Check the parameter name is unique for the function { // Check the parameter name is unique for the function
auto emplaced = parameter_names.emplace(param->symbol, param->source); if (auto added = parameter_names.Add(param->symbol, param->source); !added) {
if (!emplaced.second) {
auto name = builder_->Symbols().NameFor(param->symbol); auto name = builder_->Symbols().NameFor(param->symbol);
AddError("redefinition of parameter '" + name + "'", param->source); AddError("redefinition of parameter '" + name + "'", param->source);
AddNote("previous definition is here", emplaced.first->second); AddNote("previous definition is here", *added.value);
return nullptr; return nullptr;
} }
} }
@ -1031,7 +1030,7 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
} }
if (decl->IsEntryPoint()) { if (decl->IsEntryPoint()) {
entry_points_.emplace_back(func); entry_points_.Push(func);
} }
if (decl->body) { if (decl->body) {
@ -1850,8 +1849,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
[&](const sem::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); }, [&](const sem::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); },
[&](const sem::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); }, [&](const sem::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); },
[&](const sem::Array* arr) -> sem::Call* { [&](const sem::Array* arr) -> sem::Call* {
auto* call_target = utils::GetOrCreate( auto* call_target = array_inits_.GetOrCreate(
array_inits_, ArrayInitializerSig{{arr, args.Length(), args_stage}}, ArrayInitializerSig{{arr, args.Length(), args_stage}},
[&]() -> sem::TypeInitializer* { [&]() -> sem::TypeInitializer* {
auto params = utils::Transform(args, [&](auto, size_t i) { auto params = utils::Transform(args, [&](auto, size_t i) {
return builder_->create<sem::Parameter>( return builder_->create<sem::Parameter>(
@ -1877,8 +1876,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
return call; return call;
}, },
[&](const sem::Struct* str) -> sem::Call* { [&](const sem::Struct* str) -> sem::Call* {
auto* call_target = utils::GetOrCreate( auto* call_target = struct_inits_.GetOrCreate(
struct_inits_, StructInitializerSig{{str, args.Length(), args_stage}}, StructInitializerSig{{str, args.Length(), args_stage}},
[&]() -> sem::TypeInitializer* { [&]() -> sem::TypeInitializer* {
utils::Vector<const sem::Parameter*, 8> params; utils::Vector<const sem::Parameter*, 8> params;
params.Resize(std::min(args.Length(), str->Members().size())); params.Resize(std::min(args.Length(), str->Members().size()));
@ -1981,9 +1980,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
AddError( AddError(
"cannot infer common array element type from initializer arguments", "cannot infer common array element type from initializer arguments",
expr->source); expr->source);
std::unordered_set<const sem::Type*> types; utils::Hashset<const sem::Type*, 8> types;
for (size_t i = 0; i < args.Length(); i++) { for (size_t i = 0; i < args.Length(); i++) {
if (types.emplace(args[i]->Type()).second) { if (types.Add(args[i]->Type())) {
AddNote("argument " + std::to_string(i) + " is of type '" + AddNote("argument " + std::to_string(i) + " is of type '" +
sem_.TypeNameOf(args[i]->Type()) + "'", sem_.TypeNameOf(args[i]->Type()) + "'",
args[i]->Declaration()->source); args[i]->Declaration()->source);
@ -2687,11 +2686,10 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
} }
if (el_ty->Is<sem::Atomic>()) { if (el_ty->Is<sem::Atomic>()) {
atomic_composite_info_.emplace(out, arr->type->source); atomic_composite_info_.Add(out, &arr->type->source);
} else { } else {
auto found = atomic_composite_info_.find(el_ty); if (auto* found = atomic_composite_info_.Find(el_ty)) {
if (found != atomic_composite_info_.end()) { atomic_composite_info_.Add(out, *found);
atomic_composite_info_.emplace(out, found->second);
} }
} }
@ -2832,15 +2830,14 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
// validation. // validation.
uint64_t struct_size = 0; uint64_t struct_size = 0;
uint64_t struct_align = 1; uint64_t struct_align = 1;
std::unordered_map<Symbol, const ast::StructMember*> member_map; utils::Hashmap<Symbol, const ast::StructMember*, 8> member_map;
for (auto* member : str->members) { for (auto* member : str->members) {
Mark(member); Mark(member);
auto result = member_map.emplace(member->symbol, member); if (auto added = member_map.Add(member->symbol, member); !added) {
if (!result.second) {
AddError("redefinition of '" + builder_->Symbols().NameFor(member->symbol) + "'", AddError("redefinition of '" + builder_->Symbols().NameFor(member->symbol) + "'",
member->source); member->source);
AddNote("previous definition is here", result.first->second->source); AddNote("previous definition is here", (*added.value)->source);
return nullptr; return nullptr;
} }
@ -3027,12 +3024,11 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
for (size_t i = 0; i < sem_members.size(); i++) { for (size_t i = 0; i < sem_members.size(); i++) {
auto* mem_type = sem_members[i]->Type(); auto* mem_type = sem_members[i]->Type();
if (mem_type->Is<sem::Atomic>()) { if (mem_type->Is<sem::Atomic>()) {
atomic_composite_info_.emplace(out, sem_members[i]->Declaration()->source); atomic_composite_info_.Add(out, &sem_members[i]->Declaration()->source);
break; break;
} else { } else {
auto found = atomic_composite_info_.find(mem_type); if (auto* found = atomic_composite_info_.Find(mem_type)) {
if (found != atomic_composite_info_.end()) { atomic_composite_info_.Add(out, *found);
atomic_composite_info_.emplace(out, found->second);
break; break;
} }
} }

View File

@ -18,8 +18,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -434,13 +432,13 @@ class Resolver {
SemHelper sem_; SemHelper sem_;
Validator validator_; Validator validator_;
ast::Extensions enabled_extensions_; ast::Extensions enabled_extensions_;
std::vector<sem::Function*> entry_points_; utils::Vector<sem::Function*, 8> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_; utils::Hashmap<const sem::Type*, const Source*, 8> atomic_composite_info_;
utils::Bitset<0> marked_; utils::Bitset<0> marked_;
ExprEvalStageConstraint expr_eval_stage_constraint_; ExprEvalStageConstraint expr_eval_stage_constraint_;
std::unordered_map<OverrideId, const sem::Variable*> override_ids_; utils::Hashmap<OverrideId, const sem::Variable*, 8> override_ids_;
std::unordered_map<ArrayInitializerSig, sem::CallTarget*> array_inits_; utils::Hashmap<ArrayInitializerSig, sem::CallTarget*, 8> array_inits_;
std::unordered_map<StructInitializerSig, sem::CallTarget*> struct_inits_; utils::Hashmap<StructInitializerSig, sem::CallTarget*, 8> struct_inits_;
sem::Function* current_function_ = nullptr; sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr; sem::Statement* current_statement_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr; sem::CompoundStatement* current_compound_statement_ = nullptr;

View File

@ -54,8 +54,8 @@ class SemHelper {
/// @param node the node to retrieve /// @param node the node to retrieve
template <typename SEM = sem::Node> template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const { SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node); auto* resolved = dependencies_.resolved_symbols.Find(node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved)) : nullptr; return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
} }
/// @returns the resolved type of the ast::Expression `expr` /// @returns the resolved type of the ast::Expression `expr`

View File

@ -16,8 +16,6 @@
#include <limits> #include <limits>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -139,7 +137,7 @@ struct ParameterInfo {
bool pointer_may_become_non_uniform = false; bool pointer_may_become_non_uniform = false;
/// The parameters that are required to be uniform for the contents of this pointer parameter to /// The parameters that are required to be uniform for the contents of this pointer parameter to
/// be uniform at function exit. /// be uniform at function exit.
std::vector<const sem::Parameter*> pointer_param_output_sources; utils::Vector<const sem::Parameter*, 8> pointer_param_output_sources;
/// The node in the graph that corresponds to this parameter's initial value. /// The node in the graph that corresponds to this parameter's initial value.
Node* init_value; Node* init_value;
/// The node in the graph that corresponds to this parameter's output value (or nullptr). /// The node in the graph that corresponds to this parameter's output value (or nullptr).
@ -166,7 +164,7 @@ struct FunctionInfo {
} }
// Create nodes for parameters. // Create nodes for parameters.
parameters.resize(func->params.Length()); parameters.Resize(func->params.Length());
for (size_t i = 0; i < func->params.Length(); i++) { for (size_t i = 0; i < func->params.Length(); i++) {
auto* param = func->params[i]; auto* param = func->params[i];
auto param_name = builder->Symbols().NameFor(param->symbol); auto param_name = builder->Symbols().NameFor(param->symbol);
@ -177,7 +175,7 @@ struct FunctionInfo {
if (sem->Type()->Is<sem::Pointer>()) { if (sem->Type()->Is<sem::Pointer>()) {
node_init = CreateNode("ptrparam_" + name + "_init"); node_init = CreateNode("ptrparam_" + name + "_init");
parameters[i].pointer_return_value = CreateNode("ptrparam_" + name + "_return"); parameters[i].pointer_return_value = CreateNode("ptrparam_" + name + "_return");
local_var_decls.insert(sem); local_var_decls.Add(sem);
} else { } else {
node_init = CreateNode("param_" + name); node_init = CreateNode("param_" + name);
} }
@ -194,7 +192,7 @@ struct FunctionInfo {
/// The function's uniformity effects. /// The function's uniformity effects.
FunctionTag function_tag; FunctionTag function_tag;
/// The uniformity requirements of the function's parameters. /// The uniformity requirements of the function's parameters.
std::vector<ParameterInfo> parameters; utils::Vector<ParameterInfo, 8> parameters;
/// The control flow graph. /// The control flow graph.
utils::BlockAllocator<Node> nodes; utils::BlockAllocator<Node> nodes;
@ -213,24 +211,31 @@ struct FunctionInfo {
/// The set of a local read-write vars that are in scope at any given point in the process. /// The set of a local read-write vars that are in scope at any given point in the process.
/// Includes pointer parameters. /// Includes pointer parameters.
std::unordered_set<const sem::Variable*> local_var_decls; utils::Hashset<const sem::Variable*, 8> local_var_decls;
/// The set of partial pointer variables - pointers that point to a subobject (into an array or /// The set of partial pointer variables - pointers that point to a subobject (into an array or
/// struct). /// struct).
std::unordered_set<const sem::Variable*> partial_ptrs; utils::Hashset<const sem::Variable*, 4> partial_ptrs;
/// LoopSwitchInfo tracks information about the value of variables for a control flow construct. /// LoopSwitchInfo tracks information about the value of variables for a control flow construct.
struct LoopSwitchInfo { struct LoopSwitchInfo {
/// The type of this control flow construct. /// The type of this control flow construct.
std::string type; std::string type;
/// The input values for local variables at the start of this construct. /// The input values for local variables at the start of this construct.
std::unordered_map<const sem::Variable*, Node*> var_in_nodes; utils::Hashmap<const sem::Variable*, Node*, 8> var_in_nodes;
/// The exit values for local variables at the end of this construct. /// The exit values for local variables at the end of this construct.
std::unordered_map<const sem::Variable*, Node*> var_exit_nodes; utils::Hashmap<const sem::Variable*, Node*, 8> var_exit_nodes;
}; };
/// Map from control flow statements to the corresponding LoopSwitchInfo structure. /// @returns a LoopSwitchInfo for the given statement, allocating the LoopSwitchInfo if this is
std::unordered_map<const sem::Statement*, LoopSwitchInfo> loop_switch_infos; /// the first call with the given statement.
LoopSwitchInfo& LoopSwitchInfoFor(const sem::Statement* stmt) {
return *loop_switch_infos.GetOrCreate(stmt,
[&] { return loop_switch_info_allocator.Create(); });
}
/// Disassociates the LoopSwitchInfo for the given statement.
void RemoveLoopSwitchInfoFor(const sem::Statement* stmt) { loop_switch_infos.Remove(stmt); }
/// Create a new node. /// Create a new node.
/// @param tag a tag used to identify the node for debugging purposes /// @param tag a tag used to identify the node for debugging purposes
@ -263,7 +268,13 @@ struct FunctionInfo {
private: private:
/// A list of tags that have already been used within the current function. /// A list of tags that have already been used within the current function.
std::unordered_set<std::string> tags_; utils::Hashset<std::string, 8> tags_;
/// Map from control flow statements to the corresponding LoopSwitchInfo structure.
utils::Hashmap<const sem::Statement*, LoopSwitchInfo*, 8> loop_switch_infos;
/// Allocator of LoopSwitchInfos
utils::BlockAllocator<LoopSwitchInfo> loop_switch_info_allocator;
}; };
/// UniformityGraph is used to analyze the uniformity requirements and effects of functions in a /// UniformityGraph is used to analyze the uniformity requirements and effects of functions in a
@ -312,7 +323,7 @@ class UniformityGraph {
diag::List& diagnostics_; diag::List& diagnostics_;
/// Map of analyzed function results. /// Map of analyzed function results.
std::unordered_map<const ast::Function*, FunctionInfo> functions_; utils::Hashmap<const ast::Function*, FunctionInfo, 8> functions_;
/// The function currently being analyzed. /// The function currently being analyzed.
FunctionInfo* current_function_; FunctionInfo* current_function_;
@ -329,8 +340,7 @@ class UniformityGraph {
/// @param func the function to process /// @param func the function to process
/// @returns true if there are no uniformity issues, false otherwise /// @returns true if there are no uniformity issues, false otherwise
bool ProcessFunction(const ast::Function* func) { bool ProcessFunction(const ast::Function* func) {
functions_.emplace(func, FunctionInfo(func, builder_)); current_function_ = functions_.Add(func, FunctionInfo(func, builder_)).value;
current_function_ = &functions_.at(func);
// Process function body. // Process function body.
if (func->body) { if (func->body) {
@ -410,7 +420,7 @@ class UniformityGraph {
for (size_t j = 0; j < func->params.Length(); j++) { for (size_t j = 0; j < func->params.Length(); j++) {
auto* param_source = sem_.Get<sem::Parameter>(func->params[j]); auto* param_source = sem_.Get<sem::Parameter>(func->params[j]);
if (reachable.Contains(current_function_->parameters[j].init_value)) { if (reachable.Contains(current_function_->parameters[j].init_value)) {
current_function_->parameters[i].pointer_param_output_sources.push_back( current_function_->parameters[i].pointer_param_output_sources.Push(
param_source); param_source);
} }
} }
@ -439,7 +449,7 @@ class UniformityGraph {
}, },
[&](const ast::BlockStatement* b) { [&](const ast::BlockStatement* b) {
std::unordered_map<const sem::Variable*, Node*> scoped_assignments; utils::Hashmap<const sem::Variable*, Node*, 8> scoped_assignments;
{ {
// Push a new scope for variable assignments in the block. // Push a new scope for variable assignments in the block.
current_function_->variables.Push(); current_function_->variables.Push();
@ -472,13 +482,13 @@ class UniformityGraph {
if (behaviors.Contains(sem::Behavior::kNext) || if (behaviors.Contains(sem::Behavior::kNext) ||
behaviors.Contains(sem::Behavior::kFallthrough)) { behaviors.Contains(sem::Behavior::kFallthrough)) {
for (auto var : scoped_assignments) { for (auto var : scoped_assignments) {
current_function_->variables.Set(var.first, var.second); current_function_->variables.Set(var.key, var.value);
} }
} }
// Remove any variables declared in this scope from the set of in-scope variables. // Remove any variables declared in this scope from the set of in-scope variables.
for (auto decl : sem_.Get<sem::BlockStatement>(b)->Decls()) { for (auto decl : sem_.Get<sem::BlockStatement>(b)->Decls()) {
current_function_->local_var_decls.erase(decl.value.variable); current_function_->local_var_decls.Remove(decl.value.variable);
} }
return cf; return cf;
@ -489,8 +499,8 @@ class UniformityGraph {
auto* parent = sem_.Get(b) auto* parent = sem_.Get(b)
->FindFirstParent<sem::SwitchStatement, sem::LoopStatement, ->FindFirstParent<sem::SwitchStatement, sem::LoopStatement,
sem::ForLoopStatement, sem::WhileStatement>(); sem::ForLoopStatement, sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent); auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate variable values to the loop/switch exit nodes. // Propagate variable values to the loop/switch exit nodes.
for (auto* var : current_function_->local_var_decls) { for (auto* var : current_function_->local_var_decls) {
@ -502,7 +512,7 @@ class UniformityGraph {
} }
// Add an edge from the variable exit node to its value at this point. // Add an edge from the variable exit node to its value at this point.
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit"); return CreateNode(name + "_value_" + info.type + "_exit");
}); });
@ -526,8 +536,7 @@ class UniformityGraph {
{ {
auto* parent = sem_.Get(b)->FindFirstParent<sem::LoopStatement>(); auto* parent = sem_.Get(b)->FindFirstParent<sem::LoopStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent)); auto& info = current_function_->LoopSwitchInfoFor(parent);
auto& info = current_function_->loop_switch_infos.at(parent);
// Propagate variable values to the loop exit nodes. // Propagate variable values to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) { for (auto* var : current_function_->local_var_decls) {
@ -539,7 +548,7 @@ class UniformityGraph {
} }
// Add an edge from the variable exit node to its value at this point. // Add an edge from the variable exit node to its value at this point.
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit"); return CreateNode(name + "_value_" + info.type + "_exit");
}); });
@ -580,8 +589,7 @@ class UniformityGraph {
auto* parent = sem_.Get(c) auto* parent = sem_.Get(c)
->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement, ->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement,
sem::WhileStatement>(); sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent)); auto& info = current_function_->LoopSwitchInfoFor(parent);
auto& info = current_function_->loop_switch_infos.at(parent);
// Propagate assignments to the loop input nodes. // Propagate assignments to the loop input nodes.
for (auto* var : current_function_->local_var_decls) { for (auto* var : current_function_->local_var_decls) {
@ -593,11 +601,11 @@ class UniformityGraph {
} }
// Add an edge from the variable's loop input node to its value at this point. // Add an edge from the variable's loop input node to its value at this point.
TINT_ASSERT(Resolver, info.var_in_nodes.count(var)); auto** in_node = info.var_in_nodes.Find(var);
auto* in_node = info.var_in_nodes.at(var); TINT_ASSERT(Resolver, in_node != nullptr);
auto* out_node = current_function_->variables.Get(var); auto* out_node = current_function_->variables.Get(var);
if (out_node != in_node) { if (out_node != *in_node) {
in_node->AddEdge(out_node); (*in_node)->AddEdge(out_node);
} }
} }
return cf; return cf;
@ -618,7 +626,7 @@ class UniformityGraph {
} }
auto* cf_start = cf_init; auto* cf_start = cf_init;
auto& info = current_function_->loop_switch_infos[sem_loop]; auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "forloop"; info.type = "forloop";
// Create input nodes for any variables declared before this loop. // Create input nodes for any variables declared before this loop.
@ -626,7 +634,7 @@ class UniformityGraph {
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol); auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_forloop_in"); auto* in_node = CreateNode(name + "_value_forloop_in");
in_node->AddEdge(current_function_->variables.Get(v)); in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes[v] = in_node; info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node); current_function_->variables.Set(v, in_node);
} }
@ -640,7 +648,7 @@ class UniformityGraph {
// Propagate assignments to the loop exit nodes. // Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) { for (auto* var : current_function_->local_var_decls) {
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit"); return CreateNode(name + "_value_" + info.type + "_exit");
}); });
@ -660,19 +668,19 @@ class UniformityGraph {
// Add edges from variable loop input nodes to their values at the end of the loop. // Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) { for (auto v : info.var_in_nodes) {
auto* in_node = v.second; auto* in_node = v.value;
auto* out_node = current_function_->variables.Get(v.first); auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) { if (out_node != in_node) {
in_node->AddEdge(out_node); in_node->AddEdge(out_node);
} }
} }
// Set each variable's exit node as its value in the outer scope. // Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) { for (auto& v : info.var_exit_nodes) {
current_function_->variables.Set(v.first, v.second); current_function_->variables.Set(v.key, v.value);
} }
current_function_->loop_switch_infos.erase(sem_loop); current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) { if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf; return cf;
@ -687,7 +695,7 @@ class UniformityGraph {
auto* cf_start = cf; auto* cf_start = cf;
auto& info = current_function_->loop_switch_infos[sem_loop]; auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "whileloop"; info.type = "whileloop";
// Create input nodes for any variables declared before this loop. // Create input nodes for any variables declared before this loop.
@ -695,7 +703,7 @@ class UniformityGraph {
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol); auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_forloop_in"); auto* in_node = CreateNode(name + "_value_forloop_in");
in_node->AddEdge(current_function_->variables.Get(v)); in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes[v] = in_node; info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node); current_function_->variables.Set(v, in_node);
} }
@ -710,7 +718,7 @@ class UniformityGraph {
// Propagate assignments to the loop exit nodes. // Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) { for (auto* var : current_function_->local_var_decls) {
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit"); return CreateNode(name + "_value_" + info.type + "_exit");
}); });
@ -721,9 +729,9 @@ class UniformityGraph {
cfx->AddEdge(cf); cfx->AddEdge(cf);
// Add edges from variable loop input nodes to their values at the end of the loop. // Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) { for (auto& v : info.var_in_nodes) {
auto* in_node = v.second; auto* in_node = v.value;
auto* out_node = current_function_->variables.Get(v.first); auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) { if (out_node != in_node) {
in_node->AddEdge(out_node); in_node->AddEdge(out_node);
} }
@ -731,10 +739,10 @@ class UniformityGraph {
// Set each variable's exit node as its value in the outer scope. // Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) { for (auto v : info.var_exit_nodes) {
current_function_->variables.Set(v.first, v.second); current_function_->variables.Set(v.key, v.value);
} }
current_function_->loop_switch_infos.erase(sem_loop); current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) { if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf; return cf;
@ -752,15 +760,15 @@ class UniformityGraph {
v->affects_control_flow = true; v->affects_control_flow = true;
v->AddEdge(v_cond); v->AddEdge(v_cond);
std::unordered_map<const sem::Variable*, Node*> true_vars; utils::Hashmap<const sem::Variable*, Node*, 8> true_vars;
std::unordered_map<const sem::Variable*, Node*> false_vars; utils::Hashmap<const sem::Variable*, Node*, 8> false_vars;
// Helper to process a statement with a new scope for variable assignments. // Helper to process a statement with a new scope for variable assignments.
// Populates `assigned_vars` with new nodes for any variables that are assigned in // Populates `assigned_vars` with new nodes for any variables that are assigned in
// this statement. // this statement.
auto process_in_scope = auto process_in_scope =
[&](Node* cf_in, const ast::Statement* s, [&](Node* cf_in, const ast::Statement* s,
std::unordered_map<const sem::Variable*, Node*>& assigned_vars) { utils::Hashmap<const sem::Variable*, Node*, 8>& assigned_vars) {
// Push a new scope for variable assignments. // Push a new scope for variable assignments.
current_function_->variables.Push(); current_function_->variables.Push();
@ -790,7 +798,7 @@ class UniformityGraph {
// Update values for any variables assigned in the if or else blocks. // Update values for any variables assigned in the if or else blocks.
for (auto* var : current_function_->local_var_decls) { for (auto* var : current_function_->local_var_decls) {
// Skip variables not assigned in either block. // Skip variables not assigned in either block.
if (true_vars.count(var) == 0 && false_vars.count(var) == 0) { if (!true_vars.Contains(var) && !false_vars.Contains(var)) {
continue; continue;
} }
@ -801,15 +809,15 @@ class UniformityGraph {
// Add edges to the assigned value or the initial value. // Add edges to the assigned value or the initial value.
// Only add edges if the behavior for that block contains 'Next'. // Only add edges if the behavior for that block contains 'Next'.
if (true_has_next) { if (true_has_next) {
if (true_vars.count(var)) { if (true_vars.Contains(var)) {
out_node->AddEdge(true_vars.at(var)); out_node->AddEdge(*true_vars.Find(var));
} else { } else {
out_node->AddEdge(current_function_->variables.Get(var)); out_node->AddEdge(current_function_->variables.Get(var));
} }
} }
if (false_has_next) { if (false_has_next) {
if (false_vars.count(var)) { if (false_vars.Contains(var)) {
out_node->AddEdge(false_vars.at(var)); out_node->AddEdge(*false_vars.Find(var));
} else { } else {
out_node->AddEdge(current_function_->variables.Get(var)); out_node->AddEdge(current_function_->variables.Get(var));
} }
@ -845,7 +853,7 @@ class UniformityGraph {
auto* sem_loop = sem_.Get(l); auto* sem_loop = sem_.Get(l);
auto* cfx = CreateNode("loop_start"); auto* cfx = CreateNode("loop_start");
auto& info = current_function_->loop_switch_infos[sem_loop]; auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "loop"; info.type = "loop";
// Create input nodes for any variables declared before this loop. // Create input nodes for any variables declared before this loop.
@ -853,7 +861,7 @@ class UniformityGraph {
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol); auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_loop_in", v->Declaration()); auto* in_node = CreateNode(name + "_value_loop_in", v->Declaration());
in_node->AddEdge(current_function_->variables.Get(v)); in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes[v] = in_node; info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node); current_function_->variables.Set(v, in_node);
} }
@ -868,8 +876,8 @@ class UniformityGraph {
// Add edges from variable loop input nodes to their values at the end of the loop. // Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) { for (auto v : info.var_in_nodes) {
auto* in_node = v.second; auto* in_node = v.value;
auto* out_node = current_function_->variables.Get(v.first); auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) { if (out_node != in_node) {
in_node->AddEdge(out_node); in_node->AddEdge(out_node);
} }
@ -877,10 +885,10 @@ class UniformityGraph {
// Set each variable's exit node as its value in the outer scope. // Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) { for (auto v : info.var_exit_nodes) {
current_function_->variables.Set(v.first, v.second); current_function_->variables.Set(v.key, v.value);
} }
current_function_->loop_switch_infos.erase(sem_loop); current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) { if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf; return cf;
@ -925,7 +933,7 @@ class UniformityGraph {
cf_end = CreateNode("switch_CFend"); cf_end = CreateNode("switch_CFend");
} }
auto& info = current_function_->loop_switch_infos[sem_switch]; auto& info = current_function_->LoopSwitchInfoFor(sem_switch);
info.type = "switch"; info.type = "switch";
auto* cf_n = v; auto* cf_n = v;
@ -958,12 +966,11 @@ class UniformityGraph {
} }
// Add an edge from the variable exit node to its new value. // Add an edge from the variable exit node to its new value.
auto* exit_node = auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto name =
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
builder_->Symbols().NameFor(var->Declaration()->symbol); return CreateNode(name + "_value_" + info.type + "_exit");
return CreateNode(name + "_value_" + info.type + "_exit"); });
});
exit_node->AddEdge(current_function_->variables.Get(var)); exit_node->AddEdge(current_function_->variables.Get(var));
} }
} }
@ -974,7 +981,7 @@ class UniformityGraph {
// Update nodes for any variables assigned in the switch statement. // Update nodes for any variables assigned in the switch statement.
for (auto var : info.var_exit_nodes) { for (auto var : info.var_exit_nodes) {
current_function_->variables.Set(var.first, var.second); current_function_->variables.Set(var.key, var.value);
} }
return cf_end ? cf_end : cf; return cf_end ? cf_end : cf;
@ -995,7 +1002,7 @@ class UniformityGraph {
auto* e = UnwrapIndirectAndAddressOfChain(unary_init); auto* e = UnwrapIndirectAndAddressOfChain(unary_init);
if (e->IsAnyOf<ast::IndexAccessorExpression, if (e->IsAnyOf<ast::IndexAccessorExpression,
ast::MemberAccessorExpression>()) { ast::MemberAccessorExpression>()) {
current_function_->partial_ptrs.insert(sem_var); current_function_->partial_ptrs.Add(sem_var);
} }
} }
} }
@ -1005,7 +1012,7 @@ class UniformityGraph {
current_function_->variables.Set(sem_var, node); current_function_->variables.Set(sem_var, node);
if (decl->variable->Is<ast::Var>()) { if (decl->variable->Is<ast::Var>()) {
current_function_->local_var_decls.insert( current_function_->local_var_decls.Add(
sem_.Get<sem::LocalVariable>(decl->variable)); sem_.Get<sem::LocalVariable>(decl->variable));
} }
@ -1183,10 +1190,10 @@ class UniformityGraph {
// To determine if we're dereferencing a partial pointer, unwrap *& // To determine if we're dereferencing a partial pointer, unwrap *&
// chains; if the final expression is an identifier, see if it's a // chains; if the final expression is an identifier, see if it's a
// partial pointer. If it's not an identifier, then it must be an // partial pointer. If it's not an identifier, then it must be an
// index/acessor expression, and thus a partial pointer. // index/accessor expression, and thus a partial pointer.
auto* e = UnwrapIndirectAndAddressOfChain(u); auto* e = UnwrapIndirectAndAddressOfChain(u);
if (auto* var_user = sem_.Get<sem::VariableUser>(e)) { if (auto* var_user = sem_.Get<sem::VariableUser>(e)) {
if (current_function_->partial_ptrs.count(var_user->Variable())) { if (current_function_->partial_ptrs.Contains(var_user->Variable())) {
return true; return true;
} }
} else { } else {
@ -1290,7 +1297,7 @@ class UniformityGraph {
// Process call arguments // Process call arguments
Node* cf_last_arg = cf; Node* cf_last_arg = cf;
std::vector<Node*> args; utils::Vector<Node*, 8> args;
for (size_t i = 0; i < call->args.Length(); i++) { for (size_t i = 0; i < call->args.Length(); i++) {
auto [cf_i, arg_i] = ProcessExpression(cf_last_arg, call->args[i]); auto [cf_i, arg_i] = ProcessExpression(cf_last_arg, call->args[i]);
@ -1303,7 +1310,7 @@ class UniformityGraph {
arg_node->AddEdge(arg_i); arg_node->AddEdge(arg_i);
cf_last_arg = cf_i; cf_last_arg = cf_i;
args.push_back(arg_node); args.Push(arg_node);
} }
// Note: This is an additional node that isn't described in the specification, for the // Note: This is an additional node that isn't described in the specification, for the
@ -1341,11 +1348,11 @@ class UniformityGraph {
[&](const sem::Function* func) { [&](const sem::Function* func) {
// We must have already analyzed the user-defined function since we process // We must have already analyzed the user-defined function since we process
// functions in dependency order. // functions in dependency order.
TINT_ASSERT(Resolver, functions_.count(func->Declaration())); auto* info = functions_.Find(func->Declaration());
auto& info = functions_.at(func->Declaration()); TINT_ASSERT(Resolver, info != nullptr);
callsite_tag = info.callsite_tag; callsite_tag = info->callsite_tag;
function_tag = info.function_tag; function_tag = info->function_tag;
func_info = &info; func_info = info;
}, },
[&](const sem::TypeInitializer*) { [&](const sem::TypeInitializer*) {
callsite_tag = CallSiteNoRestriction; callsite_tag = CallSiteNoRestriction;
@ -1371,7 +1378,7 @@ class UniformityGraph {
result->AddEdge(cf_after); result->AddEdge(cf_after);
// For each argument, add edges based on parameter tags. // For each argument, add edges based on parameter tags.
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.Length(); i++) {
if (func_info) { if (func_info) {
switch (func_info->parameters[i].tag) { switch (func_info->parameters[i].tag) {
case ParameterRequiredToBeUniform: case ParameterRequiredToBeUniform:
@ -1429,11 +1436,11 @@ class UniformityGraph {
/// @param source the starting node /// @param source the starting node
/// @param reachable the set of reachable nodes to populate, if required /// @param reachable the set of reachable nodes to populate, if required
void Traverse(Node* source, utils::UniqueVector<Node*, 4>* reachable = nullptr) { void Traverse(Node* source, utils::UniqueVector<Node*, 4>* reachable = nullptr) {
std::vector<Node*> to_visit{source}; utils::Vector<Node*, 8> to_visit{source};
while (!to_visit.empty()) { while (!to_visit.IsEmpty()) {
auto* node = to_visit.back(); auto* node = to_visit.Back();
to_visit.pop_back(); to_visit.Pop();
if (reachable) { if (reachable) {
reachable->Add(node); reachable->Add(node);
@ -1441,7 +1448,7 @@ class UniformityGraph {
for (auto* to : node->edges) { for (auto* to : node->edges) {
if (to->visited_from == nullptr) { if (to->visited_from == nullptr) {
to->visited_from = node; to->visited_from = node;
to_visit.push_back(to); to_visit.Push(to);
} }
} }
} }
@ -1473,8 +1480,8 @@ class UniformityGraph {
} else if (auto* user = target->As<sem::Function>()) { } else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that // This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node. // function and look for one whose node has an edge from the RequiredToBeUniform node.
auto& target_info = functions_.at(user->Declaration()); auto* target_info = functions_.Find(user->Declaration());
for (auto* call_node : target_info.required_to_be_uniform->edges) { for (auto* call_node : target_info->required_to_be_uniform->edges) {
if (call_node->type == Node::kRegular) { if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>(); auto* child_call = call_node->ast->As<ast::CallExpression>();
return FindBuiltinThatRequiresUniformity(child_call); return FindBuiltinThatRequiresUniformity(child_call);
@ -1643,9 +1650,9 @@ class UniformityGraph {
// If this is a call to a user-defined function, add a note to show the reason that the // If this is a call to a user-defined function, add a note to show the reason that the
// parameter is required to be uniform. // parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) { if (auto* user = target->As<sem::Function>()) {
auto& next_function = functions_.at(user->Declaration()); auto* next_function = functions_.Find(user->Declaration());
Node* next_cause = next_function.parameters[cause->arg_index].init_value; Node* next_cause = next_function->parameters[cause->arg_index].init_value;
MakeError(next_function, next_cause, true); MakeError(*next_function, next_cause, true);
} }
} else { } else {
// The requirement was on a function callsite. // The requirement was on a function callsite.

View File

@ -580,8 +580,8 @@ bool Validator::LocalVariable(const sem::Variable* local) const {
bool Validator::GlobalVariable( bool Validator::GlobalVariable(
const sem::GlobalVariable* global, const sem::GlobalVariable* global,
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids, const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_ids,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const { const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const {
auto* decl = global->Declaration(); auto* decl = global->Declaration();
if (global->AddressSpace() != ast::AddressSpace::kWorkgroup && if (global->AddressSpace() != ast::AddressSpace::kWorkgroup &&
IsArrayWithOverrideCount(global->Type())) { IsArrayWithOverrideCount(global->Type())) {
@ -702,7 +702,7 @@ bool Validator::GlobalVariable(
// buffer variables with a read_write access mode. // buffer variables with a read_write access mode.
bool Validator::AtomicVariable( bool Validator::AtomicVariable(
const sem::Variable* var, const sem::Variable* var,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const { const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const {
auto address_space = var->AddressSpace(); auto address_space = var->AddressSpace();
auto* decl = var->Declaration(); auto* decl = var->Declaration();
auto access = var->Access(); auto access = var->Access();
@ -716,14 +716,13 @@ bool Validator::AtomicVariable(
return false; return false;
} }
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) { } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
auto found = atomic_composite_info.find(type); if (auto* found = atomic_composite_info.Find(type)) {
if (found != atomic_composite_info.end()) {
if (address_space != ast::AddressSpace::kStorage && if (address_space != ast::AddressSpace::kStorage &&
address_space != ast::AddressSpace::kWorkgroup) { address_space != ast::AddressSpace::kWorkgroup) {
AddError("atomic variables must have <storage> or <workgroup> address space", AddError("atomic variables must have <storage> or <workgroup> address space",
source); source);
AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here", AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here",
found->second); **found);
return false; return false;
} else if (address_space == ast::AddressSpace::kStorage && } else if (address_space == ast::AddressSpace::kStorage &&
access != ast::Access::kReadWrite) { access != ast::Access::kReadWrite) {
@ -732,7 +731,7 @@ bool Validator::AtomicVariable(
"access mode", "access mode",
source); source);
AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here", AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here",
found->second); **found);
return false; return false;
} }
} }
@ -783,7 +782,7 @@ bool Validator::Let(const sem::Variable* v) const {
bool Validator::Override( bool Validator::Override(
const sem::GlobalVariable* v, const sem::GlobalVariable* v,
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids) const { const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_ids) const {
auto* decl = v->Declaration(); auto* decl = v->Declaration();
auto* storage_ty = v->Type()->UnwrapRef(); auto* storage_ty = v->Type()->UnwrapRef();
@ -796,12 +795,12 @@ bool Validator::Override(
for (auto* attr : decl->attributes) { for (auto* attr : decl->attributes) {
if (attr->Is<ast::IdAttribute>()) { if (attr->Is<ast::IdAttribute>()) {
auto id = v->OverrideId(); auto id = v->OverrideId();
if (auto it = override_ids.find(id); it != override_ids.end() && it->second != v) { if (auto* var = override_ids.Find(id); var && *var != v) {
AddError("@id values must be unique", attr->source); AddError("@id values must be unique", attr->source);
AddNote("a override with an ID of " + std::to_string(id.value) + AddNote(
" was previously declared here:", "a override with an ID of " + std::to_string(id.value) +
ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes) " was previously declared here:",
->source); ast::GetAttribute<ast::IdAttribute>((*var)->Declaration()->attributes)->source);
return false; return false;
} }
} else { } else {
@ -1093,8 +1092,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
// order to catch conflicts. // order to catch conflicts.
// TODO(jrprice): This state could be stored in sem::Function instead, and then passed to // TODO(jrprice): This state could be stored in sem::Function instead, and then passed to
// sem::Function since it would be useful there too. // sem::Function since it would be useful there too.
std::unordered_set<ast::BuiltinValue> builtins; utils::Hashset<ast::BuiltinValue, 4> builtins;
std::unordered_set<uint32_t> locations; utils::Hashset<uint32_t, 8> locations;
enum class ParamOrRetType { enum class ParamOrRetType {
kParameter, kParameter,
kReturnType, kReturnType,
@ -1130,7 +1129,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
} }
pipeline_io_attribute = attr; pipeline_io_attribute = attr;
if (builtins.count(builtin->builtin)) { if (builtins.Contains(builtin->builtin)) {
AddError(attr_to_str(builtin) + AddError(attr_to_str(builtin) +
" attribute appears multiple times as pipeline " + " attribute appears multiple times as pipeline " +
(param_or_ret == ParamOrRetType::kParameter ? "input" : "output"), (param_or_ret == ParamOrRetType::kParameter ? "input" : "output"),
@ -1142,7 +1141,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
/* is_input */ param_or_ret == ParamOrRetType::kParameter)) { /* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
return false; return false;
} }
builtins.emplace(builtin->builtin); builtins.Add(builtin->builtin);
} else if (auto* loc_attr = attr->As<ast::LocationAttribute>()) { } else if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
if (pipeline_io_attribute) { if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", attr->source); AddError("multiple entry point IO attributes", attr->source);
@ -1287,8 +1286,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
// Clear IO sets after parameter validation. Builtin and location attributes in return types // Clear IO sets after parameter validation. Builtin and location attributes in return types
// should be validated independently from those used in parameters. // should be validated independently from those used in parameters.
builtins.clear(); builtins.Clear();
locations.clear(); locations.Clear();
if (!func->ReturnType()->Is<sem::Void>()) { if (!func->ReturnType()->Is<sem::Void>()) {
if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(), if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(),
@ -1299,7 +1298,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
} }
if (decl->PipelineStage() == ast::PipelineStage::kVertex && if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
builtins.count(ast::BuiltinValue::kPosition) == 0) { !builtins.Contains(ast::BuiltinValue::kPosition)) {
// Check module-scope variables, as the SPIR-V sanitizer generates these. // Check module-scope variables, as the SPIR-V sanitizer generates these.
bool found = false; bool found = false;
for (auto* global : func->TransitivelyReferencedGlobals()) { for (auto* global : func->TransitivelyReferencedGlobals()) {
@ -1327,18 +1326,18 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
} }
// Validate there are no resource variable binding collisions // Validate there are no resource variable binding collisions
std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points; utils::Hashmap<sem::BindingPoint, const ast::Variable*, 8> binding_points;
for (auto* global : func->TransitivelyReferencedGlobals()) { for (auto* global : func->TransitivelyReferencedGlobals()) {
auto* var_decl = global->Declaration()->As<ast::Var>(); auto* var_decl = global->Declaration()->As<ast::Var>();
if (!var_decl || !var_decl->HasBindingPoint()) { if (!var_decl || !var_decl->HasBindingPoint()) {
continue; continue;
} }
auto bp = global->BindingPoint(); auto bp = global->BindingPoint();
auto res = binding_points.emplace(bp, var_decl); if (auto added = binding_points.Add(bp, var_decl);
if (!res.second && !added &&
IsValidationEnabled(decl->attributes, IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kBindingPointCollision) && ast::DisabledValidation::kBindingPointCollision) &&
IsValidationEnabled(res.first->second->attributes, IsValidationEnabled((*added.value)->attributes,
ast::DisabledValidation::kBindingPointCollision)) { ast::DisabledValidation::kBindingPointCollision)) {
// https://gpuweb.github.io/gpuweb/wgsl/#resource-interface // https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
// Bindings must not alias within a shader stage: two different variables in the // Bindings must not alias within a shader stage: two different variables in the
@ -1350,7 +1349,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
"' references multiple variables that use the same resource binding @group(" + "' references multiple variables that use the same resource binding @group(" +
std::to_string(bp.group) + "), @binding(" + std::to_string(bp.binding) + ")", std::to_string(bp.group) + "), @binding(" + std::to_string(bp.binding) + ")",
var_decl->source); var_decl->source);
AddNote("first resource binding usage declared here", res.first->second->source); AddNote("first resource binding usage declared here", (*added.value)->source);
return false; return false;
} }
} }
@ -1917,7 +1916,7 @@ bool Validator::Matrix(const sem::Matrix* ty, const Source& source) const {
return true; return true;
} }
bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points) const { bool Validator::PipelineStages(const utils::VectorRef<sem::Function*> entry_points) const {
auto backtrace = [&](const sem::Function* func, const sem::Function* entry_point) { auto backtrace = [&](const sem::Function* func, const sem::Function* entry_point) {
if (func != entry_point) { if (func != entry_point) {
TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) { TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) {
@ -2012,7 +2011,7 @@ bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points)
return true; return true;
} }
bool Validator::PushConstants(const std::vector<sem::Function*>& entry_points) const { bool Validator::PushConstants(const utils::VectorRef<sem::Function*> entry_points) const {
for (auto* entry_point : entry_points) { for (auto* entry_point : entry_points) {
// State checked and modified by check_push_constant so that it remembers previously seen // State checked and modified by check_push_constant so that it remembers previously seen
// push_constant variables for an entry-point. // push_constant variables for an entry-point.
@ -2130,7 +2129,7 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
return false; return false;
} }
std::unordered_set<uint32_t> locations; utils::Hashset<uint32_t, 8> locations;
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
if (auto* r = member->Type()->As<sem::Array>()) { if (auto* r = member->Type()->As<sem::Array>()) {
if (r->IsRuntimeSized()) { if (r->IsRuntimeSized()) {
@ -2248,7 +2247,7 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr, bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location, uint32_t location,
const sem::Type* type, const sem::Type* type,
std::unordered_set<uint32_t>& locations, utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage, ast::PipelineStage stage,
const Source& source, const Source& source,
const bool is_input) const { const bool is_input) const {
@ -2269,12 +2268,11 @@ bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
return false; return false;
} }
if (locations.count(location)) { if (!locations.Add(location)) {
AddError(attr_to_str(loc_attr, location) + " attribute appears multiple times", AddError(attr_to_str(loc_attr, location) + " attribute appears multiple times",
loc_attr->source); loc_attr->source);
return false; return false;
} }
locations.emplace(location);
return true; return true;
} }
@ -2311,7 +2309,7 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
} }
const sem::CaseSelector* default_selector = nullptr; const sem::CaseSelector* default_selector = nullptr;
std::unordered_map<int64_t, Source> selectors; utils::Hashmap<int64_t, Source, 4> selectors;
for (auto* case_stmt : s->body) { for (auto* case_stmt : s->body) {
auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt); auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt);
@ -2338,18 +2336,16 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
} }
auto value = selector->Value()->As<uint32_t>(); auto value = selector->Value()->As<uint32_t>();
auto it = selectors.find(value); if (auto added = selectors.Add(value, selector->Declaration()->source); !added) {
if (it != selectors.end()) {
AddError("duplicate switch case '" + AddError("duplicate switch case '" +
(decl_ty->IsAnyOf<sem::I32, sem::AbstractNumeric>() (decl_ty->IsAnyOf<sem::I32, sem::AbstractNumeric>()
? std::to_string(i32(value)) ? std::to_string(i32(value))
: std::to_string(value)) + : std::to_string(value)) +
"'", "'",
selector->Declaration()->source); selector->Declaration()->source);
AddNote("previous case declared here", it->second); AddNote("previous case declared here", *added.value);
return false; return false;
} }
selectors.emplace(value, selector->Declaration()->source);
} }
} }
@ -2477,12 +2473,12 @@ bool Validator::IncrementDecrementStatement(const ast::IncrementDecrementStateme
} }
bool Validator::NoDuplicateAttributes(utils::VectorRef<const ast::Attribute*> attributes) const { bool Validator::NoDuplicateAttributes(utils::VectorRef<const ast::Attribute*> attributes) const {
std::unordered_map<const TypeInfo*, Source> seen; utils::Hashmap<const TypeInfo*, Source, 8> seen;
for (auto* d : attributes) { for (auto* d : attributes) {
auto res = seen.emplace(&d->TypeInfo(), d->source); auto added = seen.Add(&d->TypeInfo(), d->source);
if (!res.second && !d->Is<ast::InternalAttribute>()) { if (!added && !d->Is<ast::InternalAttribute>()) {
AddError("duplicate " + d->Name() + " attribute", d->source); AddError("duplicate " + d->Name() + " attribute", d->source);
AddNote("first attribute declared here", res.first->second); AddNote("first attribute declared here", *added.value);
return false; return false;
} }
} }

View File

@ -17,16 +17,15 @@
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector>
#include "src/tint/ast/pipeline_stage.h" #include "src/tint/ast/pipeline_stage.h"
#include "src/tint/program_builder.h" #include "src/tint/program_builder.h"
#include "src/tint/resolver/sem_helper.h" #include "src/tint/resolver/sem_helper.h"
#include "src/tint/sem/evaluation_stage.h" #include "src/tint/sem/evaluation_stage.h"
#include "src/tint/source.h" #include "src/tint/source.h"
#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/vector.h"
// Forward declarations // Forward declarations
namespace tint::ast { namespace tint::ast {
@ -116,12 +115,12 @@ class Validator {
/// Validates pipeline stages /// Validates pipeline stages
/// @param entry_points the entry points to the module /// @param entry_points the entry points to the module
/// @returns true on success, false otherwise. /// @returns true on success, false otherwise.
bool PipelineStages(const std::vector<sem::Function*>& entry_points) const; bool PipelineStages(const utils::VectorRef<sem::Function*> entry_points) const;
/// Validates push_constant variables /// Validates push_constant variables
/// @param entry_points the entry points to the module /// @param entry_points the entry points to the module
/// @returns true on success, false otherwise. /// @returns true on success, false otherwise.
bool PushConstants(const std::vector<sem::Function*>& entry_points) const; bool PushConstants(const utils::VectorRef<sem::Function*> entry_points) const;
/// Validates aliases /// Validates aliases
/// @param alias the alias to validate /// @param alias the alias to validate
@ -156,7 +155,7 @@ class Validator {
/// @returns true on success, false otherwise. /// @returns true on success, false otherwise.
bool AtomicVariable( bool AtomicVariable(
const sem::Variable* var, const sem::Variable* var,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const; const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const;
/// Validates an assignment /// Validates an assignment
/// @param a the assignment statement /// @param a the assignment statement
@ -248,8 +247,8 @@ class Validator {
/// @returns true on success, false otherwise /// @returns true on success, false otherwise
bool GlobalVariable( bool GlobalVariable(
const sem::GlobalVariable* var, const sem::GlobalVariable* var,
const std::unordered_map<OverrideId, const sem::Variable*>& override_id, const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_id,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const; const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const;
/// Validates a break-if statement /// Validates a break-if statement
/// @param stmt the statement to validate /// @param stmt the statement to validate
@ -297,7 +296,7 @@ class Validator {
bool LocationAttribute(const ast::LocationAttribute* loc_attr, bool LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location, uint32_t location,
const sem::Type* type, const sem::Type* type,
std::unordered_set<uint32_t>& locations, utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage, ast::PipelineStage stage,
const Source& source, const Source& source,
const bool is_input = false) const; const bool is_input = false) const;
@ -392,7 +391,7 @@ class Validator {
/// @param override_id the set of override ids in the module /// @param override_id the set of override ids in the module
/// @returns true on success, false otherwise. /// @returns true on success, false otherwise.
bool Override(const sem::GlobalVariable* v, bool Override(const sem::GlobalVariable* v,
const std::unordered_map<OverrideId, const sem::Variable*>& override_id) const; const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_id) const;
/// Validates a 'const' variable declaration /// Validates a 'const' variable declaration
/// @param v the variable to validate /// @param v the variable to validate

View File

@ -14,11 +14,11 @@
#ifndef SRC_TINT_SCOPE_STACK_H_ #ifndef SRC_TINT_SCOPE_STACK_H_
#define SRC_TINT_SCOPE_STACK_H_ #define SRC_TINT_SCOPE_STACK_H_
#include <unordered_map>
#include <utility> #include <utility>
#include <vector>
#include "src/tint/symbol.h" #include "src/tint/symbol.h"
#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/vector.h"
namespace tint { namespace tint {
@ -27,22 +27,13 @@ namespace tint {
template <class K, class V> template <class K, class V>
class ScopeStack { class ScopeStack {
public: public:
/// Constructor
ScopeStack() {
// Push global bucket
stack_.push_back({});
}
/// Copy Constructor
ScopeStack(const ScopeStack&) = default;
~ScopeStack() = default;
/// Push a new scope on to the stack /// Push a new scope on to the stack
void Push() { stack_.push_back({}); } void Push() { stack_.Push({}); }
/// Pop the scope off the top of the stack /// Pop the scope off the top of the stack
void Pop() { void Pop() {
if (stack_.size() > 1) { if (stack_.Length() > 1) {
stack_.pop_back(); stack_.Pop();
} }
} }
@ -52,8 +43,13 @@ class ScopeStack {
/// @returns the old value if there was an existing key at the top of the /// @returns the old value if there was an existing key at the top of the
/// stack, otherwise the zero initializer for type T. /// stack, otherwise the zero initializer for type T.
V Set(const K& key, V val) { V Set(const K& key, V val) {
std::swap(val, stack_.back()[key]); auto& back = stack_.Back();
return val; if (auto* el = back.Find(key)) {
std::swap(val, *el);
return val;
}
back.Add(key, val);
return {};
} }
/// Retrieves a value from the stack /// Retrieves a value from the stack
@ -61,10 +57,8 @@ class ScopeStack {
/// @returns the value, or the zero initializer if the value was not found /// @returns the value, or the zero initializer if the value was not found
V Get(const K& key) const { V Get(const K& key) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) { for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
auto& map = *iter; if (auto* val = iter->Find(key)) {
auto val = map.find(key); return *val;
if (val != map.end()) {
return val->second;
} }
} }
@ -73,16 +67,16 @@ class ScopeStack {
/// Return the top scope of the stack. /// Return the top scope of the stack.
/// @returns the top scope of the stack /// @returns the top scope of the stack
const std::unordered_map<K, V>& Top() const { return stack_.back(); } const utils::Hashmap<K, V, 8>& Top() const { return stack_.Back(); }
/// Clear the scope stack. /// Clear the scope stack.
void Clear() { void Clear() {
stack_.clear(); stack_.Clear();
stack_.push_back({}); stack_.Push({});
} }
private: private:
std::vector<std::unordered_map<K, V>> stack_; utils::Vector<utils::Hashmap<K, V, 8>, 8> stack_ = {{}};
}; };
} // namespace tint } // namespace tint

View File

@ -157,9 +157,9 @@ size_t HashCombine(size_t hash, const ARGS&... values) {
template <typename T> template <typename T>
struct UnorderedKeyWrapper { struct UnorderedKeyWrapper {
/// The wrapped value /// The wrapped value
const T value; T value;
/// The hash of value /// The hash of value
const size_t hash; size_t hash;
/// Constructor /// Constructor
/// @param v the value to wrap /// @param v the value to wrap

View File

@ -524,7 +524,7 @@ class HashmapBase {
/// Shuffles slots for an insertion that has been placed one slot before `start`. /// Shuffles slots for an insertion that has been placed one slot before `start`.
/// @param start the index of the first slot to start shuffling. /// @param start the index of the first slot to start shuffling.
/// @param evicted the slot content that was evicted for the insertion. /// @param evicted the slot content that was evicted for the insertion.
void InsertShuffle(size_t start, Slot evicted) { void InsertShuffle(size_t start, Slot&& evicted) {
Scan(start, [&](size_t, size_t index) { Scan(start, [&](size_t, size_t index) {
auto& slot = slots_[index]; auto& slot = slots_[index];