[spirv-writer] Only add used variables to entry point.

This Cl updates the entry point code to only output Input/Output
variabes which are referenced by the function instead of all
Input/Output variables.

Bug: tint:28
Change-Id: Idc429e02cac8dac7fc7b609cbd7f88039695829e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23623
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-06-22 20:52:24 +00:00 committed by David Neto
parent 194e0cca3b
commit 13d2a3b96c
10 changed files with 296 additions and 50 deletions

View File

@ -42,6 +42,15 @@ Function::Function(Function&&) = default;
Function::~Function() = default;
void Function::add_referenced_module_variable(Variable* var) {
for (const auto* v : referenced_module_vars_) {
if (v->name() == var->name()) {
return;
}
}
referenced_module_vars_.push_back(var);
}
bool Function::IsValid() const {
for (const auto& param : params_) {
if (param == nullptr || !param->IsValid())

View File

@ -68,6 +68,15 @@ class Function : public Node {
/// @returns the function params
const VariableList& params() const { return params_; }
/// Adds the given variable to the list of referenced module variables if it
/// is not already included.
/// @param var the module variable to add
void add_referenced_module_variable(Variable* var);
/// @returns the referenced module variables
const std::vector<Variable*>& referenced_module_variables() const {
return referenced_module_vars_;
}
/// Sets the return type of the function
/// @param type the return type
void set_return_type(type::Type* type) { return_type_ = type; }
@ -98,6 +107,7 @@ class Function : public Node {
VariableList params_;
type::Type* return_type_ = nullptr;
StatementList body_;
std::vector<Variable*> referenced_module_vars_;
};
/// A list of unique functions

View File

@ -57,6 +57,26 @@ TEST_F(FunctionTest, Creation_WithSource) {
EXPECT_EQ(src.column, 2u);
}
TEST_F(FunctionTest, AddDuplicateReferencedVariables) {
type::VoidType void_type;
type::I32Type i32;
Variable v("var", StorageClass::kInput, &i32);
Function f("func", VariableList{}, &void_type);
f.add_referenced_module_variable(&v);
ASSERT_EQ(f.referenced_module_variables().size(), 1u);
EXPECT_EQ(f.referenced_module_variables()[0], &v);
f.add_referenced_module_variable(&v);
ASSERT_EQ(f.referenced_module_variables().size(), 1u);
Variable v2("var2", StorageClass::kOutput, &i32);
f.add_referenced_module_variable(&v2);
ASSERT_EQ(f.referenced_module_variables().size(), 2u);
EXPECT_EQ(f.referenced_module_variables()[1], &v2);
}
TEST_F(FunctionTest, IsValid) {
type::VoidType void_type;
type::I32Type i32;

View File

@ -105,7 +105,7 @@ class Variable : public Node {
/// @param name the name to set
void set_name(const std::string& name) { name_ = name; }
/// @returns the variable name
const std::string& name() { return name_; }
const std::string& name() const { return name_; }
/// Sets the value type if a const or formal parameter, or the
/// store type if a var.

View File

@ -161,6 +161,19 @@ void TypeDeterminer::set_error(const Source& src, const std::string& msg) {
error_ += msg;
}
void TypeDeterminer::set_referenced_from_function_if_needed(
ast::Variable* var) {
if (current_function_ == nullptr) {
return;
}
if (var->storage_class() == ast::StorageClass::kNone ||
var->storage_class() == ast::StorageClass::kFunction) {
return;
}
current_function_->add_referenced_module_variable(var);
}
bool TypeDeterminer::Determine() {
for (const auto& var : mod_->global_variables()) {
variable_stack_.set_global(var->name(), var.get());
@ -190,6 +203,8 @@ bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
name_to_function_[func->name()] = func;
current_function_ = func;
variable_stack_.push_scope();
for (const auto& param : func->params()) {
variable_stack_.set(param->name(), param.get());
@ -200,6 +215,8 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) {
}
variable_stack_.pop_scope();
current_function_ = nullptr;
return true;
}
@ -567,6 +584,8 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
var->type(), var->storage_class())));
}
set_referenced_from_function_if_needed(var);
return true;
}

View File

@ -104,6 +104,7 @@ class TypeDeterminer {
private:
void set_error(const Source& src, const std::string& msg);
void set_referenced_from_function_if_needed(ast::Variable* var);
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
bool DetermineAs(ast::AsExpression* expr);
@ -121,6 +122,7 @@ class TypeDeterminer {
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
std::unordered_map<std::string, ast::Function*> name_to_function_;
ast::Function* current_function_ = nullptr;
};
} // namespace tint

View File

@ -743,6 +743,93 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
EXPECT_TRUE(ident.result_type()->IsF32());
}
TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) {
ast::type::F32Type f32;
auto in_var = std::make_unique<ast::Variable>(
"in_var", ast::StorageClass::kInput, &f32);
auto out_var = std::make_unique<ast::Variable>(
"out_var", ast::StorageClass::kOutput, &f32);
auto sb_var = std::make_unique<ast::Variable>(
"sb_var", ast::StorageClass::kStorageBuffer, &f32);
auto wg_var = std::make_unique<ast::Variable>(
"wg_var", ast::StorageClass::kWorkgroup, &f32);
auto priv_var = std::make_unique<ast::Variable>(
"priv_var", ast::StorageClass::kPrivate, &f32);
auto in_ptr = in_var.get();
auto out_ptr = out_var.get();
auto sb_ptr = sb_var.get();
auto wg_ptr = wg_var.get();
auto priv_ptr = priv_var.get();
mod()->AddGlobalVariable(std::move(in_var));
mod()->AddGlobalVariable(std::move(out_var));
mod()->AddGlobalVariable(std::move(sb_var));
mod()->AddGlobalVariable(std::move(wg_var));
mod()->AddGlobalVariable(std::move(priv_var));
ast::VariableList params;
auto func =
std::make_unique<ast::Function>("my_func", std::move(params), &f32);
auto func_ptr = func.get();
ast::StatementList body;
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("out_var"),
std::make_unique<ast::IdentifierExpression>("in_var")));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("wg_var"),
std::make_unique<ast::IdentifierExpression>("wg_var")));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("sb_var"),
std::make_unique<ast::IdentifierExpression>("sb_var")));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("priv_var"),
std::make_unique<ast::IdentifierExpression>("priv_var")));
func->set_body(std::move(body));
mod()->AddFunction(std::move(func));
// Register the function
EXPECT_TRUE(td()->Determine());
const auto& vars = func_ptr->referenced_module_variables();
ASSERT_EQ(vars.size(), 5);
EXPECT_EQ(vars[0], out_ptr);
EXPECT_EQ(vars[1], in_ptr);
EXPECT_EQ(vars[2], wg_ptr);
EXPECT_EQ(vars[3], sb_ptr);
EXPECT_EQ(vars[4], priv_ptr);
}
TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) {
ast::type::F32Type f32;
auto var = std::make_unique<ast::Variable>(
"in_var", ast::StorageClass::kFunction, &f32);
ast::VariableList params;
auto func =
std::make_unique<ast::Function>("my_func", std::move(params), &f32);
auto func_ptr = func.get();
ast::StatementList body;
body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("var"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f))));
func->set_body(std::move(body));
mod()->AddFunction(std::move(func));
// Register the function
EXPECT_TRUE(td()->Determine());
EXPECT_EQ(func_ptr->referenced_module_variables().size(), 0);
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
ast::type::I32Type i32;
ast::type::F32Type f32;

View File

@ -164,6 +164,8 @@ bool Builder::Build() {
}
}
// Note, the entry points must be generated after the functions as they need
// to be able to lookup the function information based on the name.
for (const auto& ep : mod_->entry_points()) {
if (!GenerateEntryPoint(ep.get())) {
return false;
@ -296,10 +298,16 @@ bool Builder::GenerateEntryPoint(ast::EntryPoint* ep) {
OperandList operands = {Operand::Int(stage), Operand::Int(id),
Operand::String(name)};
// TODO(dsinclair): This could be made smarter by only listing the
// input/output variables which are used by the entry point instead of just
// listing all module scoped variables of type input/output.
for (const auto& var : mod_->global_variables()) {
auto* func = func_name_to_func_[ep->function_name()];
if (func == nullptr) {
error_ = "processing an entry point when the function has not been seen.";
return false;
}
for (const auto* var : func->referenced_module_variables()) {
// For SPIR-V 1.3 we only output Input/output variables. If we update to
// SPIR-V 1.4 or later this should be all variables.
if (var->storage_class() != ast::StorageClass::kInput &&
var->storage_class() != ast::StorageClass::kOutput) {
continue;
@ -425,6 +433,7 @@ bool Builder::GenerateFunction(ast::Function* func) {
scope_stack_.pop_scope();
func_name_to_id_[func->name()] = func_id;
func_name_to_func_[func->name()] = func;
return true;
}

View File

@ -84,36 +84,6 @@ class Builder {
return id;
}
/// Sets the id for a given function name
/// @param name the name to set
/// @param id the id to set
void set_func_name_to_id(const std::string& name, uint32_t id) {
func_name_to_id_[name] = id;
}
/// Retrives the id for the given function name
/// @param name the function name to search for
/// @returns the id for the given name or 0 on failure
uint32_t id_for_func_name(const std::string& name) {
if (func_name_to_id_.count(name) == 0) {
return 0;
}
return func_name_to_id_[name];
}
/// Retrieves the id for an entry point function, or 0 if not found.
/// Emits an error if not found.
/// @param ep the entry point
/// @returns 0 on error
uint32_t id_for_entry_point(ast::EntryPoint* ep) {
auto id = id_for_func_name(ep->function_name());
if (id == 0) {
error_ = "unable to find ID for function: " + ep->function_name();
return 0;
}
return id;
}
/// Iterates over all the instructions in the correct order and calls the
/// given callback
/// @param cb the callback to execute
@ -402,6 +372,29 @@ class Builder {
/// automatically.
Operand result_op();
/// Retrives the id for the given function name
/// @param name the function name to search for
/// @returns the id for the given name or 0 on failure
uint32_t id_for_func_name(const std::string& name) {
if (func_name_to_id_.count(name) == 0) {
return 0;
}
return func_name_to_id_[name];
}
/// Retrieves the id for an entry point function, or 0 if not found.
/// Emits an error if not found.
/// @param ep the entry point
/// @returns 0 on error
uint32_t id_for_entry_point(ast::EntryPoint* ep) {
auto id = id_for_func_name(ep->function_name());
if (id == 0) {
error_ = "unable to find ID for function: " + ep->function_name();
return 0;
}
return id;
}
ast::Module* mod_;
std::string error_;
uint32_t next_id_ = 1;
@ -415,6 +408,7 @@ class Builder {
std::unordered_map<std::string, uint32_t> import_name_to_id_;
std::unordered_map<std::string, uint32_t> func_name_to_id_;
std::unordered_map<std::string, ast::Function*> func_name_to_func_;
std::unordered_map<std::string, uint32_t> type_name_to_id_;
std::unordered_map<std::string, uint32_t> const_to_id_;
ScopeStack<uint32_t> scope_stack_;

View File

@ -17,10 +17,16 @@
#include "gtest/gtest.h"
#include "spirv/unified1/spirv.h"
#include "spirv/unified1/spirv.hpp11"
#include "src/ast/assignment_statement.h"
#include "src/ast/entry_point.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/pipeline_stage.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/void_type.h"
#include "src/ast/variable.h"
#include "src/context.h"
#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
@ -32,24 +38,30 @@ namespace {
using BuilderTest = testing::Test;
TEST_F(BuilderTest, EntryPoint) {
ast::type::VoidType void_type;
ast::Function func("frag_main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main");
ast::Module mod;
Builder b(&mod);
b.set_func_name_to_id("frag_main", 2);
ASSERT_TRUE(b.GenerateEntryPoint(&ep));
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %2 "main"
EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %3 "main"
)");
}
TEST_F(BuilderTest, EntryPoint_WithoutName) {
ast::type::VoidType void_type;
ast::Function func("compute_main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kCompute, "", "compute_main");
ast::Module mod;
Builder b(&mod);
b.set_func_name_to_id("compute_main", 3);
ASSERT_TRUE(b.GenerateEntryPoint(&ep));
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
EXPECT_EQ(DumpInstructions(b.preamble()),
R"(OpEntryPoint GLCompute %3 "compute_main"
@ -77,12 +89,15 @@ using EntryPointStageTest = testing::TestWithParam<EntryPointStageData>;
TEST_P(EntryPointStageTest, Emit) {
auto params = GetParam();
ast::type::VoidType void_type;
ast::Function func("main", {}, &void_type);
ast::EntryPoint ep(params.stage, "", "main");
ast::Module mod;
Builder b(&mod);
b.set_func_name_to_id("main", 3);
ASSERT_TRUE(b.GenerateEntryPoint(&ep));
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
auto preamble = b.preamble();
ASSERT_EQ(preamble.size(), 1u);
@ -101,8 +116,12 @@ INSTANTIATE_TEST_SUITE_P(
EntryPointStageData{ast::PipelineStage::kCompute,
SpvExecutionModelGLCompute}));
TEST_F(BuilderTest, EntryPoint_WithInterfaceIds) {
TEST_F(BuilderTest, EntryPoint_WithUnusedInterfaceIds) {
ast::type::F32Type f32;
ast::type::VoidType void_type;
ast::Function func("main", {}, &void_type);
auto v_in =
std::make_unique<ast::Variable>("my_in", ast::StorageClass::kInput, &f32);
auto v_out = std::make_unique<ast::Variable>(
@ -121,11 +140,12 @@ TEST_F(BuilderTest, EntryPoint_WithInterfaceIds) {
mod.AddGlobalVariable(std::move(v_out));
mod.AddGlobalVariable(std::move(v_wg));
b.set_func_name_to_id("main", 3);
ASSERT_TRUE(b.GenerateEntryPoint(&ep));
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in"
OpName %4 "my_out"
OpName %7 "my_wg"
OpName %11 "main"
)");
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
%2 = OpTypePointer Input %3
@ -135,35 +155,111 @@ OpName %7 "my_wg"
%4 = OpVariable %5 Output %6
%8 = OpTypePointer Workgroup %3
%7 = OpVariable %8 Workgroup
%10 = OpTypeVoid
%9 = OpTypeFunction %10
)");
EXPECT_EQ(DumpInstructions(b.preamble()),
R"(OpEntryPoint Vertex %3 "main" %1 %4
R"(OpEntryPoint Vertex %11 "main"
)");
}
TEST_F(BuilderTest, EntryPoint_WithUsedInterfaceIds) {
ast::type::F32Type f32;
ast::type::VoidType void_type;
ast::Function func("main", {}, &void_type);
ast::StatementList body;
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("my_out"),
std::make_unique<ast::IdentifierExpression>("my_in")));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("my_wg"),
std::make_unique<ast::IdentifierExpression>("my_wg")));
// Add duplicate usages so we show they don't get output multiple times.
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("my_out"),
std::make_unique<ast::IdentifierExpression>("my_in")));
func.set_body(std::move(body));
auto v_in =
std::make_unique<ast::Variable>("my_in", ast::StorageClass::kInput, &f32);
auto v_out = std::make_unique<ast::Variable>(
"my_out", ast::StorageClass::kOutput, &f32);
auto v_wg = std::make_unique<ast::Variable>(
"my_wg", ast::StorageClass::kWorkgroup, &f32);
ast::EntryPoint ep(ast::PipelineStage::kVertex, "", "main");
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(v_in.get());
td.RegisterVariableForTesting(v_out.get());
td.RegisterVariableForTesting(v_wg.get());
ASSERT_TRUE(td.DetermineFunction(&func)) << td.error();
Builder b(&mod);
EXPECT_TRUE(b.GenerateGlobalVariable(v_in.get())) << b.error();
EXPECT_TRUE(b.GenerateGlobalVariable(v_out.get())) << b.error();
EXPECT_TRUE(b.GenerateGlobalVariable(v_wg.get())) << b.error();
mod.AddGlobalVariable(std::move(v_in));
mod.AddGlobalVariable(std::move(v_out));
mod.AddGlobalVariable(std::move(v_wg));
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error();
EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in"
OpName %4 "my_out"
OpName %7 "my_wg"
OpName %11 "main"
)");
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
%2 = OpTypePointer Input %3
%1 = OpVariable %2 Input
%5 = OpTypePointer Output %3
%6 = OpConstantNull %3
%4 = OpVariable %5 Output %6
%8 = OpTypePointer Workgroup %3
%7 = OpVariable %8 Workgroup
%10 = OpTypeVoid
%9 = OpTypeFunction %10
)");
EXPECT_EQ(DumpInstructions(b.preamble()),
R"(OpEntryPoint Vertex %11 "main" %4 %1
)");
}
TEST_F(BuilderTest, ExecutionModel_Fragment_OriginUpperLeft) {
ast::type::VoidType void_type;
ast::Function func("frag_main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main");
ast::Module mod;
Builder b(&mod);
b.set_func_name_to_id("frag_main", 2);
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateExecutionModes(&ep));
EXPECT_EQ(DumpInstructions(b.preamble()),
R"(OpExecutionMode %2 OriginUpperLeft
R"(OpExecutionMode %3 OriginUpperLeft
)");
}
TEST_F(BuilderTest, ExecutionModel_Compute_LocalSize) {
ast::type::VoidType void_type;
ast::Function func("main", {}, &void_type);
ast::EntryPoint ep(ast::PipelineStage::kCompute, "main", "main");
ast::Module mod;
Builder b(&mod);
b.set_func_name_to_id("main", 2);
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
ASSERT_TRUE(b.GenerateExecutionModes(&ep));
EXPECT_EQ(DumpInstructions(b.preamble()),
R"(OpExecutionMode %2 LocalSize 1 1 1
R"(OpExecutionMode %3 LocalSize 1 1 1
)");
}