Add std::hash<tint::Symbol> specialization
Allows symbols to be used as keys for std::unordered_map and std::unordered_set. Replace all map / set use of uint32_t for Symbol, where applicable. Change-Id: If142b4ad1f0ee65bc62209ae2f277e7746be19bb Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37262 Reviewed-by: dan sinclair <dsinclair@chromium.org> Commit-Queue: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
4527a512eb
commit
f6866a2ffc
|
@ -48,14 +48,12 @@ class ScopeStack {
|
||||||
/// Set a global variable in the stack
|
/// Set a global variable in the stack
|
||||||
/// @param symbol the symbol of the variable
|
/// @param symbol the symbol of the variable
|
||||||
/// @param val the value
|
/// @param val the value
|
||||||
void set_global(const Symbol& symbol, T val) {
|
void set_global(const Symbol& symbol, T val) { stack_[0][symbol] = val; }
|
||||||
stack_[0][symbol.value()] = val;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sets variable into the top most scope of the stack
|
/// Sets variable into the top most scope of the stack
|
||||||
/// @param symbol the symbol of the variable
|
/// @param symbol the symbol of the variable
|
||||||
/// @param val the value
|
/// @param val the value
|
||||||
void set(const Symbol& symbol, T val) { stack_.back()[symbol.value()] = val; }
|
void set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
|
||||||
|
|
||||||
/// Checks for the given `symbol` in the stack
|
/// Checks for the given `symbol` in the stack
|
||||||
/// @param symbol the symbol to look for
|
/// @param symbol the symbol to look for
|
||||||
|
@ -79,7 +77,7 @@ class ScopeStack {
|
||||||
bool get(const Symbol& symbol, T* ret, bool* is_global) const {
|
bool get(const Symbol& symbol, T* ret, bool* is_global) const {
|
||||||
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
|
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
|
||||||
auto& map = *iter;
|
auto& map = *iter;
|
||||||
auto val = map.find(symbol.value());
|
auto val = map.find(symbol);
|
||||||
|
|
||||||
if (val != map.end()) {
|
if (val != map.end()) {
|
||||||
if (ret) {
|
if (ret) {
|
||||||
|
@ -95,7 +93,7 @@ class ScopeStack {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::unordered_map<uint32_t, T>> stack_;
|
std::vector<std::unordered_map<Symbol, T>> stack_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
16
src/symbol.h
16
src/symbol.h
|
@ -67,4 +67,20 @@ class Symbol {
|
||||||
|
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
|
||||||
|
/// Custom std::hash specialization for tint::Symbol so symbols can be used as
|
||||||
|
/// keys for std::unordered_map and std::unordered_set.
|
||||||
|
template <>
|
||||||
|
class hash<tint::Symbol> {
|
||||||
|
public:
|
||||||
|
/// @param sym the symbol to return
|
||||||
|
/// @return the Symbol internal value
|
||||||
|
inline std::size_t operator()(const tint::Symbol& sym) const {
|
||||||
|
return static_cast<std::size_t>(sym.value());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace std
|
||||||
|
|
||||||
#endif // SRC_SYMBOL_H_
|
#endif // SRC_SYMBOL_H_
|
||||||
|
|
|
@ -40,7 +40,7 @@ Symbol SymbolTable::Register(const std::string& name) {
|
||||||
++next_symbol_;
|
++next_symbol_;
|
||||||
|
|
||||||
name_to_symbol_[name] = sym;
|
name_to_symbol_[name] = sym;
|
||||||
symbol_to_name_[sym.value()] = name;
|
symbol_to_name_[sym] = name;
|
||||||
|
|
||||||
return sym;
|
return sym;
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ Symbol SymbolTable::GetSymbol(const std::string& name) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SymbolTable::NameFor(const Symbol symbol) const {
|
std::string SymbolTable::NameFor(const Symbol symbol) const {
|
||||||
auto it = symbol_to_name_.find(symbol.value());
|
auto it = symbol_to_name_.find(symbol);
|
||||||
if (it == symbol_to_name_.end())
|
if (it == symbol_to_name_.end())
|
||||||
return "";
|
return "";
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ class SymbolTable {
|
||||||
// The value to be associated to the next registered symbol table entry.
|
// The value to be associated to the next registered symbol table entry.
|
||||||
uint32_t next_symbol_ = 1;
|
uint32_t next_symbol_ = 1;
|
||||||
|
|
||||||
std::unordered_map<uint32_t, std::string> symbol_to_name_;
|
std::unordered_map<Symbol, std::string> symbol_to_name_;
|
||||||
std::unordered_map<std::string, Symbol> name_to_symbol_;
|
std::unordered_map<std::string, Symbol> name_to_symbol_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -124,7 +124,7 @@ bool TypeDeterminer::Determine() {
|
||||||
if (!func->IsEntryPoint()) {
|
if (!func->IsEntryPoint()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (const auto& callee : caller_to_callee_[func->symbol().value()]) {
|
for (const auto& callee : caller_to_callee_[func->symbol()]) {
|
||||||
set_entry_points(callee, func->symbol());
|
set_entry_points(callee, func->symbol());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -133,9 +133,9 @@ bool TypeDeterminer::Determine() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) {
|
void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) {
|
||||||
symbol_to_function_[fn_sym.value()]->add_ancestor_entry_point(ep_sym);
|
symbol_to_function_[fn_sym]->add_ancestor_entry_point(ep_sym);
|
||||||
|
|
||||||
for (const auto& callee : caller_to_callee_[fn_sym.value()]) {
|
for (const auto& callee : caller_to_callee_[fn_sym]) {
|
||||||
set_entry_points(callee, ep_sym);
|
set_entry_points(callee, ep_sym);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,7 +150,7 @@ bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
|
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
|
||||||
symbol_to_function_[func->symbol().value()] = func;
|
symbol_to_function_[func->symbol()] = func;
|
||||||
|
|
||||||
current_function_ = func;
|
current_function_ = func;
|
||||||
|
|
||||||
|
@ -389,7 +389,7 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (current_function_) {
|
if (current_function_) {
|
||||||
caller_to_callee_[current_function_->symbol().value()].push_back(
|
caller_to_callee_[current_function_->symbol()].push_back(
|
||||||
ident->symbol());
|
ident->symbol());
|
||||||
|
|
||||||
auto* callee_func = mod_->FindFunctionBySymbol(ident->symbol());
|
auto* callee_func = mod_->FindFunctionBySymbol(ident->symbol());
|
||||||
|
@ -906,7 +906,7 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto iter = symbol_to_function_.find(symbol.value());
|
auto iter = symbol_to_function_.find(symbol);
|
||||||
if (iter != symbol_to_function_.end()) {
|
if (iter != symbol_to_function_.end()) {
|
||||||
expr->set_result_type(iter->second->return_type());
|
expr->set_result_type(iter->second->return_type());
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -129,11 +129,11 @@ class TypeDeterminer {
|
||||||
ast::Module* mod_;
|
ast::Module* mod_;
|
||||||
std::string error_;
|
std::string error_;
|
||||||
ScopeStack<ast::Variable*> variable_stack_;
|
ScopeStack<ast::Variable*> variable_stack_;
|
||||||
std::unordered_map<uint32_t, ast::Function*> symbol_to_function_;
|
std::unordered_map<Symbol, ast::Function*> symbol_to_function_;
|
||||||
ast::Function* current_function_ = nullptr;
|
ast::Function* current_function_ = nullptr;
|
||||||
|
|
||||||
// Map from caller functions to callee functions.
|
// Map from caller functions to callee functions.
|
||||||
std::unordered_map<uint32_t, std::vector<Symbol>> caller_to_callee_;
|
std::unordered_map<Symbol, std::vector<Symbol>> caller_to_callee_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
|
@ -160,7 +160,7 @@ bool GeneratorImpl::Generate(std::ostream& out) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_set<uint32_t> emitted_globals;
|
std::unordered_set<Symbol> emitted_globals;
|
||||||
// Make sure all entry point data is emitted before the entry point functions
|
// Make sure all entry point data is emitted before the entry point functions
|
||||||
for (auto* func : module_->functions()) {
|
for (auto* func : module_->functions()) {
|
||||||
if (!func->IsEntryPoint()) {
|
if (!func->IsEntryPoint()) {
|
||||||
|
@ -198,14 +198,14 @@ Symbol GeneratorImpl::current_ep_var_symbol(VarType type) {
|
||||||
Symbol sym;
|
Symbol sym;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case VarType::kIn: {
|
case VarType::kIn: {
|
||||||
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
|
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_);
|
||||||
if (in_it != ep_sym_to_in_data_.end()) {
|
if (in_it != ep_sym_to_in_data_.end()) {
|
||||||
sym = in_it->second.var_symbol;
|
sym = in_it->second.var_symbol;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case VarType::kOut: {
|
case VarType::kOut: {
|
||||||
auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value());
|
auto outit = ep_sym_to_out_data_.find(current_ep_sym_);
|
||||||
if (outit != ep_sym_to_out_data_.end()) {
|
if (outit != ep_sym_to_out_data_.end()) {
|
||||||
sym = outit->second.var_symbol;
|
sym = outit->second.var_symbol;
|
||||||
}
|
}
|
||||||
|
@ -1279,14 +1279,14 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
|
||||||
//
|
//
|
||||||
// We emit both of them if they're there regardless of if they're both used.
|
// We emit both of them if they're there regardless of if they're both used.
|
||||||
if (emit_duplicate_functions) {
|
if (emit_duplicate_functions) {
|
||||||
auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
|
auto in_it = ep_sym_to_in_data_.find(ep_sym);
|
||||||
if (in_it != ep_sym_to_in_data_.end()) {
|
if (in_it != ep_sym_to_in_data_.end()) {
|
||||||
out << "in " << namer_->NameFor(in_it->second.struct_symbol) << " "
|
out << "in " << namer_->NameFor(in_it->second.struct_symbol) << " "
|
||||||
<< namer_->NameFor(in_it->second.var_symbol);
|
<< namer_->NameFor(in_it->second.var_symbol);
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outit = ep_sym_to_out_data_.find(ep_sym.value());
|
auto outit = ep_sym_to_out_data_.find(ep_sym);
|
||||||
if (outit != ep_sym_to_out_data_.end()) {
|
if (outit != ep_sym_to_out_data_.end()) {
|
||||||
if (!first) {
|
if (!first) {
|
||||||
out << ", ";
|
out << ", ";
|
||||||
|
@ -1328,7 +1328,7 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
|
||||||
bool GeneratorImpl::EmitEntryPointData(
|
bool GeneratorImpl::EmitEntryPointData(
|
||||||
std::ostream& out,
|
std::ostream& out,
|
||||||
ast::Function* func,
|
ast::Function* func,
|
||||||
std::unordered_set<uint32_t>& emitted_globals) {
|
std::unordered_set<Symbol>& emitted_globals) {
|
||||||
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> in_variables;
|
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> in_variables;
|
||||||
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> outvariables;
|
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> outvariables;
|
||||||
for (auto data : func->referenced_location_variables()) {
|
for (auto data : func->referenced_location_variables()) {
|
||||||
|
@ -1369,10 +1369,10 @@ bool GeneratorImpl::EmitEntryPointData(
|
||||||
|
|
||||||
// If the global has already been emitted we skip it, it's been emitted by
|
// If the global has already been emitted we skip it, it's been emitted by
|
||||||
// a previous entry point.
|
// a previous entry point.
|
||||||
if (emitted_globals.count(var->symbol().value()) != 0) {
|
if (emitted_globals.count(var->symbol()) != 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
emitted_globals.insert(var->symbol().value());
|
emitted_globals.insert(var->symbol());
|
||||||
|
|
||||||
auto* type = var->type()->UnwrapIfNeeded();
|
auto* type = var->type()->UnwrapIfNeeded();
|
||||||
if (auto* strct = type->As<ast::type::Struct>()) {
|
if (auto* strct = type->As<ast::type::Struct>()) {
|
||||||
|
@ -1413,10 +1413,10 @@ bool GeneratorImpl::EmitEntryPointData(
|
||||||
|
|
||||||
// If the global has already been emitted we skip it, it's been emitted by
|
// If the global has already been emitted we skip it, it's been emitted by
|
||||||
// a previous entry point.
|
// a previous entry point.
|
||||||
if (emitted_globals.count(var->symbol().value()) != 0) {
|
if (emitted_globals.count(var->symbol()) != 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
emitted_globals.insert(var->symbol().value());
|
emitted_globals.insert(var->symbol());
|
||||||
|
|
||||||
auto* ac = var->type()->As<ast::type::AccessControl>();
|
auto* ac = var->type()->As<ast::type::AccessControl>();
|
||||||
if (ac == nullptr) {
|
if (ac == nullptr) {
|
||||||
|
@ -1439,8 +1439,8 @@ bool GeneratorImpl::EmitEntryPointData(
|
||||||
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
||||||
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
|
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
|
||||||
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
|
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
|
||||||
ep_sym_to_in_data_[func->symbol().value()] = {
|
ep_sym_to_in_data_[func->symbol()] = {in_struct_sym,
|
||||||
in_struct_sym, module_->RegisterSymbol(in_var_name)};
|
module_->RegisterSymbol(in_var_name)};
|
||||||
|
|
||||||
make_indent(out);
|
make_indent(out);
|
||||||
out << "struct " << namer_->NameFor(in_struct_sym) << " {" << std::endl;
|
out << "struct " << namer_->NameFor(in_struct_sym) << " {" << std::endl;
|
||||||
|
@ -1486,7 +1486,7 @@ bool GeneratorImpl::EmitEntryPointData(
|
||||||
auto outstruct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
auto outstruct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
||||||
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
|
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
|
||||||
auto outvar_name = namer_->GenerateName(kTintStructOutVarPrefix);
|
auto outvar_name = namer_->GenerateName(kTintStructOutVarPrefix);
|
||||||
ep_sym_to_out_data_[func->symbol().value()] = {
|
ep_sym_to_out_data_[func->symbol()] = {
|
||||||
outstruct_sym, module_->RegisterSymbol(outvar_name)};
|
outstruct_sym, module_->RegisterSymbol(outvar_name)};
|
||||||
|
|
||||||
make_indent(out);
|
make_indent(out);
|
||||||
|
@ -1577,7 +1577,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
|
||||||
make_indent(out);
|
make_indent(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
|
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
|
||||||
bool has_outdata = outdata != ep_sym_to_out_data_.end();
|
bool has_outdata = outdata != ep_sym_to_out_data_.end();
|
||||||
if (has_outdata) {
|
if (has_outdata) {
|
||||||
out << namer_->NameFor(outdata->second.struct_symbol);
|
out << namer_->NameFor(outdata->second.struct_symbol);
|
||||||
|
@ -1586,7 +1586,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
|
||||||
}
|
}
|
||||||
out << " " << namer_->NameFor(current_ep_sym_) << "(";
|
out << " " << namer_->NameFor(current_ep_sym_) << "(";
|
||||||
|
|
||||||
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
|
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
|
||||||
if (in_data != ep_sym_to_in_data_.end()) {
|
if (in_data != ep_sym_to_in_data_.end()) {
|
||||||
out << namer_->NameFor(in_data->second.struct_symbol) << " "
|
out << namer_->NameFor(in_data->second.struct_symbol) << " "
|
||||||
<< namer_->NameFor(in_data->second.var_symbol);
|
<< namer_->NameFor(in_data->second.var_symbol);
|
||||||
|
@ -2023,7 +2023,7 @@ bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) {
|
||||||
|
|
||||||
if (generating_entry_point_) {
|
if (generating_entry_point_) {
|
||||||
out << "return";
|
out << "return";
|
||||||
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
|
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
|
||||||
if (outdata != ep_sym_to_out_data_.end()) {
|
if (outdata != ep_sym_to_out_data_.end()) {
|
||||||
out << " " << namer_->NameFor(outdata->second.var_symbol);
|
out << " " << namer_->NameFor(outdata->second.var_symbol);
|
||||||
}
|
}
|
||||||
|
|
|
@ -225,7 +225,7 @@ class GeneratorImpl {
|
||||||
/// @returns true if the entry point data was emitted
|
/// @returns true if the entry point data was emitted
|
||||||
bool EmitEntryPointData(std::ostream& out,
|
bool EmitEntryPointData(std::ostream& out,
|
||||||
ast::Function* func,
|
ast::Function* func,
|
||||||
std::unordered_set<uint32_t>& emitted_globals);
|
std::unordered_set<Symbol>& emitted_globals);
|
||||||
/// Handles emitting the entry point function
|
/// Handles emitting the entry point function
|
||||||
/// @param out the output stream
|
/// @param out the output stream
|
||||||
/// @param func the entry point
|
/// @param func the entry point
|
||||||
|
@ -395,8 +395,8 @@ class GeneratorImpl {
|
||||||
bool generating_entry_point_ = false;
|
bool generating_entry_point_ = false;
|
||||||
uint32_t loop_emission_counter_ = 0;
|
uint32_t loop_emission_counter_ = 0;
|
||||||
ScopeStack<ast::Variable*> global_variables_;
|
ScopeStack<ast::Variable*> global_variables_;
|
||||||
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
|
std::unordered_map<Symbol, EntryPointData> ep_sym_to_in_data_;
|
||||||
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
|
std::unordered_map<Symbol, EntryPointData> ep_sym_to_out_data_;
|
||||||
|
|
||||||
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
|
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
|
||||||
// function name. If there is no entry for a given key then function did
|
// function name. If there is no entry for a given key then function did
|
||||||
|
|
|
@ -72,7 +72,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
|
||||||
|
|
||||||
mod->AddFunction(func);
|
mod->AddFunction(func);
|
||||||
|
|
||||||
std::unordered_set<uint32_t> globals;
|
std::unordered_set<Symbol> globals;
|
||||||
|
|
||||||
ASSERT_TRUE(td.Determine()) << td.error();
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
||||||
|
@ -122,7 +122,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
|
||||||
|
|
||||||
mod->AddFunction(func);
|
mod->AddFunction(func);
|
||||||
|
|
||||||
std::unordered_set<uint32_t> globals;
|
std::unordered_set<Symbol> globals;
|
||||||
|
|
||||||
ASSERT_TRUE(td.Determine()) << td.error();
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
|
||||||
|
|
||||||
mod->AddFunction(func);
|
mod->AddFunction(func);
|
||||||
|
|
||||||
std::unordered_set<uint32_t> globals;
|
std::unordered_set<Symbol> globals;
|
||||||
|
|
||||||
ASSERT_TRUE(td.Determine()) << td.error();
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
||||||
|
@ -222,7 +222,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
|
||||||
|
|
||||||
mod->AddFunction(func);
|
mod->AddFunction(func);
|
||||||
|
|
||||||
std::unordered_set<uint32_t> globals;
|
std::unordered_set<Symbol> globals;
|
||||||
|
|
||||||
ASSERT_TRUE(td.Determine()) << td.error();
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
||||||
|
@ -269,7 +269,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
|
||||||
|
|
||||||
mod->AddFunction(func);
|
mod->AddFunction(func);
|
||||||
|
|
||||||
std::unordered_set<uint32_t> globals;
|
std::unordered_set<Symbol> globals;
|
||||||
|
|
||||||
ASSERT_TRUE(td.Determine()) << td.error();
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
||||||
|
@ -311,7 +311,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
|
||||||
|
|
||||||
mod->AddFunction(func);
|
mod->AddFunction(func);
|
||||||
|
|
||||||
std::unordered_set<uint32_t> globals;
|
std::unordered_set<Symbol> globals;
|
||||||
|
|
||||||
ASSERT_TRUE(td.Determine()) << td.error();
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
||||||
|
@ -361,7 +361,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
|
||||||
|
|
||||||
mod->AddFunction(func);
|
mod->AddFunction(func);
|
||||||
|
|
||||||
std::unordered_set<uint32_t> globals;
|
std::unordered_set<Symbol> globals;
|
||||||
|
|
||||||
ASSERT_TRUE(td.Determine()) << td.error();
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
|
||||||
|
|
|
@ -400,14 +400,14 @@ Symbol GeneratorImpl::current_ep_var_symbol(VarType type) {
|
||||||
Symbol sym;
|
Symbol sym;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case VarType::kIn: {
|
case VarType::kIn: {
|
||||||
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
|
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_);
|
||||||
if (in_it != ep_sym_to_in_data_.end()) {
|
if (in_it != ep_sym_to_in_data_.end()) {
|
||||||
sym = in_it->second.var_symbol;
|
sym = in_it->second.var_symbol;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case VarType::kOut: {
|
case VarType::kOut: {
|
||||||
auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value());
|
auto out_it = ep_sym_to_out_data_.find(current_ep_sym_);
|
||||||
if (out_it != ep_sym_to_out_data_.end()) {
|
if (out_it != ep_sym_to_out_data_.end()) {
|
||||||
sym = out_it->second.var_symbol;
|
sym = out_it->second.var_symbol;
|
||||||
}
|
}
|
||||||
|
@ -1061,7 +1061,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
|
||||||
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
||||||
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
|
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
|
||||||
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
|
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
|
||||||
ep_sym_to_in_data_[func->symbol().value()] = {
|
ep_sym_to_in_data_[func->symbol()] = {
|
||||||
in_struct_sym, module_->RegisterSymbol(in_var_name)};
|
in_struct_sym, module_->RegisterSymbol(in_var_name)};
|
||||||
|
|
||||||
make_indent();
|
make_indent();
|
||||||
|
@ -1099,7 +1099,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
|
||||||
auto out_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
auto out_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
|
||||||
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
|
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
|
||||||
auto out_var_name = namer_->GenerateName(kTintStructOutVarPrefix);
|
auto out_var_name = namer_->GenerateName(kTintStructOutVarPrefix);
|
||||||
ep_sym_to_out_data_[func->symbol().value()] = {
|
ep_sym_to_out_data_[func->symbol()] = {
|
||||||
out_struct_sym, module_->RegisterSymbol(out_var_name)};
|
out_struct_sym, module_->RegisterSymbol(out_var_name)};
|
||||||
|
|
||||||
make_indent();
|
make_indent();
|
||||||
|
@ -1284,14 +1284,14 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
|
||||||
//
|
//
|
||||||
// We emit both of them if they're there regardless of if they're both used.
|
// We emit both of them if they're there regardless of if they're both used.
|
||||||
if (emit_duplicate_functions) {
|
if (emit_duplicate_functions) {
|
||||||
auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
|
auto in_it = ep_sym_to_in_data_.find(ep_sym);
|
||||||
if (in_it != ep_sym_to_in_data_.end()) {
|
if (in_it != ep_sym_to_in_data_.end()) {
|
||||||
out_ << "thread " << namer_->NameFor(in_it->second.struct_symbol) << "& "
|
out_ << "thread " << namer_->NameFor(in_it->second.struct_symbol) << "& "
|
||||||
<< namer_->NameFor(in_it->second.var_symbol);
|
<< namer_->NameFor(in_it->second.var_symbol);
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out_it = ep_sym_to_out_data_.find(ep_sym.value());
|
auto out_it = ep_sym_to_out_data_.find(ep_sym);
|
||||||
if (out_it != ep_sym_to_out_data_.end()) {
|
if (out_it != ep_sym_to_out_data_.end()) {
|
||||||
if (!first) {
|
if (!first) {
|
||||||
out_ << ", ";
|
out_ << ", ";
|
||||||
|
@ -1421,7 +1421,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
|
||||||
|
|
||||||
// This is an entry point, the return type is the entry point output structure
|
// This is an entry point, the return type is the entry point output structure
|
||||||
// if one exists, or void otherwise.
|
// if one exists, or void otherwise.
|
||||||
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
|
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
|
||||||
bool has_out_data = out_data != ep_sym_to_out_data_.end();
|
bool has_out_data = out_data != ep_sym_to_out_data_.end();
|
||||||
if (has_out_data) {
|
if (has_out_data) {
|
||||||
out_ << namer_->NameFor(out_data->second.struct_symbol);
|
out_ << namer_->NameFor(out_data->second.struct_symbol);
|
||||||
|
@ -1431,7 +1431,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
|
||||||
out_ << " " << namer_->NameFor(func->symbol()) << "(";
|
out_ << " " << namer_->NameFor(func->symbol()) << "(";
|
||||||
|
|
||||||
bool first = true;
|
bool first = true;
|
||||||
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
|
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
|
||||||
if (in_data != ep_sym_to_in_data_.end()) {
|
if (in_data != ep_sym_to_in_data_.end()) {
|
||||||
out_ << namer_->NameFor(in_data->second.struct_symbol) << " "
|
out_ << namer_->NameFor(in_data->second.struct_symbol) << " "
|
||||||
<< namer_->NameFor(in_data->second.var_symbol) << " [[stage_in]]";
|
<< namer_->NameFor(in_data->second.var_symbol) << " [[stage_in]]";
|
||||||
|
@ -1734,7 +1734,7 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
|
||||||
out_ << "return";
|
out_ << "return";
|
||||||
|
|
||||||
if (generating_entry_point_) {
|
if (generating_entry_point_) {
|
||||||
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
|
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
|
||||||
if (out_data != ep_sym_to_out_data_.end()) {
|
if (out_data != ep_sym_to_out_data_.end()) {
|
||||||
out_ << " " << namer_->NameFor(out_data->second.var_symbol);
|
out_ << " " << namer_->NameFor(out_data->second.var_symbol);
|
||||||
}
|
}
|
||||||
|
|
|
@ -281,8 +281,8 @@ class GeneratorImpl : public TextGenerator {
|
||||||
uint32_t loop_emission_counter_ = 0;
|
uint32_t loop_emission_counter_ = 0;
|
||||||
Namer* namer_;
|
Namer* namer_;
|
||||||
|
|
||||||
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
|
std::unordered_map<Symbol, EntryPointData> ep_sym_to_in_data_;
|
||||||
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
|
std::unordered_map<Symbol, EntryPointData> ep_sym_to_out_data_;
|
||||||
|
|
||||||
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
|
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
|
||||||
// function name. If there is no entry for a given key then function did
|
// function name. If there is no entry for a given key then function did
|
||||||
|
|
|
@ -587,7 +587,7 @@ bool Builder::GenerateFunction(ast::Function* func) {
|
||||||
|
|
||||||
scope_stack_.pop_scope();
|
scope_stack_.pop_scope();
|
||||||
|
|
||||||
func_symbol_to_id_[func->symbol().value()] = func_id;
|
func_symbol_to_id_[func->symbol()] = func_id;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1814,7 +1814,7 @@ uint32_t Builder::GenerateCallExpression(ast::CallExpression* expr) {
|
||||||
|
|
||||||
OperandList ops = {Operand::Int(type_id), result};
|
OperandList ops = {Operand::Int(type_id), result};
|
||||||
|
|
||||||
auto func_id = func_symbol_to_id_[ident->symbol().value()];
|
auto func_id = func_symbol_to_id_[ident->symbol()];
|
||||||
if (func_id == 0) {
|
if (func_id == 0) {
|
||||||
error_ = "unable to find called function: " +
|
error_ = "unable to find called function: " +
|
||||||
mod_->SymbolToName(ident->symbol());
|
mod_->SymbolToName(ident->symbol());
|
||||||
|
|
|
@ -508,7 +508,7 @@ class Builder {
|
||||||
std::vector<Function> functions_;
|
std::vector<Function> functions_;
|
||||||
|
|
||||||
std::unordered_map<std::string, uint32_t> import_name_to_id_;
|
std::unordered_map<std::string, uint32_t> import_name_to_id_;
|
||||||
std::unordered_map<uint32_t, uint32_t> func_symbol_to_id_;
|
std::unordered_map<Symbol, uint32_t> func_symbol_to_id_;
|
||||||
std::unordered_map<std::string, uint32_t> type_name_to_id_;
|
std::unordered_map<std::string, uint32_t> type_name_to_id_;
|
||||||
std::unordered_map<std::string, uint32_t> const_to_id_;
|
std::unordered_map<std::string, uint32_t> const_to_id_;
|
||||||
std::unordered_map<std::string, uint32_t>
|
std::unordered_map<std::string, uint32_t>
|
||||||
|
|
Loading…
Reference in New Issue