TD: Fix O(2^n) of reachable-by-entry-point

Re-jig the code so that this can be performed in O(n).

Fixed: tint:245
Change-Id: I6dc341c0313e3a1c808f15c66e0c70a7339640e5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/43641
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton 2021-03-03 19:54:44 +00:00 committed by Commit Bot service account
parent 04d93c88a0
commit 5a13258981
3 changed files with 80 additions and 30 deletions

View File

@ -137,29 +137,9 @@ bool TypeDeterminer::DetermineInternal() {
return false; return false;
} }
// Walk over the caller to callee information and update functions with
// which entry points call those functions.
for (auto* func : builder_->AST().Functions()) {
if (!func->IsEntryPoint()) {
continue;
}
for (const auto& callee : caller_to_callee_[func->symbol()]) {
set_entry_points(callee, func->symbol());
}
}
return true; return true;
} }
void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) {
auto* info = symbol_to_function_.at(fn_sym);
info->ancestor_entry_points.add(ep_sym);
for (const auto& callee : caller_to_callee_[fn_sym]) {
set_entry_points(callee, ep_sym);
}
}
bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) { bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
for (auto* func : funcs) { for (auto* func : funcs) {
if (!DetermineFunction(func)) { if (!DetermineFunction(func)) {
@ -439,9 +419,6 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* call) {
} }
} else { } else {
if (current_function_) { if (current_function_) {
caller_to_callee_[current_function_->declaration->symbol()].push_back(
ident->symbol());
auto callee_func_it = symbol_to_function_.find(ident->symbol()); auto callee_func_it = symbol_to_function_.find(ident->symbol());
if (callee_func_it == symbol_to_function_.end()) { if (callee_func_it == symbol_to_function_.end()) {
if (current_function_->declaration->symbol() == ident->symbol()) { if (current_function_->declaration->symbol() == ident->symbol()) {
@ -457,6 +434,13 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* call) {
} }
auto* callee_func = callee_func_it->second; auto* callee_func = callee_func_it->second;
// Note: Requires called functions to be resolved first.
// This is currently guaranteed as functions must be declared before use.
current_function_->transitive_calls.add(callee_func);
for (auto* transitive_call : callee_func->transitive_calls) {
current_function_->transitive_calls.add(transitive_call);
}
// We inherit any referenced variables from the callee. // We inherit any referenced variables from the callee.
for (auto* var : callee_func->referenced_module_vars) { for (auto* var : callee_func->referenced_module_vars) {
set_referenced_from_function_if_needed(var, false); set_referenced_from_function_if_needed(var, false);
@ -1004,6 +988,25 @@ void TypeDeterminer::SetType(ast::Expression* expr, type::Type* type) {
void TypeDeterminer::CreateSemanticNodes() const { void TypeDeterminer::CreateSemanticNodes() const {
auto& sem = builder_->Sem(); auto& sem = builder_->Sem();
// Collate all the 'ancestor_entry_points' - this is a map of function symbol
// to all the entry points that transitively call the function.
std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
for (auto* func : builder_->AST().Functions()) {
auto it = function_to_info_.find(func);
if (it == function_to_info_.end()) {
continue; // Type determination has likely errored. Process what we can.
}
auto* info = it->second;
if (!func->IsEntryPoint()) {
continue;
}
for (auto* call : info->transitive_calls) {
auto& vec = ancestor_entry_points[call->declaration->symbol()];
vec.emplace_back(func->symbol());
}
}
// Create semantic nodes for all ast::Variables // Create semantic nodes for all ast::Variables
for (auto it : variable_to_info_) { for (auto it : variable_to_info_) {
auto* var = it.first; auto* var = it.first;
@ -1038,10 +1041,11 @@ void TypeDeterminer::CreateSemanticNodes() const {
for (auto it : function_to_info_) { for (auto it : function_to_info_) {
auto* func = it.first; auto* func = it.first;
auto* info = it.second; auto* info = it.second;
auto* sem_func = builder_->create<semantic::Function>( auto* sem_func = builder_->create<semantic::Function>(
info->declaration, remap_vars(info->referenced_module_vars), info->declaration, remap_vars(info->referenced_module_vars),
remap_vars(info->local_referenced_module_vars), remap_vars(info->local_referenced_module_vars),
info->ancestor_entry_points); ancestor_entry_points[func->symbol()]);
func_info_to_sem_func.emplace(info, sem_func); func_info_to_sem_func.emplace(info, sem_func);
sem.Add(func, sem_func); sem.Add(func, sem_func);
} }

View File

@ -91,7 +91,9 @@ class TypeDeterminer {
ast::Function* const declaration; ast::Function* const declaration;
UniqueVector<VariableInfo*> referenced_module_vars; UniqueVector<VariableInfo*> referenced_module_vars;
UniqueVector<VariableInfo*> local_referenced_module_vars; UniqueVector<VariableInfo*> local_referenced_module_vars;
UniqueVector<Symbol> ancestor_entry_points;
// List of transitive calls this function makes
UniqueVector<FunctionInfo*> transitive_calls;
}; };
/// Structure holding semantic information about an expression. /// Structure holding semantic information about an expression.
@ -118,7 +120,8 @@ class TypeDeterminer {
/// @param funcs the functions to check /// @param funcs the functions to check
/// @returns true if the determination was successful /// @returns true if the determination was successful
bool DetermineFunctions(const ast::FunctionList& funcs); bool DetermineFunctions(const ast::FunctionList& funcs);
/// Determines type information for a function /// Determines type information for a function. Requires all dependency
/// (callee) functions to have DetermineFunction() called on them first.
/// @param func the function to check /// @param func the function to check
/// @returns true if the determination was successful /// @returns true if the determination was successful
bool DetermineFunction(ast::Function* func); bool DetermineFunction(ast::Function* func);
@ -162,7 +165,6 @@ class TypeDeterminer {
uint32_t* id); uint32_t* id);
void set_referenced_from_function_if_needed(VariableInfo* var, bool local); void set_referenced_from_function_if_needed(VariableInfo* var, bool local);
void set_entry_points(const Symbol& fn_sym, Symbol ep_sym);
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr); bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
bool DetermineBinary(ast::BinaryExpression* expr); bool DetermineBinary(ast::BinaryExpression* expr);
@ -202,9 +204,6 @@ class TypeDeterminer {
semantic::Statement* current_statement_ = nullptr; semantic::Statement* current_statement_ = nullptr;
BlockAllocator<VariableInfo> variable_infos_; BlockAllocator<VariableInfo> variable_infos_;
BlockAllocator<FunctionInfo> function_infos_; BlockAllocator<FunctionInfo> function_infos_;
// Map from caller functions to callee functions.
std::unordered_map<Symbol, std::vector<Symbol>> caller_to_callee_;
}; };
} // namespace tint } // namespace tint

View File

@ -3269,6 +3269,53 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty()); EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty());
} }
// Check for linear-time traversal of functions reachable from entry points.
// See: crbug.com/tint/245
TEST_F(TypeDeterminerTest, Function_EntryPoints_LinearTime) {
// fn lNa() { }
// fn lNb() { }
// ...
// fn l2a() { l3a(); l3b(); }
// fn l2b() { l3a(); l3b(); }
// fn l1a() { l2a(); l2b(); }
// fn l1b() { l2a(); l2b(); }
// fn main() { l1a(); l1b(); }
static constexpr int levels = 64;
auto fn_a = [](int level) { return "l" + std::to_string(level + 1) + "a"; };
auto fn_b = [](int level) { return "l" + std::to_string(level + 1) + "b"; };
Func(fn_a(levels), {}, ty.void_(), {}, {});
Func(fn_b(levels), {}, ty.void_(), {}, {});
for (int i = levels - 1; i >= 0; i--) {
Func(fn_a(i), {}, ty.void_(),
{
create<ast::CallStatement>(Call(fn_a(i + 1))),
create<ast::CallStatement>(Call(fn_b(i + 1))),
},
{});
Func(fn_b(i), {}, ty.void_(),
{
create<ast::CallStatement>(Call(fn_a(i + 1))),
create<ast::CallStatement>(Call(fn_b(i + 1))),
},
{});
}
Func("main", {}, ty.void_(),
{
create<ast::CallStatement>(Call(fn_a(0))),
create<ast::CallStatement>(Call(fn_b(0))),
},
{
create<ast::StageDecoration>(ast::PipelineStage::kVertex),
});
ASSERT_TRUE(td()->Determine()) << td()->error();
}
using TypeDeterminerTextureIntrinsicTest = using TypeDeterminerTextureIntrinsicTest =
TypeDeterminerTestWithParam<ast::intrinsic::test::TextureOverloadCase>; TypeDeterminerTestWithParam<ast::intrinsic::test::TextureOverloadCase>;