[type-determiner] Update to work with entry point and function stages.

This Cl updates the type determiner to work with both styles of entry
point definition.

Change-Id: Ic48f1a5f0a5820821f9a74380896426a97483049
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28666
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
This commit is contained in:
dan sinclair 2020-09-21 17:51:31 +00:00 committed by Commit Bot service account
parent 767ea855ab
commit 0592643782
4 changed files with 129 additions and 0 deletions

View File

@ -17,6 +17,7 @@
#include <sstream> #include <sstream>
#include "src/ast/decorated_variable.h" #include "src/ast/decorated_variable.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/workgroup_decoration.h" #include "src/ast/workgroup_decoration.h"
namespace tint { namespace tint {
@ -56,6 +57,15 @@ std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
return {1, 1, 1}; return {1, 1, 1};
} }
ast::PipelineStage Function::pipeline_stage() const {
for (const auto& deco : decorations_) {
if (deco->IsStage()) {
return deco->AsStage()->value();
}
}
return ast::PipelineStage::kNone;
}
void Function::add_referenced_module_variable(Variable* var) { void Function::add_referenced_module_variable(Variable* var) {
for (const auto* v : referenced_module_vars_) { for (const auto* v : referenced_module_vars_) {
if (v->name() == var->name()) { if (v->name() == var->name()) {

View File

@ -28,6 +28,7 @@
#include "src/ast/function_decoration.h" #include "src/ast/function_decoration.h"
#include "src/ast/location_decoration.h" #include "src/ast/location_decoration.h"
#include "src/ast/node.h" #include "src/ast/node.h"
#include "src/ast/pipeline_stage.h"
#include "src/ast/set_decoration.h" #include "src/ast/set_decoration.h"
#include "src/ast/statement.h" #include "src/ast/statement.h"
#include "src/ast/type/type.h" #include "src/ast/type/type.h"
@ -100,6 +101,9 @@ class Function : public Node {
/// return if no workgroup size was set. /// return if no workgroup size was set.
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const; std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
/// @returns the functions pipeline stage or None if not set
ast::PipelineStage pipeline_stage() const;
/// Adds the given variable to the list of referenced module variables if it /// Adds the given variable to the list of referenced module variables if it
/// is not already included. /// is not already included.
/// @param var the module variable to add /// @param var the module variable to add

View File

@ -217,6 +217,17 @@ bool TypeDeterminer::Determine() {
} }
} }
// Walk over the caller to callee information and update functions with which
// entry points call those functions.
for (const auto& func : mod_->functions()) {
if (func->pipeline_stage() == ast::PipelineStage::kNone) {
continue;
}
for (const auto& callee : caller_to_callee_[func->name()]) {
set_entry_points(callee, func->name());
}
}
return true; return true;
} }

View File

@ -37,9 +37,11 @@
#include "src/ast/if_statement.h" #include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h" #include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h" #include "src/ast/member_accessor_expression.h"
#include "src/ast/pipeline_stage.h"
#include "src/ast/return_statement.h" #include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h" #include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h" #include "src/ast/sint_literal.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/struct.h" #include "src/ast/struct.h"
#include "src/ast/struct_member.h" #include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
@ -4479,5 +4481,107 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints) {
EXPECT_TRUE(ep_2_func_ptr->ancestor_entry_points().empty()); EXPECT_TRUE(ep_2_func_ptr->ancestor_entry_points().empty());
} }
TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
ast::type::F32Type f32;
// fn b() {}
// fn c() { b(); }
// fn a() { c(); }
// fn ep_1() { a(); b(); }
// fn ep_2() { c();}
//
// c -> {ep_1, ep_2}
// a -> {ep_1}
// b -> {ep_1, ep_2}
// ep_1 -> {}
// ep_2 -> {}
ast::VariableList params;
auto func_b = std::make_unique<ast::Function>("b", std::move(params), &f32);
auto* func_b_ptr = func_b.get();
auto body = std::make_unique<ast::BlockStatement>();
func_b->set_body(std::move(body));
auto func_c = std::make_unique<ast::Function>("c", std::move(params), &f32);
auto* func_c_ptr = func_c.get();
body = std::make_unique<ast::BlockStatement>();
body->append(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("second"),
std::make_unique<ast::CallExpression>(
std::make_unique<ast::IdentifierExpression>("b"),
ast::ExpressionList{})));
func_c->set_body(std::move(body));
auto func_a = std::make_unique<ast::Function>("a", std::move(params), &f32);
auto* func_a_ptr = func_a.get();
body = std::make_unique<ast::BlockStatement>();
body->append(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("first"),
std::make_unique<ast::CallExpression>(
std::make_unique<ast::IdentifierExpression>("c"),
ast::ExpressionList{})));
func_a->set_body(std::move(body));
auto ep_1 = std::make_unique<ast::Function>("ep_1", std::move(params), &f32);
ep_1->add_decoration(
std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex));
auto* ep_1_ptr = ep_1.get();
body = std::make_unique<ast::BlockStatement>();
body->append(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("call_a"),
std::make_unique<ast::CallExpression>(
std::make_unique<ast::IdentifierExpression>("a"),
ast::ExpressionList{})));
body->append(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("call_b"),
std::make_unique<ast::CallExpression>(
std::make_unique<ast::IdentifierExpression>("b"),
ast::ExpressionList{})));
ep_1->set_body(std::move(body));
auto ep_2 = std::make_unique<ast::Function>("ep_2", std::move(params), &f32);
ep_2->add_decoration(
std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex));
auto* ep_2_ptr = ep_2.get();
body = std::make_unique<ast::BlockStatement>();
body->append(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("call_c"),
std::make_unique<ast::CallExpression>(
std::make_unique<ast::IdentifierExpression>("c"),
ast::ExpressionList{})));
ep_2->set_body(std::move(body));
mod()->AddFunction(std::move(func_b));
mod()->AddFunction(std::move(func_c));
mod()->AddFunction(std::move(func_a));
mod()->AddFunction(std::move(ep_1));
mod()->AddFunction(std::move(ep_2));
// Register the functions and calculate the callers
ASSERT_TRUE(td()->Determine()) << td()->error();
const auto& b_eps = func_b_ptr->ancestor_entry_points();
ASSERT_EQ(2u, b_eps.size());
EXPECT_EQ("ep_1", b_eps[0]);
EXPECT_EQ("ep_2", b_eps[1]);
const auto& a_eps = func_a_ptr->ancestor_entry_points();
ASSERT_EQ(1u, a_eps.size());
EXPECT_EQ("ep_1", a_eps[0]);
const auto& c_eps = func_c_ptr->ancestor_entry_points();
ASSERT_EQ(2u, c_eps.size());
EXPECT_EQ("ep_1", c_eps[0]);
EXPECT_EQ("ep_2", c_eps[1]);
EXPECT_TRUE(ep_1_ptr->ancestor_entry_points().empty());
EXPECT_TRUE(ep_2_ptr->ancestor_entry_points().empty());
}
} // namespace } // namespace
} // namespace tint } // namespace tint