From 0592643782616ca75eec63e39e32126162db967d Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Mon, 21 Sep 2020 17:51:31 +0000 Subject: [PATCH] [type-determiner] Update to work with entry point and function stages. This Cl updates the type determiner to work with both styles of entry point definition. Change-Id: Ic48f1a5f0a5820821f9a74380896426a97483049 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28666 Commit-Queue: dan sinclair Reviewed-by: David Neto Reviewed-by: Sarah Mashayekhi --- src/ast/function.cc | 10 ++++ src/ast/function.h | 4 ++ src/type_determiner.cc | 11 ++++ src/type_determiner_test.cc | 104 ++++++++++++++++++++++++++++++++++++ 4 files changed, 129 insertions(+) diff --git a/src/ast/function.cc b/src/ast/function.cc index 8e8d853ccf..23bd0043d4 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -17,6 +17,7 @@ #include #include "src/ast/decorated_variable.h" +#include "src/ast/stage_decoration.h" #include "src/ast/workgroup_decoration.h" namespace tint { @@ -56,6 +57,15 @@ std::tuple Function::workgroup_size() const { return {1, 1, 1}; } +ast::PipelineStage Function::pipeline_stage() const { + for (const auto& deco : decorations_) { + if (deco->IsStage()) { + return deco->AsStage()->value(); + } + } + return ast::PipelineStage::kNone; +} + void Function::add_referenced_module_variable(Variable* var) { for (const auto* v : referenced_module_vars_) { if (v->name() == var->name()) { diff --git a/src/ast/function.h b/src/ast/function.h index 078f69293b..5cb75bafe7 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -28,6 +28,7 @@ #include "src/ast/function_decoration.h" #include "src/ast/location_decoration.h" #include "src/ast/node.h" +#include "src/ast/pipeline_stage.h" #include "src/ast/set_decoration.h" #include "src/ast/statement.h" #include "src/ast/type/type.h" @@ -100,6 +101,9 @@ class Function : public Node { /// return if no workgroup size was set. std::tuple workgroup_size() const; + /// @returns the functions pipeline stage or None if not set + ast::PipelineStage pipeline_stage() const; + /// Adds the given variable to the list of referenced module variables if it /// is not already included. /// @param var the module variable to add diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 948b04ca85..a55ec6c3ec 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -217,6 +217,17 @@ bool TypeDeterminer::Determine() { } } + // Walk over the caller to callee information and update functions with which + // entry points call those functions. + for (const auto& func : mod_->functions()) { + if (func->pipeline_stage() == ast::PipelineStage::kNone) { + continue; + } + for (const auto& callee : caller_to_callee_[func->name()]) { + set_entry_points(callee, func->name()); + } + } + return true; } diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 512e0b2ef6..67821ad17d 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -37,9 +37,11 @@ #include "src/ast/if_statement.h" #include "src/ast/loop_statement.h" #include "src/ast/member_accessor_expression.h" +#include "src/ast/pipeline_stage.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/sint_literal.h" +#include "src/ast/stage_decoration.h" #include "src/ast/struct.h" #include "src/ast/struct_member.h" #include "src/ast/switch_statement.h" @@ -4479,5 +4481,107 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints) { EXPECT_TRUE(ep_2_func_ptr->ancestor_entry_points().empty()); } +TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { + ast::type::F32Type f32; + + // fn b() {} + // fn c() { b(); } + // fn a() { c(); } + // fn ep_1() { a(); b(); } + // fn ep_2() { c();} + // + // c -> {ep_1, ep_2} + // a -> {ep_1} + // b -> {ep_1, ep_2} + // ep_1 -> {} + // ep_2 -> {} + + ast::VariableList params; + auto func_b = std::make_unique("b", std::move(params), &f32); + auto* func_b_ptr = func_b.get(); + + auto body = std::make_unique(); + func_b->set_body(std::move(body)); + + auto func_c = std::make_unique("c", std::move(params), &f32); + auto* func_c_ptr = func_c.get(); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("second"), + std::make_unique( + std::make_unique("b"), + ast::ExpressionList{}))); + func_c->set_body(std::move(body)); + + auto func_a = std::make_unique("a", std::move(params), &f32); + auto* func_a_ptr = func_a.get(); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("first"), + std::make_unique( + std::make_unique("c"), + ast::ExpressionList{}))); + func_a->set_body(std::move(body)); + + auto ep_1 = std::make_unique("ep_1", std::move(params), &f32); + ep_1->add_decoration( + std::make_unique(ast::PipelineStage::kVertex)); + auto* ep_1_ptr = ep_1.get(); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("call_a"), + std::make_unique( + std::make_unique("a"), + ast::ExpressionList{}))); + body->append(std::make_unique( + std::make_unique("call_b"), + std::make_unique( + std::make_unique("b"), + ast::ExpressionList{}))); + ep_1->set_body(std::move(body)); + + auto ep_2 = std::make_unique("ep_2", std::move(params), &f32); + ep_2->add_decoration( + std::make_unique(ast::PipelineStage::kVertex)); + auto* ep_2_ptr = ep_2.get(); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("call_c"), + std::make_unique( + std::make_unique("c"), + ast::ExpressionList{}))); + ep_2->set_body(std::move(body)); + + mod()->AddFunction(std::move(func_b)); + mod()->AddFunction(std::move(func_c)); + mod()->AddFunction(std::move(func_a)); + mod()->AddFunction(std::move(ep_1)); + mod()->AddFunction(std::move(ep_2)); + + // Register the functions and calculate the callers + ASSERT_TRUE(td()->Determine()) << td()->error(); + + const auto& b_eps = func_b_ptr->ancestor_entry_points(); + ASSERT_EQ(2u, b_eps.size()); + EXPECT_EQ("ep_1", b_eps[0]); + EXPECT_EQ("ep_2", b_eps[1]); + + const auto& a_eps = func_a_ptr->ancestor_entry_points(); + ASSERT_EQ(1u, a_eps.size()); + EXPECT_EQ("ep_1", a_eps[0]); + + const auto& c_eps = func_c_ptr->ancestor_entry_points(); + ASSERT_EQ(2u, c_eps.size()); + EXPECT_EQ("ep_1", c_eps[0]); + EXPECT_EQ("ep_2", c_eps[1]); + + EXPECT_TRUE(ep_1_ptr->ancestor_entry_points().empty()); + EXPECT_TRUE(ep_2_ptr->ancestor_entry_points().empty()); +} + } // namespace } // namespace tint