[metal-writer] Add entry point support.

This CL adds preliminary entry point support to the Metal backend.

Bug: tint:8
Change-Id: I7b904621d706d4503d5054711de64872f79cf2fa
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23708
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-06-23 18:22:28 +00:00
parent 6366f68121
commit ec3e2d4abd
3 changed files with 105 additions and 7 deletions

View File

@ -44,12 +44,16 @@ GeneratorImpl::GeneratorImpl() = default;
GeneratorImpl::~GeneratorImpl() = default; GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate(const ast::Module& module) { bool GeneratorImpl::Generate(const ast::Module& module) {
module_ = &module;
for (const auto& func : module.functions()) { for (const auto& func : module.functions()) {
if (!EmitFunction(func.get())) { if (!EmitFunction(func.get())) {
return false; return false;
} }
out_ << std::endl; out_ << std::endl;
} }
module_ = nullptr;
return true; return true;
} }
@ -227,14 +231,51 @@ bool GeneratorImpl::EmitExpression(ast::Expression* expr) {
return false; return false;
} }
void GeneratorImpl::EmitStage(ast::PipelineStage stage) {
switch (stage) {
case ast::PipelineStage::kFragment:
out_ << "fragment";
break;
case ast::PipelineStage::kVertex:
out_ << "vertex";
break;
case ast::PipelineStage::kCompute:
out_ << "kernel";
break;
case ast::PipelineStage::kNone:
break;
}
return;
}
bool GeneratorImpl::EmitFunction(ast::Function* func) { bool GeneratorImpl::EmitFunction(ast::Function* func) {
make_indent(); make_indent();
// TODO(dsinclair): Technically this is wrong as you could, in theory, have
// multiple entry points pointing at the same function. I'm ignoring that for
// now. It will either go away with the entry_point changes in the spec
// or we'll have to figure out how to deal with it.
auto name = func->name();
for (const auto& ep : module_->entry_points()) {
if (ep->function_name() == name) {
EmitStage(ep->stage());
out_ << " ";
if (!ep->name().empty()) {
name = ep->name();
}
break;
}
}
if (!EmitType(func->return_type(), "")) { if (!EmitType(func->return_type(), "")) {
return false; return false;
} }
out_ << " " << func->name() << "("; out_ << " " << name << "(";
bool first = true; bool first = true;
for (const auto& v : func->params()) { for (const auto& v : func->params()) {

View File

@ -76,6 +76,9 @@ class GeneratorImpl : public TextGenerator {
/// @param expr the scalar constructor expression /// @param expr the scalar constructor expression
/// @returns true if the scalar constructor is emitted /// @returns true if the scalar constructor is emitted
bool EmitScalarConstructor(ast::ScalarConstructorExpression* expr); bool EmitScalarConstructor(ast::ScalarConstructorExpression* expr);
/// Handles emitting a pipeline stage name
/// @param stage the stage to emit
void EmitStage(ast::PipelineStage stage);
/// Handles a brace-enclosed list of statements. /// Handles a brace-enclosed list of statements.
/// @param statements the statements to output /// @param statements the statements to output
/// @returns true if the statements were emitted /// @returns true if the statements were emitted
@ -97,6 +100,9 @@ class GeneratorImpl : public TextGenerator {
/// @param expr the type constructor expression /// @param expr the type constructor expression
/// @returns true if the constructor is emitted /// @returns true if the constructor is emitted
bool EmitTypeConstructor(ast::TypeConstructorExpression* expr); bool EmitTypeConstructor(ast::TypeConstructorExpression* expr);
private:
const ast::Module* module_ = nullptr;
}; };
} // namespace msl } // namespace msl

View File

@ -14,6 +14,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/ast/function.h" #include "src/ast/function.h"
#include "src/ast/module.h"
#include "src/ast/return_statement.h" #include "src/ast/return_statement.h"
#include "src/ast/type/array_type.h" #include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
@ -32,19 +33,24 @@ using MslGeneratorImplTest = testing::Test;
TEST_F(MslGeneratorImplTest, Emit_Function) { TEST_F(MslGeneratorImplTest, Emit_Function) {
ast::type::VoidType void_type; ast::type::VoidType void_type;
ast::Function func("my_func", {}, &void_type); auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{},
&void_type);
ast::StatementList body; ast::StatementList body;
body.push_back(std::make_unique<ast::ReturnStatement>()); body.push_back(std::make_unique<ast::ReturnStatement>());
func.set_body(std::move(body)); func->set_body(std::move(body));
ast::Module m;
m.AddFunction(std::move(func));
GeneratorImpl g; GeneratorImpl g;
g.increment_indent(); g.increment_indent();
ASSERT_TRUE(g.EmitFunction(&func)); ASSERT_TRUE(g.Generate(m)) << g.error();
EXPECT_EQ(g.result(), R"( void my_func() { EXPECT_EQ(g.result(), R"( void my_func() {
return; return;
} }
)"); )");
} }
@ -59,19 +65,64 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) {
std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32)); std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
ast::type::VoidType void_type; ast::type::VoidType void_type;
ast::Function func("my_func", std::move(params), &void_type); auto func =
std::make_unique<ast::Function>("my_func", std::move(params), &void_type);
ast::StatementList body; ast::StatementList body;
body.push_back(std::make_unique<ast::ReturnStatement>()); body.push_back(std::make_unique<ast::ReturnStatement>());
func.set_body(std::move(body)); func->set_body(std::move(body));
ast::Module m;
m.AddFunction(std::move(func));
GeneratorImpl g; GeneratorImpl g;
g.increment_indent(); g.increment_indent();
ASSERT_TRUE(g.EmitFunction(&func)); ASSERT_TRUE(g.Generate(m)) << g.error();
EXPECT_EQ(g.result(), R"( void my_func(float a, int b) { EXPECT_EQ(g.result(), R"( void my_func(float a, int b) {
return; return;
} }
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_NoName) {
ast::type::VoidType void_type;
auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
&void_type);
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
"frag_main");
ast::Module m;
m.AddFunction(std::move(func));
m.AddEntryPoint(std::move(ep));
GeneratorImpl g;
ASSERT_TRUE(g.Generate(m)) << g.error();
EXPECT_EQ(g.result(), R"(fragment void frag_main() {
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) {
ast::type::VoidType void_type;
auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{},
&void_type);
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
"main", "comp_main");
ast::Module m;
m.AddFunction(std::move(func));
m.AddEntryPoint(std::move(ep));
GeneratorImpl g;
ASSERT_TRUE(g.Generate(m)) << g.error();
EXPECT_EQ(g.result(), R"(kernel void main() {
}
)"); )");
} }