tint/resolver: Move from STL to tint::utils containers

Change-Id: I883168a1a84457138de85decb921c5c430c32bd8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108702
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-11-09 20:55:33 +00:00 committed by Dawn LUCI CQ
parent 865c3f8e94
commit 9418152d08
17 changed files with 259 additions and 264 deletions

View File

@ -306,7 +306,6 @@ endif()
message(STATUS "Using python3")
find_package(PythonInterp 3 REQUIRED)
################################################################################
# common_compile_options - sets compiler and linker options common for dawn and
# tint on the given target
@ -347,6 +346,12 @@ endif()
# Dawn's public and internal "configs"
################################################################################
set(IS_DEBUG_BUILD 0)
string(TOUPPER "${CMAKE_BUILD_TYPE}" build_type)
if ((NOT ${build_type} STREQUAL "RELEASE") AND (NOT ${build_type} STREQUAL "RELWITHDEBINFO"))
set(IS_DEBUG_BUILD 1)
endif()
# The public config contains only the include paths for the Dawn headers.
add_library(dawn_public_config INTERFACE)
target_include_directories(dawn_public_config INTERFACE
@ -363,7 +368,7 @@ target_include_directories(dawn_internal_config INTERFACE
target_link_libraries(dawn_internal_config INTERFACE dawn_public_config)
# Compile definitions for the internal config
if (DAWN_ALWAYS_ASSERT OR $<CONFIG:Debug>)
if (DAWN_ALWAYS_ASSERT OR IS_DEBUG_BUILD)
target_compile_definitions(dawn_internal_config INTERFACE "DAWN_ENABLE_ASSERTS")
endif()
if (DAWN_ENABLE_D3D12)

View File

@ -1357,9 +1357,8 @@ if(TINT_BUILD_TESTS)
# overflows when resolving deeply nested expression chains or statements.
# Production builds neither use MSVC nor debug, so just bump the stack size
# for this build combination.
string(TOUPPER "${CMAKE_BUILD_TYPE}" build_type)
if ((NOT ${build_type} STREQUAL "RELEASE") AND (NOT ${build_type} STREQUAL "RELWITHDEBINFO"))
target_link_options(tint_unittests PRIVATE "/STACK 2097152") # 2MB, default is 1MB
if (IS_DEBUG_BUILD)
target_link_options(tint_unittests PRIVATE "/STACK:4194304") # 4MB, default is 1MB
endif()
else()
target_compile_options(tint_unittests PRIVATE

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_set>
#include "src/tint/ast/builtin_texture_helper_test.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/type_initializer.h"

View File

@ -20,7 +20,6 @@
#include <optional>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include "src/tint/program_builder.h"
@ -463,18 +462,18 @@ const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return nullptr;
},
[&](const sem::Struct* s) -> const ImplConstant* {
std::unordered_map<const sem::Type*, const ImplConstant*> zero_by_type;
utils::Hashmap<const sem::Type*, const ImplConstant*, 8> zero_by_type;
utils::Vector<const sem::Constant*, 4> zeros;
zeros.Reserve(s->Members().size());
for (auto* member : s->Members()) {
auto* zero = utils::GetOrCreate(zero_by_type, member->Type(),
[&] { return ZeroValue(builder, member->Type()); });
auto* zero = zero_by_type.GetOrCreate(
member->Type(), [&] { return ZeroValue(builder, member->Type()); });
if (!zero) {
return nullptr;
}
zeros.Push(zero);
}
if (zero_by_type.size() == 1) {
if (zero_by_type.Count() == 1) {
// All members were of the same type, so the zero value is the same for all members.
return builder.create<Splat>(type, zeros[0], s->Members().size());
}

View File

@ -15,7 +15,6 @@
#include "src/tint/resolver/dependency_graph.h"
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
@ -117,7 +116,7 @@ struct DependencyEdgeCmp {
/// A map of DependencyEdge to DependencyInfo
using DependencyEdges =
std::unordered_map<DependencyEdge, DependencyInfo, DependencyEdgeCmp, DependencyEdgeCmp>;
utils::Hashmap<DependencyEdge, DependencyInfo, 64, DependencyEdgeCmp, DependencyEdgeCmp>;
/// Global describes a module-scope variable, type or function.
struct Global {
@ -126,11 +125,11 @@ struct Global {
/// The declaration ast::Node
const ast::Node* node;
/// A list of dependencies that this global depends on
std::vector<Global*> deps;
utils::Vector<Global*, 8> deps;
};
/// A map of global name to Global
using GlobalMap = std::unordered_map<Symbol, Global*>;
using GlobalMap = utils::Hashmap<Symbol, Global*, 16>;
/// Raises an ICE that a global ast::Node type was not handled by this system.
void UnhandledNode(diag::List& diagnostics, const ast::Node* node) {
@ -170,7 +169,7 @@ class DependencyScanner {
dependency_edges_(edges) {
// Register all the globals at global-scope
for (auto it : globals_by_name) {
scope_stack_.Set(it.first, it.second->node);
scope_stack_.Set(it.key, it.value->node);
}
}
@ -232,7 +231,7 @@ class DependencyScanner {
for (auto* param : func->params) {
if (auto* shadows = scope_stack_.Get(param->symbol)) {
graph_.shadows.emplace(param, shadows);
graph_.shadows.Add(param, shadows);
}
Declare(param->symbol, param);
}
@ -306,7 +305,7 @@ class DependencyScanner {
},
[&](const ast::VariableDeclStatement* v) {
if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
graph_.shadows.emplace(v->variable, shadows);
graph_.shadows.Add(v->variable, shadows);
}
TraverseType(v->variable->type);
TraverseExpression(v->variable->initializer);
@ -473,16 +472,14 @@ class DependencyScanner {
}
}
if (auto* global = utils::Lookup(globals_, to); global && global->node == resolved) {
if (dependency_edges_
.emplace(DependencyEdge{current_global_, global},
DependencyInfo{from->source, action})
.second) {
current_global_->deps.emplace_back(global);
if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
DependencyInfo{from->source, action})) {
current_global_->deps.Push(*global);
}
}
graph_.resolved_symbols.emplace(from, resolved);
graph_.resolved_symbols.Add(from, resolved);
}
/// @returns true if `name` is the name of a builtin function
@ -497,7 +494,7 @@ class DependencyScanner {
source);
}
using VariableMap = std::unordered_map<Symbol, const ast::Variable*>;
using VariableMap = utils::Hashmap<Symbol, const ast::Variable*, 32>;
const SymbolTable& symbols_;
const GlobalMap& globals_;
diag::List& diagnostics_;
@ -520,7 +517,7 @@ struct DependencyAnalysis {
/// @returns true if analysis found no errors, otherwise false.
bool Run(const ast::Module& module) {
// Reserve container memory
graph_.resolved_symbols.reserve(module.GlobalDeclarations().Length());
graph_.resolved_symbols.Reserve(module.GlobalDeclarations().Length());
sorted_.Reserve(module.GlobalDeclarations().Length());
// Collect all the named globals from the AST module
@ -589,9 +586,9 @@ struct DependencyAnalysis {
for (auto* node : module.GlobalDeclarations()) {
auto* global = allocator_.Create(node);
if (auto symbol = SymbolOf(node); symbol.IsValid()) {
globals_.emplace(symbol, global);
globals_.Add(symbol, global);
}
declaration_order_.emplace_back(global);
declaration_order_.Push(global);
}
}
@ -625,16 +622,16 @@ struct DependencyAnalysis {
return;
}
std::vector<Entry> stack{Entry{root, 0}};
utils::Vector<Entry, 16> stack{Entry{root, 0}};
while (true) {
auto& entry = stack.back();
auto& entry = stack.Back();
// Have we exhausted the dependencies of entry.global?
if (entry.dep_idx < entry.global->deps.size()) {
if (entry.dep_idx < entry.global->deps.Length()) {
// No, there's more dependencies to traverse.
auto& dep = entry.global->deps[entry.dep_idx];
// Does the caller want to enter this dependency?
if (enter(dep)) { // Yes.
stack.push_back(Entry{dep, 0}); // Enter the dependency.
stack.Push(Entry{dep, 0}); // Enter the dependency.
} else {
entry.dep_idx++; // No. Skip this node.
}
@ -643,11 +640,11 @@ struct DependencyAnalysis {
// Exit this global, pop the stack, and if there's another parent node,
// increment its dependency index, and loop again.
exit(entry.global);
stack.pop_back();
if (stack.empty()) {
stack.Pop();
if (stack.IsEmpty()) {
return; // All done.
}
stack.back().dep_idx++;
stack.Back().dep_idx++;
}
}
}
@ -707,9 +704,8 @@ struct DependencyAnalysis {
/// of global `from` depending on `to`.
/// @note will raise an ICE if the edge is not found.
DependencyInfo DepInfoFor(const Global* from, const Global* to) const {
auto it = dependency_edges_.find(DependencyEdge{from, to});
if (it != dependency_edges_.end()) {
return it->second;
if (auto info = dependency_edges_.Find(DependencyEdge{from, to})) {
return *info;
}
TINT_ICE(Resolver, diagnostics_)
<< "failed to find dependency info for edge: '" << NameOf(from->node) << "' -> '"
@ -762,7 +758,7 @@ struct DependencyAnalysis {
printf("------ dependencies ------ \n");
for (auto* node : sorted_) {
auto symbol = SymbolOf(node);
auto* global = globals_.at(symbol);
auto* global = *globals_.Find(symbol);
printf("%s depends on:\n", symbols_.NameFor(symbol).c_str());
for (auto* dep : global->deps) {
printf(" %s\n", NameOf(dep->node).c_str());
@ -791,7 +787,7 @@ struct DependencyAnalysis {
DependencyEdges dependency_edges_;
/// Globals in declaration order. Populated by GatherGlobals().
std::vector<Global*> declaration_order_;
utils::Vector<Global*, 64> declaration_order_;
/// Globals in sorted dependency order. Populated by SortGlobals().
utils::UniqueVector<const ast::Node*, 64> sorted_;

View File

@ -15,11 +15,11 @@
#ifndef SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
#define SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
#include <unordered_map>
#include <vector>
#include "src/tint/ast/module.h"
#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/utils/hashmap.h"
namespace tint::resolver {
@ -50,13 +50,13 @@ struct DependencyGraph {
/// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or
/// variable that declares the symbol.
std::unordered_map<const ast::Node*, const ast::Node*> resolved_symbols;
utils::Hashmap<const ast::Node*, const ast::Node*, 64> resolved_symbols;
/// Map of ast::Variable to a type, function, or variable that is shadowed by
/// the variable key. A declaration (X) shadows another (Y) if X and Y use
/// the same symbol, and X is declared in a sub-scope of the scope that
/// declares Y.
std::unordered_map<const ast::Variable*, const ast::Node*> shadows;
utils::Hashmap<const ast::Variable*, const ast::Node*, 16> shadows;
};
} // namespace tint::resolver

View File

@ -1128,9 +1128,10 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
if (expect_pass) {
// Check that the use resolves to the declaration
auto* resolved_symbol = graph.resolved_symbols[use];
EXPECT_EQ(resolved_symbol, decl)
<< "resolved: " << (resolved_symbol ? resolved_symbol->TypeInfo().name : "<null>")
auto* resolved_symbol = graph.resolved_symbols.Find(use);
ASSERT_NE(resolved_symbol, nullptr);
EXPECT_EQ(*resolved_symbol, decl)
<< "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
<< "\n"
<< "decl: " << decl->TypeInfo().name;
}
@ -1177,7 +1178,10 @@ TEST_P(ResolverDependencyShadowTest, Test) {
: helper.parameters[0];
helper.Build();
EXPECT_EQ(Build().shadows[inner_var], outer);
auto shadows = Build().shadows;
auto* shadow = shadows.Find(inner_var);
ASSERT_NE(shadow, nullptr);
EXPECT_EQ(*shadow, outer);
}
INSTANTIATE_TEST_SUITE_P(LocalShadowGlobal,
@ -1308,8 +1312,9 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
auto graph = Build();
for (auto use : symbol_uses) {
auto* resolved_symbol = graph.resolved_symbols[use.use];
EXPECT_EQ(resolved_symbol, use.decl) << use.where;
auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
ASSERT_NE(resolved_symbol, nullptr) << use.where;
EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
}
}

View File

@ -16,7 +16,6 @@
#include <algorithm>
#include <limits>
#include <unordered_map>
#include <utility>
#include "src/tint/ast/binary_expression.h"
@ -36,7 +35,7 @@
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/type_initializer.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/math.h"
#include "src/tint/utils/scoped_assignment.h"
@ -1114,10 +1113,10 @@ class Impl : public IntrinsicTable {
ProgramBuilder& builder;
Matchers matchers;
std::unordered_map<IntrinsicPrototype, sem::Builtin*, IntrinsicPrototype::Hasher> builtins;
std::unordered_map<IntrinsicPrototype, sem::TypeInitializer*, IntrinsicPrototype::Hasher>
utils::Hashmap<IntrinsicPrototype, sem::Builtin*, 64, IntrinsicPrototype::Hasher> builtins;
utils::Hashmap<IntrinsicPrototype, sem::TypeInitializer*, 16, IntrinsicPrototype::Hasher>
initializers;
std::unordered_map<IntrinsicPrototype, sem::TypeConversion*, IntrinsicPrototype::Hasher>
utils::Hashmap<IntrinsicPrototype, sem::TypeConversion*, 16, IntrinsicPrototype::Hasher>
converters;
};
@ -1185,7 +1184,7 @@ Impl::Builtin Impl::Lookup(sem::BuiltinType builtin_type,
}
// De-duplicate builtins that are identical.
auto* sem = utils::GetOrCreate(builtins, match, [&] {
auto* sem = builtins.GetOrCreate(match, [&] {
utils::Vector<sem::Parameter*, kNumFixedParams> params;
params.Reserve(match.parameters.Length());
for (auto& p : match.parameters) {
@ -1396,7 +1395,7 @@ IntrinsicTable::InitOrConv Impl::Lookup(InitConvIntrinsic type,
}
auto eval_stage = match.overload->const_eval_fn ? sem::EvaluationStage::kConstant
: sem::EvaluationStage::kRuntime;
auto* target = utils::GetOrCreate(initializers, match, [&]() {
auto* target = initializers.GetOrCreate(match, [&]() {
return builder.create<sem::TypeInitializer>(match.return_type, std::move(params),
eval_stage);
});
@ -1404,7 +1403,7 @@ IntrinsicTable::InitOrConv Impl::Lookup(InitConvIntrinsic type,
}
// Conversion.
auto* target = utils::GetOrCreate(converters, match, [&]() {
auto* target = converters.GetOrCreate(match, [&]() {
auto param = builder.create<sem::Parameter>(
nullptr, 0u, match.parameters[0].type, ast::AddressSpace::kNone,
ast::Access::kUndefined, match.parameters[0].usage);

View File

@ -482,7 +482,7 @@ sem::Variable* Resolver::Override(const ast::Override* v) {
sem->SetOverrideId(o);
// Track the constant IDs that are specified in the shader.
override_ids_.emplace(o, sem);
override_ids_.Add(o, sem);
}
builder_->Sem().Add(v, sem);
@ -842,7 +842,7 @@ bool Resolver::AllocateOverridableConstantIds() {
id = builder_->Sem().Get<sem::GlobalVariable>(override)->OverrideId();
} else {
// No ID was specified, so allocate the next available ID.
while (!ids_exhausted && override_ids_.count(next_id)) {
while (!ids_exhausted && override_ids_.Contains(next_id)) {
increment_next_id();
}
if (ids_exhausted) {
@ -864,9 +864,9 @@ bool Resolver::AllocateOverridableConstantIds() {
void Resolver::SetShadows() {
for (auto it : dependencies_.shadows) {
Switch(
sem_.Get(it.first), //
[&](sem::LocalVariable* local) { local->SetShadows(sem_.Get(it.second)); },
[&](sem::Parameter* param) { param->SetShadows(sem_.Get(it.second)); });
sem_.Get(it.key), //
[&](sem::LocalVariable* local) { local->SetShadows(sem_.Get(it.value)); },
[&](sem::Parameter* param) { param->SetShadows(sem_.Get(it.value)); });
}
}
@ -923,7 +923,7 @@ sem::Statement* Resolver::StaticAssert(const ast::StaticAssert* assertion) {
sem::Function* Resolver::Function(const ast::Function* decl) {
uint32_t parameter_index = 0;
std::unordered_map<Symbol, Source> parameter_names;
utils::Hashmap<Symbol, Source, 8> parameter_names;
utils::Vector<sem::Parameter*, 8> parameters;
// Resolve all the parameters
@ -931,11 +931,10 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
Mark(param);
{ // Check the parameter name is unique for the function
auto emplaced = parameter_names.emplace(param->symbol, param->source);
if (!emplaced.second) {
if (auto added = parameter_names.Add(param->symbol, param->source); !added) {
auto name = builder_->Symbols().NameFor(param->symbol);
AddError("redefinition of parameter '" + name + "'", param->source);
AddNote("previous definition is here", emplaced.first->second);
AddNote("previous definition is here", *added.value);
return nullptr;
}
}
@ -1031,7 +1030,7 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
}
if (decl->IsEntryPoint()) {
entry_points_.emplace_back(func);
entry_points_.Push(func);
}
if (decl->body) {
@ -1850,8 +1849,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
[&](const sem::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); },
[&](const sem::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); },
[&](const sem::Array* arr) -> sem::Call* {
auto* call_target = utils::GetOrCreate(
array_inits_, ArrayInitializerSig{{arr, args.Length(), args_stage}},
auto* call_target = array_inits_.GetOrCreate(
ArrayInitializerSig{{arr, args.Length(), args_stage}},
[&]() -> sem::TypeInitializer* {
auto params = utils::Transform(args, [&](auto, size_t i) {
return builder_->create<sem::Parameter>(
@ -1877,8 +1876,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
return call;
},
[&](const sem::Struct* str) -> sem::Call* {
auto* call_target = utils::GetOrCreate(
struct_inits_, StructInitializerSig{{str, args.Length(), args_stage}},
auto* call_target = struct_inits_.GetOrCreate(
StructInitializerSig{{str, args.Length(), args_stage}},
[&]() -> sem::TypeInitializer* {
utils::Vector<const sem::Parameter*, 8> params;
params.Resize(std::min(args.Length(), str->Members().size()));
@ -1981,9 +1980,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
AddError(
"cannot infer common array element type from initializer arguments",
expr->source);
std::unordered_set<const sem::Type*> types;
utils::Hashset<const sem::Type*, 8> types;
for (size_t i = 0; i < args.Length(); i++) {
if (types.emplace(args[i]->Type()).second) {
if (types.Add(args[i]->Type())) {
AddNote("argument " + std::to_string(i) + " is of type '" +
sem_.TypeNameOf(args[i]->Type()) + "'",
args[i]->Declaration()->source);
@ -2687,11 +2686,10 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
}
if (el_ty->Is<sem::Atomic>()) {
atomic_composite_info_.emplace(out, arr->type->source);
atomic_composite_info_.Add(out, &arr->type->source);
} else {
auto found = atomic_composite_info_.find(el_ty);
if (found != atomic_composite_info_.end()) {
atomic_composite_info_.emplace(out, found->second);
if (auto* found = atomic_composite_info_.Find(el_ty)) {
atomic_composite_info_.Add(out, *found);
}
}
@ -2832,15 +2830,14 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
// validation.
uint64_t struct_size = 0;
uint64_t struct_align = 1;
std::unordered_map<Symbol, const ast::StructMember*> member_map;
utils::Hashmap<Symbol, const ast::StructMember*, 8> member_map;
for (auto* member : str->members) {
Mark(member);
auto result = member_map.emplace(member->symbol, member);
if (!result.second) {
if (auto added = member_map.Add(member->symbol, member); !added) {
AddError("redefinition of '" + builder_->Symbols().NameFor(member->symbol) + "'",
member->source);
AddNote("previous definition is here", result.first->second->source);
AddNote("previous definition is here", (*added.value)->source);
return nullptr;
}
@ -3027,12 +3024,11 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
for (size_t i = 0; i < sem_members.size(); i++) {
auto* mem_type = sem_members[i]->Type();
if (mem_type->Is<sem::Atomic>()) {
atomic_composite_info_.emplace(out, sem_members[i]->Declaration()->source);
atomic_composite_info_.Add(out, &sem_members[i]->Declaration()->source);
break;
} else {
auto found = atomic_composite_info_.find(mem_type);
if (found != atomic_composite_info_.end()) {
atomic_composite_info_.emplace(out, found->second);
if (auto* found = atomic_composite_info_.Find(mem_type)) {
atomic_composite_info_.Add(out, *found);
break;
}
}

View File

@ -18,8 +18,6 @@
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -434,13 +432,13 @@ class Resolver {
SemHelper sem_;
Validator validator_;
ast::Extensions enabled_extensions_;
std::vector<sem::Function*> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
utils::Vector<sem::Function*, 8> entry_points_;
utils::Hashmap<const sem::Type*, const Source*, 8> atomic_composite_info_;
utils::Bitset<0> marked_;
ExprEvalStageConstraint expr_eval_stage_constraint_;
std::unordered_map<OverrideId, const sem::Variable*> override_ids_;
std::unordered_map<ArrayInitializerSig, sem::CallTarget*> array_inits_;
std::unordered_map<StructInitializerSig, sem::CallTarget*> struct_inits_;
utils::Hashmap<OverrideId, const sem::Variable*, 8> override_ids_;
utils::Hashmap<ArrayInitializerSig, sem::CallTarget*, 8> array_inits_;
utils::Hashmap<StructInitializerSig, sem::CallTarget*, 8> struct_inits_;
sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr;

View File

@ -54,8 +54,8 @@ class SemHelper {
/// @param node the node to retrieve
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved)) : nullptr;
auto* resolved = dependencies_.resolved_symbols.Find(node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
}
/// @returns the resolved type of the ast::Expression `expr`

View File

@ -16,8 +16,6 @@
#include <limits>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -139,7 +137,7 @@ struct ParameterInfo {
bool pointer_may_become_non_uniform = false;
/// The parameters that are required to be uniform for the contents of this pointer parameter to
/// be uniform at function exit.
std::vector<const sem::Parameter*> pointer_param_output_sources;
utils::Vector<const sem::Parameter*, 8> pointer_param_output_sources;
/// The node in the graph that corresponds to this parameter's initial value.
Node* init_value;
/// The node in the graph that corresponds to this parameter's output value (or nullptr).
@ -166,7 +164,7 @@ struct FunctionInfo {
}
// Create nodes for parameters.
parameters.resize(func->params.Length());
parameters.Resize(func->params.Length());
for (size_t i = 0; i < func->params.Length(); i++) {
auto* param = func->params[i];
auto param_name = builder->Symbols().NameFor(param->symbol);
@ -177,7 +175,7 @@ struct FunctionInfo {
if (sem->Type()->Is<sem::Pointer>()) {
node_init = CreateNode("ptrparam_" + name + "_init");
parameters[i].pointer_return_value = CreateNode("ptrparam_" + name + "_return");
local_var_decls.insert(sem);
local_var_decls.Add(sem);
} else {
node_init = CreateNode("param_" + name);
}
@ -194,7 +192,7 @@ struct FunctionInfo {
/// The function's uniformity effects.
FunctionTag function_tag;
/// The uniformity requirements of the function's parameters.
std::vector<ParameterInfo> parameters;
utils::Vector<ParameterInfo, 8> parameters;
/// The control flow graph.
utils::BlockAllocator<Node> nodes;
@ -213,24 +211,31 @@ struct FunctionInfo {
/// The set of a local read-write vars that are in scope at any given point in the process.
/// Includes pointer parameters.
std::unordered_set<const sem::Variable*> local_var_decls;
utils::Hashset<const sem::Variable*, 8> local_var_decls;
/// The set of partial pointer variables - pointers that point to a subobject (into an array or
/// struct).
std::unordered_set<const sem::Variable*> partial_ptrs;
utils::Hashset<const sem::Variable*, 4> partial_ptrs;
/// LoopSwitchInfo tracks information about the value of variables for a control flow construct.
struct LoopSwitchInfo {
/// The type of this control flow construct.
std::string type;
/// The input values for local variables at the start of this construct.
std::unordered_map<const sem::Variable*, Node*> var_in_nodes;
utils::Hashmap<const sem::Variable*, Node*, 8> var_in_nodes;
/// The exit values for local variables at the end of this construct.
std::unordered_map<const sem::Variable*, Node*> var_exit_nodes;
utils::Hashmap<const sem::Variable*, Node*, 8> var_exit_nodes;
};
/// Map from control flow statements to the corresponding LoopSwitchInfo structure.
std::unordered_map<const sem::Statement*, LoopSwitchInfo> loop_switch_infos;
/// @returns a LoopSwitchInfo for the given statement, allocating the LoopSwitchInfo if this is
/// the first call with the given statement.
LoopSwitchInfo& LoopSwitchInfoFor(const sem::Statement* stmt) {
return *loop_switch_infos.GetOrCreate(stmt,
[&] { return loop_switch_info_allocator.Create(); });
}
/// Disassociates the LoopSwitchInfo for the given statement.
void RemoveLoopSwitchInfoFor(const sem::Statement* stmt) { loop_switch_infos.Remove(stmt); }
/// Create a new node.
/// @param tag a tag used to identify the node for debugging purposes
@ -263,7 +268,13 @@ struct FunctionInfo {
private:
/// A list of tags that have already been used within the current function.
std::unordered_set<std::string> tags_;
utils::Hashset<std::string, 8> tags_;
/// Map from control flow statements to the corresponding LoopSwitchInfo structure.
utils::Hashmap<const sem::Statement*, LoopSwitchInfo*, 8> loop_switch_infos;
/// Allocator of LoopSwitchInfos
utils::BlockAllocator<LoopSwitchInfo> loop_switch_info_allocator;
};
/// UniformityGraph is used to analyze the uniformity requirements and effects of functions in a
@ -312,7 +323,7 @@ class UniformityGraph {
diag::List& diagnostics_;
/// Map of analyzed function results.
std::unordered_map<const ast::Function*, FunctionInfo> functions_;
utils::Hashmap<const ast::Function*, FunctionInfo, 8> functions_;
/// The function currently being analyzed.
FunctionInfo* current_function_;
@ -329,8 +340,7 @@ class UniformityGraph {
/// @param func the function to process
/// @returns true if there are no uniformity issues, false otherwise
bool ProcessFunction(const ast::Function* func) {
functions_.emplace(func, FunctionInfo(func, builder_));
current_function_ = &functions_.at(func);
current_function_ = functions_.Add(func, FunctionInfo(func, builder_)).value;
// Process function body.
if (func->body) {
@ -410,7 +420,7 @@ class UniformityGraph {
for (size_t j = 0; j < func->params.Length(); j++) {
auto* param_source = sem_.Get<sem::Parameter>(func->params[j]);
if (reachable.Contains(current_function_->parameters[j].init_value)) {
current_function_->parameters[i].pointer_param_output_sources.push_back(
current_function_->parameters[i].pointer_param_output_sources.Push(
param_source);
}
}
@ -439,7 +449,7 @@ class UniformityGraph {
},
[&](const ast::BlockStatement* b) {
std::unordered_map<const sem::Variable*, Node*> scoped_assignments;
utils::Hashmap<const sem::Variable*, Node*, 8> scoped_assignments;
{
// Push a new scope for variable assignments in the block.
current_function_->variables.Push();
@ -472,13 +482,13 @@ class UniformityGraph {
if (behaviors.Contains(sem::Behavior::kNext) ||
behaviors.Contains(sem::Behavior::kFallthrough)) {
for (auto var : scoped_assignments) {
current_function_->variables.Set(var.first, var.second);
current_function_->variables.Set(var.key, var.value);
}
}
// Remove any variables declared in this scope from the set of in-scope variables.
for (auto decl : sem_.Get<sem::BlockStatement>(b)->Decls()) {
current_function_->local_var_decls.erase(decl.value.variable);
current_function_->local_var_decls.Remove(decl.value.variable);
}
return cf;
@ -489,8 +499,8 @@ class UniformityGraph {
auto* parent = sem_.Get(b)
->FindFirstParent<sem::SwitchStatement, sem::LoopStatement,
sem::ForLoopStatement, sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent);
auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate variable values to the loop/switch exit nodes.
for (auto* var : current_function_->local_var_decls) {
@ -502,7 +512,7 @@ class UniformityGraph {
}
// Add an edge from the variable exit node to its value at this point.
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@ -526,8 +536,7 @@ class UniformityGraph {
{
auto* parent = sem_.Get(b)->FindFirstParent<sem::LoopStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent);
auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate variable values to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
@ -539,7 +548,7 @@ class UniformityGraph {
}
// Add an edge from the variable exit node to its value at this point.
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@ -580,8 +589,7 @@ class UniformityGraph {
auto* parent = sem_.Get(c)
->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement,
sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent);
auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate assignments to the loop input nodes.
for (auto* var : current_function_->local_var_decls) {
@ -593,11 +601,11 @@ class UniformityGraph {
}
// Add an edge from the variable's loop input node to its value at this point.
TINT_ASSERT(Resolver, info.var_in_nodes.count(var));
auto* in_node = info.var_in_nodes.at(var);
auto** in_node = info.var_in_nodes.Find(var);
TINT_ASSERT(Resolver, in_node != nullptr);
auto* out_node = current_function_->variables.Get(var);
if (out_node != in_node) {
in_node->AddEdge(out_node);
if (out_node != *in_node) {
(*in_node)->AddEdge(out_node);
}
}
return cf;
@ -618,7 +626,7 @@ class UniformityGraph {
}
auto* cf_start = cf_init;
auto& info = current_function_->loop_switch_infos[sem_loop];
auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "forloop";
// Create input nodes for any variables declared before this loop.
@ -626,7 +634,7 @@ class UniformityGraph {
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_forloop_in");
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes[v] = in_node;
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
}
@ -640,7 +648,7 @@ class UniformityGraph {
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@ -660,19 +668,19 @@ class UniformityGraph {
// Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) {
auto* in_node = v.second;
auto* out_node = current_function_->variables.Get(v.first);
auto* in_node = v.value;
auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) {
in_node->AddEdge(out_node);
}
}
// Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) {
current_function_->variables.Set(v.first, v.second);
for (auto& v : info.var_exit_nodes) {
current_function_->variables.Set(v.key, v.value);
}
current_function_->loop_switch_infos.erase(sem_loop);
current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf;
@ -687,7 +695,7 @@ class UniformityGraph {
auto* cf_start = cf;
auto& info = current_function_->loop_switch_infos[sem_loop];
auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "whileloop";
// Create input nodes for any variables declared before this loop.
@ -695,7 +703,7 @@ class UniformityGraph {
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_forloop_in");
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes[v] = in_node;
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
}
@ -710,7 +718,7 @@ class UniformityGraph {
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@ -721,9 +729,9 @@ class UniformityGraph {
cfx->AddEdge(cf);
// Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) {
auto* in_node = v.second;
auto* out_node = current_function_->variables.Get(v.first);
for (auto& v : info.var_in_nodes) {
auto* in_node = v.value;
auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) {
in_node->AddEdge(out_node);
}
@ -731,10 +739,10 @@ class UniformityGraph {
// Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) {
current_function_->variables.Set(v.first, v.second);
current_function_->variables.Set(v.key, v.value);
}
current_function_->loop_switch_infos.erase(sem_loop);
current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf;
@ -752,15 +760,15 @@ class UniformityGraph {
v->affects_control_flow = true;
v->AddEdge(v_cond);
std::unordered_map<const sem::Variable*, Node*> true_vars;
std::unordered_map<const sem::Variable*, Node*> false_vars;
utils::Hashmap<const sem::Variable*, Node*, 8> true_vars;
utils::Hashmap<const sem::Variable*, Node*, 8> false_vars;
// Helper to process a statement with a new scope for variable assignments.
// Populates `assigned_vars` with new nodes for any variables that are assigned in
// this statement.
auto process_in_scope =
[&](Node* cf_in, const ast::Statement* s,
std::unordered_map<const sem::Variable*, Node*>& assigned_vars) {
utils::Hashmap<const sem::Variable*, Node*, 8>& assigned_vars) {
// Push a new scope for variable assignments.
current_function_->variables.Push();
@ -790,7 +798,7 @@ class UniformityGraph {
// Update values for any variables assigned in the if or else blocks.
for (auto* var : current_function_->local_var_decls) {
// Skip variables not assigned in either block.
if (true_vars.count(var) == 0 && false_vars.count(var) == 0) {
if (!true_vars.Contains(var) && !false_vars.Contains(var)) {
continue;
}
@ -801,15 +809,15 @@ class UniformityGraph {
// Add edges to the assigned value or the initial value.
// Only add edges if the behavior for that block contains 'Next'.
if (true_has_next) {
if (true_vars.count(var)) {
out_node->AddEdge(true_vars.at(var));
if (true_vars.Contains(var)) {
out_node->AddEdge(*true_vars.Find(var));
} else {
out_node->AddEdge(current_function_->variables.Get(var));
}
}
if (false_has_next) {
if (false_vars.count(var)) {
out_node->AddEdge(false_vars.at(var));
if (false_vars.Contains(var)) {
out_node->AddEdge(*false_vars.Find(var));
} else {
out_node->AddEdge(current_function_->variables.Get(var));
}
@ -845,7 +853,7 @@ class UniformityGraph {
auto* sem_loop = sem_.Get(l);
auto* cfx = CreateNode("loop_start");
auto& info = current_function_->loop_switch_infos[sem_loop];
auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "loop";
// Create input nodes for any variables declared before this loop.
@ -853,7 +861,7 @@ class UniformityGraph {
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_loop_in", v->Declaration());
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes[v] = in_node;
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
}
@ -868,8 +876,8 @@ class UniformityGraph {
// Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) {
auto* in_node = v.second;
auto* out_node = current_function_->variables.Get(v.first);
auto* in_node = v.value;
auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) {
in_node->AddEdge(out_node);
}
@ -877,10 +885,10 @@ class UniformityGraph {
// Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) {
current_function_->variables.Set(v.first, v.second);
current_function_->variables.Set(v.key, v.value);
}
current_function_->loop_switch_infos.erase(sem_loop);
current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf;
@ -925,7 +933,7 @@ class UniformityGraph {
cf_end = CreateNode("switch_CFend");
}
auto& info = current_function_->loop_switch_infos[sem_switch];
auto& info = current_function_->LoopSwitchInfoFor(sem_switch);
info.type = "switch";
auto* cf_n = v;
@ -958,8 +966,7 @@ class UniformityGraph {
}
// Add an edge from the variable exit node to its new value.
auto* exit_node =
utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name =
builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
@ -974,7 +981,7 @@ class UniformityGraph {
// Update nodes for any variables assigned in the switch statement.
for (auto var : info.var_exit_nodes) {
current_function_->variables.Set(var.first, var.second);
current_function_->variables.Set(var.key, var.value);
}
return cf_end ? cf_end : cf;
@ -995,7 +1002,7 @@ class UniformityGraph {
auto* e = UnwrapIndirectAndAddressOfChain(unary_init);
if (e->IsAnyOf<ast::IndexAccessorExpression,
ast::MemberAccessorExpression>()) {
current_function_->partial_ptrs.insert(sem_var);
current_function_->partial_ptrs.Add(sem_var);
}
}
}
@ -1005,7 +1012,7 @@ class UniformityGraph {
current_function_->variables.Set(sem_var, node);
if (decl->variable->Is<ast::Var>()) {
current_function_->local_var_decls.insert(
current_function_->local_var_decls.Add(
sem_.Get<sem::LocalVariable>(decl->variable));
}
@ -1183,10 +1190,10 @@ class UniformityGraph {
// To determine if we're dereferencing a partial pointer, unwrap *&
// chains; if the final expression is an identifier, see if it's a
// partial pointer. If it's not an identifier, then it must be an
// index/acessor expression, and thus a partial pointer.
// index/accessor expression, and thus a partial pointer.
auto* e = UnwrapIndirectAndAddressOfChain(u);
if (auto* var_user = sem_.Get<sem::VariableUser>(e)) {
if (current_function_->partial_ptrs.count(var_user->Variable())) {
if (current_function_->partial_ptrs.Contains(var_user->Variable())) {
return true;
}
} else {
@ -1290,7 +1297,7 @@ class UniformityGraph {
// Process call arguments
Node* cf_last_arg = cf;
std::vector<Node*> args;
utils::Vector<Node*, 8> args;
for (size_t i = 0; i < call->args.Length(); i++) {
auto [cf_i, arg_i] = ProcessExpression(cf_last_arg, call->args[i]);
@ -1303,7 +1310,7 @@ class UniformityGraph {
arg_node->AddEdge(arg_i);
cf_last_arg = cf_i;
args.push_back(arg_node);
args.Push(arg_node);
}
// Note: This is an additional node that isn't described in the specification, for the
@ -1341,11 +1348,11 @@ class UniformityGraph {
[&](const sem::Function* func) {
// We must have already analyzed the user-defined function since we process
// functions in dependency order.
TINT_ASSERT(Resolver, functions_.count(func->Declaration()));
auto& info = functions_.at(func->Declaration());
callsite_tag = info.callsite_tag;
function_tag = info.function_tag;
func_info = &info;
auto* info = functions_.Find(func->Declaration());
TINT_ASSERT(Resolver, info != nullptr);
callsite_tag = info->callsite_tag;
function_tag = info->function_tag;
func_info = info;
},
[&](const sem::TypeInitializer*) {
callsite_tag = CallSiteNoRestriction;
@ -1371,7 +1378,7 @@ class UniformityGraph {
result->AddEdge(cf_after);
// For each argument, add edges based on parameter tags.
for (size_t i = 0; i < args.size(); i++) {
for (size_t i = 0; i < args.Length(); i++) {
if (func_info) {
switch (func_info->parameters[i].tag) {
case ParameterRequiredToBeUniform:
@ -1429,11 +1436,11 @@ class UniformityGraph {
/// @param source the starting node
/// @param reachable the set of reachable nodes to populate, if required
void Traverse(Node* source, utils::UniqueVector<Node*, 4>* reachable = nullptr) {
std::vector<Node*> to_visit{source};
utils::Vector<Node*, 8> to_visit{source};
while (!to_visit.empty()) {
auto* node = to_visit.back();
to_visit.pop_back();
while (!to_visit.IsEmpty()) {
auto* node = to_visit.Back();
to_visit.Pop();
if (reachable) {
reachable->Add(node);
@ -1441,7 +1448,7 @@ class UniformityGraph {
for (auto* to : node->edges) {
if (to->visited_from == nullptr) {
to->visited_from = node;
to_visit.push_back(to);
to_visit.Push(to);
}
}
}
@ -1473,8 +1480,8 @@ class UniformityGraph {
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
auto& target_info = functions_.at(user->Declaration());
for (auto* call_node : target_info.required_to_be_uniform->edges) {
auto* target_info = functions_.Find(user->Declaration());
for (auto* call_node : target_info->required_to_be_uniform->edges) {
if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
return FindBuiltinThatRequiresUniformity(child_call);
@ -1643,9 +1650,9 @@ class UniformityGraph {
// If this is a call to a user-defined function, add a note to show the reason that the
// parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) {
auto& next_function = functions_.at(user->Declaration());
Node* next_cause = next_function.parameters[cause->arg_index].init_value;
MakeError(next_function, next_cause, true);
auto* next_function = functions_.Find(user->Declaration());
Node* next_cause = next_function->parameters[cause->arg_index].init_value;
MakeError(*next_function, next_cause, true);
}
} else {
// The requirement was on a function callsite.

View File

@ -580,8 +580,8 @@ bool Validator::LocalVariable(const sem::Variable* local) const {
bool Validator::GlobalVariable(
const sem::GlobalVariable* global,
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const {
const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_ids,
const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const {
auto* decl = global->Declaration();
if (global->AddressSpace() != ast::AddressSpace::kWorkgroup &&
IsArrayWithOverrideCount(global->Type())) {
@ -702,7 +702,7 @@ bool Validator::GlobalVariable(
// buffer variables with a read_write access mode.
bool Validator::AtomicVariable(
const sem::Variable* var,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const {
const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const {
auto address_space = var->AddressSpace();
auto* decl = var->Declaration();
auto access = var->Access();
@ -716,14 +716,13 @@ bool Validator::AtomicVariable(
return false;
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
auto found = atomic_composite_info.find(type);
if (found != atomic_composite_info.end()) {
if (auto* found = atomic_composite_info.Find(type)) {
if (address_space != ast::AddressSpace::kStorage &&
address_space != ast::AddressSpace::kWorkgroup) {
AddError("atomic variables must have <storage> or <workgroup> address space",
source);
AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here",
found->second);
**found);
return false;
} else if (address_space == ast::AddressSpace::kStorage &&
access != ast::Access::kReadWrite) {
@ -732,7 +731,7 @@ bool Validator::AtomicVariable(
"access mode",
source);
AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here",
found->second);
**found);
return false;
}
}
@ -783,7 +782,7 @@ bool Validator::Let(const sem::Variable* v) const {
bool Validator::Override(
const sem::GlobalVariable* v,
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids) const {
const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_ids) const {
auto* decl = v->Declaration();
auto* storage_ty = v->Type()->UnwrapRef();
@ -796,12 +795,12 @@ bool Validator::Override(
for (auto* attr : decl->attributes) {
if (attr->Is<ast::IdAttribute>()) {
auto id = v->OverrideId();
if (auto it = override_ids.find(id); it != override_ids.end() && it->second != v) {
if (auto* var = override_ids.Find(id); var && *var != v) {
AddError("@id values must be unique", attr->source);
AddNote("a override with an ID of " + std::to_string(id.value) +
AddNote(
"a override with an ID of " + std::to_string(id.value) +
" was previously declared here:",
ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes)
->source);
ast::GetAttribute<ast::IdAttribute>((*var)->Declaration()->attributes)->source);
return false;
}
} else {
@ -1093,8 +1092,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
// order to catch conflicts.
// TODO(jrprice): This state could be stored in sem::Function instead, and then passed to
// sem::Function since it would be useful there too.
std::unordered_set<ast::BuiltinValue> builtins;
std::unordered_set<uint32_t> locations;
utils::Hashset<ast::BuiltinValue, 4> builtins;
utils::Hashset<uint32_t, 8> locations;
enum class ParamOrRetType {
kParameter,
kReturnType,
@ -1130,7 +1129,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
}
pipeline_io_attribute = attr;
if (builtins.count(builtin->builtin)) {
if (builtins.Contains(builtin->builtin)) {
AddError(attr_to_str(builtin) +
" attribute appears multiple times as pipeline " +
(param_or_ret == ParamOrRetType::kParameter ? "input" : "output"),
@ -1142,7 +1141,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
/* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
return false;
}
builtins.emplace(builtin->builtin);
builtins.Add(builtin->builtin);
} else if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", attr->source);
@ -1287,8 +1286,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
// Clear IO sets after parameter validation. Builtin and location attributes in return types
// should be validated independently from those used in parameters.
builtins.clear();
locations.clear();
builtins.Clear();
locations.Clear();
if (!func->ReturnType()->Is<sem::Void>()) {
if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(),
@ -1299,7 +1298,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
}
if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
builtins.count(ast::BuiltinValue::kPosition) == 0) {
!builtins.Contains(ast::BuiltinValue::kPosition)) {
// Check module-scope variables, as the SPIR-V sanitizer generates these.
bool found = false;
for (auto* global : func->TransitivelyReferencedGlobals()) {
@ -1327,18 +1326,18 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
}
// Validate there are no resource variable binding collisions
std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
utils::Hashmap<sem::BindingPoint, const ast::Variable*, 8> binding_points;
for (auto* global : func->TransitivelyReferencedGlobals()) {
auto* var_decl = global->Declaration()->As<ast::Var>();
if (!var_decl || !var_decl->HasBindingPoint()) {
continue;
}
auto bp = global->BindingPoint();
auto res = binding_points.emplace(bp, var_decl);
if (!res.second &&
if (auto added = binding_points.Add(bp, var_decl);
!added &&
IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kBindingPointCollision) &&
IsValidationEnabled(res.first->second->attributes,
IsValidationEnabled((*added.value)->attributes,
ast::DisabledValidation::kBindingPointCollision)) {
// https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
// Bindings must not alias within a shader stage: two different variables in the
@ -1350,7 +1349,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
"' references multiple variables that use the same resource binding @group(" +
std::to_string(bp.group) + "), @binding(" + std::to_string(bp.binding) + ")",
var_decl->source);
AddNote("first resource binding usage declared here", res.first->second->source);
AddNote("first resource binding usage declared here", (*added.value)->source);
return false;
}
}
@ -1917,7 +1916,7 @@ bool Validator::Matrix(const sem::Matrix* ty, const Source& source) const {
return true;
}
bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points) const {
bool Validator::PipelineStages(const utils::VectorRef<sem::Function*> entry_points) const {
auto backtrace = [&](const sem::Function* func, const sem::Function* entry_point) {
if (func != entry_point) {
TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) {
@ -2012,7 +2011,7 @@ bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points)
return true;
}
bool Validator::PushConstants(const std::vector<sem::Function*>& entry_points) const {
bool Validator::PushConstants(const utils::VectorRef<sem::Function*> entry_points) const {
for (auto* entry_point : entry_points) {
// State checked and modified by check_push_constant so that it remembers previously seen
// push_constant variables for an entry-point.
@ -2130,7 +2129,7 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
return false;
}
std::unordered_set<uint32_t> locations;
utils::Hashset<uint32_t, 8> locations;
for (auto* member : str->Members()) {
if (auto* r = member->Type()->As<sem::Array>()) {
if (r->IsRuntimeSized()) {
@ -2248,7 +2247,7 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location,
const sem::Type* type,
std::unordered_set<uint32_t>& locations,
utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input) const {
@ -2269,12 +2268,11 @@ bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
return false;
}
if (locations.count(location)) {
if (!locations.Add(location)) {
AddError(attr_to_str(loc_attr, location) + " attribute appears multiple times",
loc_attr->source);
return false;
}
locations.emplace(location);
return true;
}
@ -2311,7 +2309,7 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
}
const sem::CaseSelector* default_selector = nullptr;
std::unordered_map<int64_t, Source> selectors;
utils::Hashmap<int64_t, Source, 4> selectors;
for (auto* case_stmt : s->body) {
auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt);
@ -2338,18 +2336,16 @@ bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
}
auto value = selector->Value()->As<uint32_t>();
auto it = selectors.find(value);
if (it != selectors.end()) {
if (auto added = selectors.Add(value, selector->Declaration()->source); !added) {
AddError("duplicate switch case '" +
(decl_ty->IsAnyOf<sem::I32, sem::AbstractNumeric>()
? std::to_string(i32(value))
: std::to_string(value)) +
"'",
selector->Declaration()->source);
AddNote("previous case declared here", it->second);
AddNote("previous case declared here", *added.value);
return false;
}
selectors.emplace(value, selector->Declaration()->source);
}
}
@ -2477,12 +2473,12 @@ bool Validator::IncrementDecrementStatement(const ast::IncrementDecrementStateme
}
bool Validator::NoDuplicateAttributes(utils::VectorRef<const ast::Attribute*> attributes) const {
std::unordered_map<const TypeInfo*, Source> seen;
utils::Hashmap<const TypeInfo*, Source, 8> seen;
for (auto* d : attributes) {
auto res = seen.emplace(&d->TypeInfo(), d->source);
if (!res.second && !d->Is<ast::InternalAttribute>()) {
auto added = seen.Add(&d->TypeInfo(), d->source);
if (!added && !d->Is<ast::InternalAttribute>()) {
AddError("duplicate " + d->Name() + " attribute", d->source);
AddNote("first attribute declared here", res.first->second);
AddNote("first attribute declared here", *added.value);
return false;
}
}

View File

@ -17,16 +17,15 @@
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/tint/ast/pipeline_stage.h"
#include "src/tint/program_builder.h"
#include "src/tint/resolver/sem_helper.h"
#include "src/tint/sem/evaluation_stage.h"
#include "src/tint/source.h"
#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/vector.h"
// Forward declarations
namespace tint::ast {
@ -116,12 +115,12 @@ class Validator {
/// Validates pipeline stages
/// @param entry_points the entry points to the module
/// @returns true on success, false otherwise.
bool PipelineStages(const std::vector<sem::Function*>& entry_points) const;
bool PipelineStages(const utils::VectorRef<sem::Function*> entry_points) const;
/// Validates push_constant variables
/// @param entry_points the entry points to the module
/// @returns true on success, false otherwise.
bool PushConstants(const std::vector<sem::Function*>& entry_points) const;
bool PushConstants(const utils::VectorRef<sem::Function*> entry_points) const;
/// Validates aliases
/// @param alias the alias to validate
@ -156,7 +155,7 @@ class Validator {
/// @returns true on success, false otherwise.
bool AtomicVariable(
const sem::Variable* var,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const;
const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const;
/// Validates an assignment
/// @param a the assignment statement
@ -248,8 +247,8 @@ class Validator {
/// @returns true on success, false otherwise
bool GlobalVariable(
const sem::GlobalVariable* var,
const std::unordered_map<OverrideId, const sem::Variable*>& override_id,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const;
const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_id,
const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const;
/// Validates a break-if statement
/// @param stmt the statement to validate
@ -297,7 +296,7 @@ class Validator {
bool LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location,
const sem::Type* type,
std::unordered_set<uint32_t>& locations,
utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input = false) const;
@ -392,7 +391,7 @@ class Validator {
/// @param override_id the set of override ids in the module
/// @returns true on success, false otherwise.
bool Override(const sem::GlobalVariable* v,
const std::unordered_map<OverrideId, const sem::Variable*>& override_id) const;
const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_id) const;
/// Validates a 'const' variable declaration
/// @param v the variable to validate

View File

@ -14,11 +14,11 @@
#ifndef SRC_TINT_SCOPE_STACK_H_
#define SRC_TINT_SCOPE_STACK_H_
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/symbol.h"
#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/vector.h"
namespace tint {
@ -27,22 +27,13 @@ namespace tint {
template <class K, class V>
class ScopeStack {
public:
/// Constructor
ScopeStack() {
// Push global bucket
stack_.push_back({});
}
/// Copy Constructor
ScopeStack(const ScopeStack&) = default;
~ScopeStack() = default;
/// Push a new scope on to the stack
void Push() { stack_.push_back({}); }
void Push() { stack_.Push({}); }
/// Pop the scope off the top of the stack
void Pop() {
if (stack_.size() > 1) {
stack_.pop_back();
if (stack_.Length() > 1) {
stack_.Pop();
}
}
@ -52,19 +43,22 @@ class ScopeStack {
/// @returns the old value if there was an existing key at the top of the
/// stack, otherwise the zero initializer for type T.
V Set(const K& key, V val) {
std::swap(val, stack_.back()[key]);
auto& back = stack_.Back();
if (auto* el = back.Find(key)) {
std::swap(val, *el);
return val;
}
back.Add(key, val);
return {};
}
/// Retrieves a value from the stack
/// @param key the key to look for
/// @returns the value, or the zero initializer if the value was not found
V Get(const K& key) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
auto& map = *iter;
auto val = map.find(key);
if (val != map.end()) {
return val->second;
if (auto* val = iter->Find(key)) {
return *val;
}
}
@ -73,16 +67,16 @@ class ScopeStack {
/// Return the top scope of the stack.
/// @returns the top scope of the stack
const std::unordered_map<K, V>& Top() const { return stack_.back(); }
const utils::Hashmap<K, V, 8>& Top() const { return stack_.Back(); }
/// Clear the scope stack.
void Clear() {
stack_.clear();
stack_.push_back({});
stack_.Clear();
stack_.Push({});
}
private:
std::vector<std::unordered_map<K, V>> stack_;
utils::Vector<utils::Hashmap<K, V, 8>, 8> stack_ = {{}};
};
} // namespace tint

View File

@ -157,9 +157,9 @@ size_t HashCombine(size_t hash, const ARGS&... values) {
template <typename T>
struct UnorderedKeyWrapper {
/// The wrapped value
const T value;
T value;
/// The hash of value
const size_t hash;
size_t hash;
/// Constructor
/// @param v the value to wrap

View File

@ -524,7 +524,7 @@ class HashmapBase {
/// Shuffles slots for an insertion that has been placed one slot before `start`.
/// @param start the index of the first slot to start shuffling.
/// @param evicted the slot content that was evicted for the insertion.
void InsertShuffle(size_t start, Slot evicted) {
void InsertShuffle(size_t start, Slot&& evicted) {
Scan(start, [&](size_t, size_t index) {
auto& slot = slots_[index];