Add std::hash<tint::Symbol> specialization

Allows symbols to be used as keys for std::unordered_map and std::unordered_set.
Replace all map / set use of uint32_t for Symbol, where applicable.

Change-Id: If142b4ad1f0ee65bc62209ae2f277e7746be19bb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37262
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2021-01-11 22:02:42 +00:00 committed by Commit Bot service account
parent 4527a512eb
commit f6866a2ffc
13 changed files with 71 additions and 57 deletions

View File

@ -48,14 +48,12 @@ class ScopeStack {
/// Set a global variable in the stack
/// @param symbol the symbol of the variable
/// @param val the value
void set_global(const Symbol& symbol, T val) {
stack_[0][symbol.value()] = val;
}
void set_global(const Symbol& symbol, T val) { stack_[0][symbol] = val; }
/// Sets variable into the top most scope of the stack
/// @param symbol the symbol of the variable
/// @param val the value
void set(const Symbol& symbol, T val) { stack_.back()[symbol.value()] = val; }
void set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
/// Checks for the given `symbol` in the stack
/// @param symbol the symbol to look for
@ -79,7 +77,7 @@ class ScopeStack {
bool get(const Symbol& symbol, T* ret, bool* is_global) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
auto& map = *iter;
auto val = map.find(symbol.value());
auto val = map.find(symbol);
if (val != map.end()) {
if (ret) {
@ -95,7 +93,7 @@ class ScopeStack {
}
private:
std::vector<std::unordered_map<uint32_t, T>> stack_;
std::vector<std::unordered_map<Symbol, T>> stack_;
};
} // namespace tint

View File

@ -67,4 +67,20 @@ class Symbol {
} // namespace tint
namespace std {
/// Custom std::hash specialization for tint::Symbol so symbols can be used as
/// keys for std::unordered_map and std::unordered_set.
template <>
class hash<tint::Symbol> {
public:
/// @param sym the symbol to return
/// @return the Symbol internal value
inline std::size_t operator()(const tint::Symbol& sym) const {
return static_cast<std::size_t>(sym.value());
}
};
} // namespace std
#endif // SRC_SYMBOL_H_

View File

@ -40,7 +40,7 @@ Symbol SymbolTable::Register(const std::string& name) {
++next_symbol_;
name_to_symbol_[name] = sym;
symbol_to_name_[sym.value()] = name;
symbol_to_name_[sym] = name;
return sym;
}
@ -51,7 +51,7 @@ Symbol SymbolTable::GetSymbol(const std::string& name) const {
}
std::string SymbolTable::NameFor(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol.value());
auto it = symbol_to_name_.find(symbol);
if (it == symbol_to_name_.end())
return "";

View File

@ -62,7 +62,7 @@ class SymbolTable {
// The value to be associated to the next registered symbol table entry.
uint32_t next_symbol_ = 1;
std::unordered_map<uint32_t, std::string> symbol_to_name_;
std::unordered_map<Symbol, std::string> symbol_to_name_;
std::unordered_map<std::string, Symbol> name_to_symbol_;
};

View File

@ -124,7 +124,7 @@ bool TypeDeterminer::Determine() {
if (!func->IsEntryPoint()) {
continue;
}
for (const auto& callee : caller_to_callee_[func->symbol().value()]) {
for (const auto& callee : caller_to_callee_[func->symbol()]) {
set_entry_points(callee, func->symbol());
}
}
@ -133,9 +133,9 @@ bool TypeDeterminer::Determine() {
}
void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) {
symbol_to_function_[fn_sym.value()]->add_ancestor_entry_point(ep_sym);
symbol_to_function_[fn_sym]->add_ancestor_entry_point(ep_sym);
for (const auto& callee : caller_to_callee_[fn_sym.value()]) {
for (const auto& callee : caller_to_callee_[fn_sym]) {
set_entry_points(callee, ep_sym);
}
}
@ -150,7 +150,7 @@ bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
}
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
symbol_to_function_[func->symbol().value()] = func;
symbol_to_function_[func->symbol()] = func;
current_function_ = func;
@ -389,7 +389,7 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
}
} else {
if (current_function_) {
caller_to_callee_[current_function_->symbol().value()].push_back(
caller_to_callee_[current_function_->symbol()].push_back(
ident->symbol());
auto* callee_func = mod_->FindFunctionBySymbol(ident->symbol());
@ -906,7 +906,7 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
return true;
}
auto iter = symbol_to_function_.find(symbol.value());
auto iter = symbol_to_function_.find(symbol);
if (iter != symbol_to_function_.end()) {
expr->set_result_type(iter->second->return_type());
return true;

View File

@ -129,11 +129,11 @@ class TypeDeterminer {
ast::Module* mod_;
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
std::unordered_map<uint32_t, ast::Function*> symbol_to_function_;
std::unordered_map<Symbol, ast::Function*> symbol_to_function_;
ast::Function* current_function_ = nullptr;
// Map from caller functions to callee functions.
std::unordered_map<uint32_t, std::vector<Symbol>> caller_to_callee_;
std::unordered_map<Symbol, std::vector<Symbol>> caller_to_callee_;
};
} // namespace tint

View File

@ -160,7 +160,7 @@ bool GeneratorImpl::Generate(std::ostream& out) {
}
}
std::unordered_set<uint32_t> emitted_globals;
std::unordered_set<Symbol> emitted_globals;
// Make sure all entry point data is emitted before the entry point functions
for (auto* func : module_->functions()) {
if (!func->IsEntryPoint()) {
@ -198,14 +198,14 @@ Symbol GeneratorImpl::current_ep_var_symbol(VarType type) {
Symbol sym;
switch (type) {
case VarType::kIn: {
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_it != ep_sym_to_in_data_.end()) {
sym = in_it->second.var_symbol;
}
break;
}
case VarType::kOut: {
auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value());
auto outit = ep_sym_to_out_data_.find(current_ep_sym_);
if (outit != ep_sym_to_out_data_.end()) {
sym = outit->second.var_symbol;
}
@ -1279,14 +1279,14 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
auto in_it = ep_sym_to_in_data_.find(ep_sym);
if (in_it != ep_sym_to_in_data_.end()) {
out << "in " << namer_->NameFor(in_it->second.struct_symbol) << " "
<< namer_->NameFor(in_it->second.var_symbol);
first = false;
}
auto outit = ep_sym_to_out_data_.find(ep_sym.value());
auto outit = ep_sym_to_out_data_.find(ep_sym);
if (outit != ep_sym_to_out_data_.end()) {
if (!first) {
out << ", ";
@ -1328,7 +1328,7 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
bool GeneratorImpl::EmitEntryPointData(
std::ostream& out,
ast::Function* func,
std::unordered_set<uint32_t>& emitted_globals) {
std::unordered_set<Symbol>& emitted_globals) {
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> in_variables;
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> outvariables;
for (auto data : func->referenced_location_variables()) {
@ -1369,10 +1369,10 @@ bool GeneratorImpl::EmitEntryPointData(
// If the global has already been emitted we skip it, it's been emitted by
// a previous entry point.
if (emitted_globals.count(var->symbol().value()) != 0) {
if (emitted_globals.count(var->symbol()) != 0) {
continue;
}
emitted_globals.insert(var->symbol().value());
emitted_globals.insert(var->symbol());
auto* type = var->type()->UnwrapIfNeeded();
if (auto* strct = type->As<ast::type::Struct>()) {
@ -1413,10 +1413,10 @@ bool GeneratorImpl::EmitEntryPointData(
// If the global has already been emitted we skip it, it's been emitted by
// a previous entry point.
if (emitted_globals.count(var->symbol().value()) != 0) {
if (emitted_globals.count(var->symbol()) != 0) {
continue;
}
emitted_globals.insert(var->symbol().value());
emitted_globals.insert(var->symbol());
auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac == nullptr) {
@ -1439,8 +1439,8 @@ bool GeneratorImpl::EmitEntryPointData(
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
ep_sym_to_in_data_[func->symbol().value()] = {
in_struct_sym, module_->RegisterSymbol(in_var_name)};
ep_sym_to_in_data_[func->symbol()] = {in_struct_sym,
module_->RegisterSymbol(in_var_name)};
make_indent(out);
out << "struct " << namer_->NameFor(in_struct_sym) << " {" << std::endl;
@ -1486,7 +1486,7 @@ bool GeneratorImpl::EmitEntryPointData(
auto outstruct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
auto outvar_name = namer_->GenerateName(kTintStructOutVarPrefix);
ep_sym_to_out_data_[func->symbol().value()] = {
ep_sym_to_out_data_[func->symbol()] = {
outstruct_sym, module_->RegisterSymbol(outvar_name)};
make_indent(out);
@ -1577,7 +1577,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
make_indent(out);
}
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
bool has_outdata = outdata != ep_sym_to_out_data_.end();
if (has_outdata) {
out << namer_->NameFor(outdata->second.struct_symbol);
@ -1586,7 +1586,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
}
out << " " << namer_->NameFor(current_ep_sym_) << "(";
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_data != ep_sym_to_in_data_.end()) {
out << namer_->NameFor(in_data->second.struct_symbol) << " "
<< namer_->NameFor(in_data->second.var_symbol);
@ -2023,7 +2023,7 @@ bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) {
if (generating_entry_point_) {
out << "return";
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
if (outdata != ep_sym_to_out_data_.end()) {
out << " " << namer_->NameFor(outdata->second.var_symbol);
}

View File

@ -225,7 +225,7 @@ class GeneratorImpl {
/// @returns true if the entry point data was emitted
bool EmitEntryPointData(std::ostream& out,
ast::Function* func,
std::unordered_set<uint32_t>& emitted_globals);
std::unordered_set<Symbol>& emitted_globals);
/// Handles emitting the entry point function
/// @param out the output stream
/// @param func the entry point
@ -395,8 +395,8 @@ class GeneratorImpl {
bool generating_entry_point_ = false;
uint32_t loop_emission_counter_ = 0;
ScopeStack<ast::Variable*> global_variables_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
std::unordered_map<Symbol, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<Symbol, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did

View File

@ -72,7 +72,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
mod->AddFunction(func);
std::unordered_set<uint32_t> globals;
std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@ -122,7 +122,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
mod->AddFunction(func);
std::unordered_set<uint32_t> globals;
std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@ -172,7 +172,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
mod->AddFunction(func);
std::unordered_set<uint32_t> globals;
std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@ -222,7 +222,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
mod->AddFunction(func);
std::unordered_set<uint32_t> globals;
std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@ -269,7 +269,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
mod->AddFunction(func);
std::unordered_set<uint32_t> globals;
std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@ -311,7 +311,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
mod->AddFunction(func);
std::unordered_set<uint32_t> globals;
std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@ -361,7 +361,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
mod->AddFunction(func);
std::unordered_set<uint32_t> globals;
std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();

View File

@ -400,14 +400,14 @@ Symbol GeneratorImpl::current_ep_var_symbol(VarType type) {
Symbol sym;
switch (type) {
case VarType::kIn: {
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_it != ep_sym_to_in_data_.end()) {
sym = in_it->second.var_symbol;
}
break;
}
case VarType::kOut: {
auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value());
auto out_it = ep_sym_to_out_data_.find(current_ep_sym_);
if (out_it != ep_sym_to_out_data_.end()) {
sym = out_it->second.var_symbol;
}
@ -1061,7 +1061,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
ep_sym_to_in_data_[func->symbol().value()] = {
ep_sym_to_in_data_[func->symbol()] = {
in_struct_sym, module_->RegisterSymbol(in_var_name)};
make_indent();
@ -1099,7 +1099,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto out_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
auto out_var_name = namer_->GenerateName(kTintStructOutVarPrefix);
ep_sym_to_out_data_[func->symbol().value()] = {
ep_sym_to_out_data_[func->symbol()] = {
out_struct_sym, module_->RegisterSymbol(out_var_name)};
make_indent();
@ -1284,14 +1284,14 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
auto in_it = ep_sym_to_in_data_.find(ep_sym);
if (in_it != ep_sym_to_in_data_.end()) {
out_ << "thread " << namer_->NameFor(in_it->second.struct_symbol) << "& "
<< namer_->NameFor(in_it->second.var_symbol);
first = false;
}
auto out_it = ep_sym_to_out_data_.find(ep_sym.value());
auto out_it = ep_sym_to_out_data_.find(ep_sym);
if (out_it != ep_sym_to_out_data_.end()) {
if (!first) {
out_ << ", ";
@ -1421,7 +1421,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
// This is an entry point, the return type is the entry point output structure
// if one exists, or void otherwise.
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
bool has_out_data = out_data != ep_sym_to_out_data_.end();
if (has_out_data) {
out_ << namer_->NameFor(out_data->second.struct_symbol);
@ -1431,7 +1431,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
out_ << " " << namer_->NameFor(func->symbol()) << "(";
bool first = true;
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_data != ep_sym_to_in_data_.end()) {
out_ << namer_->NameFor(in_data->second.struct_symbol) << " "
<< namer_->NameFor(in_data->second.var_symbol) << " [[stage_in]]";
@ -1734,7 +1734,7 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
out_ << "return";
if (generating_entry_point_) {
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
if (out_data != ep_sym_to_out_data_.end()) {
out_ << " " << namer_->NameFor(out_data->second.var_symbol);
}

View File

@ -281,8 +281,8 @@ class GeneratorImpl : public TextGenerator {
uint32_t loop_emission_counter_ = 0;
Namer* namer_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
std::unordered_map<Symbol, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<Symbol, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did

View File

@ -587,7 +587,7 @@ bool Builder::GenerateFunction(ast::Function* func) {
scope_stack_.pop_scope();
func_symbol_to_id_[func->symbol().value()] = func_id;
func_symbol_to_id_[func->symbol()] = func_id;
return true;
}
@ -1814,7 +1814,7 @@ uint32_t Builder::GenerateCallExpression(ast::CallExpression* expr) {
OperandList ops = {Operand::Int(type_id), result};
auto func_id = func_symbol_to_id_[ident->symbol().value()];
auto func_id = func_symbol_to_id_[ident->symbol()];
if (func_id == 0) {
error_ = "unable to find called function: " +
mod_->SymbolToName(ident->symbol());

View File

@ -508,7 +508,7 @@ class Builder {
std::vector<Function> functions_;
std::unordered_map<std::string, uint32_t> import_name_to_id_;
std::unordered_map<uint32_t, uint32_t> func_symbol_to_id_;
std::unordered_map<Symbol, uint32_t> func_symbol_to_id_;
std::unordered_map<std::string, uint32_t> type_name_to_id_;
std::unordered_map<std::string, uint32_t> const_to_id_;
std::unordered_map<std::string, uint32_t>