[msl-writer] Generate entry point functions.
This CL generates entry point functions and duplicate functions as needed to call from the entry points. Bug: tint:8 Change-Id: I8092ce463248e7a887c26ae05b0774e8fa21ab94 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24764 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
c5a5f9666f
commit
df415a8919
|
@ -26,6 +26,15 @@ DecoratedVariable::DecoratedVariable(DecoratedVariable&&) = default;
|
||||||
|
|
||||||
DecoratedVariable::~DecoratedVariable() = default;
|
DecoratedVariable::~DecoratedVariable() = default;
|
||||||
|
|
||||||
|
bool DecoratedVariable::HasLocationDecoration() const {
|
||||||
|
for (const auto& deco : decorations_) {
|
||||||
|
if (deco->IsLocation()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool DecoratedVariable::IsDecorated() const {
|
bool DecoratedVariable::IsDecorated() const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,9 @@ class DecoratedVariable : public Variable {
|
||||||
/// @returns the decorations attached to this variable
|
/// @returns the decorations attached to this variable
|
||||||
const VariableDecorationList& decorations() const { return decorations_; }
|
const VariableDecorationList& decorations() const { return decorations_; }
|
||||||
|
|
||||||
|
/// @returns true if the decorations include a LocationDecoration
|
||||||
|
bool HasLocationDecoration() const;
|
||||||
|
|
||||||
/// @returns true if this is a decorated variable
|
/// @returns true if this is a decorated variable
|
||||||
bool IsDecorated() const override;
|
bool IsDecorated() const override;
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,14 @@ Function* Module::FindFunctionByName(const std::string& name) const {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Module::IsFunctionEntryPoint(const std::string& name) const {
|
||||||
|
for (const auto& ep : entry_points_) {
|
||||||
|
if (ep->function_name() == name)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool Module::IsValid() const {
|
bool Module::IsValid() const {
|
||||||
for (const auto& import : imports_) {
|
for (const auto& import : imports_) {
|
||||||
if (import == nullptr || !import->IsValid()) {
|
if (import == nullptr || !import->IsValid()) {
|
||||||
|
|
|
@ -65,6 +65,11 @@ class Module {
|
||||||
/// @returns the entry points in the module
|
/// @returns the entry points in the module
|
||||||
const EntryPointList& entry_points() const { return entry_points_; }
|
const EntryPointList& entry_points() const { return entry_points_; }
|
||||||
|
|
||||||
|
/// Checks if the given function name is an entry point function
|
||||||
|
/// @param name the function name
|
||||||
|
/// @returns true if name is an entry point function
|
||||||
|
bool IsFunctionEntryPoint(const std::string& name) const;
|
||||||
|
|
||||||
/// Adds a type alias to the module
|
/// Adds a type alias to the module
|
||||||
/// @param type the alias to add
|
/// @param type the alias to add
|
||||||
void AddAliasType(type::AliasType* type) { alias_types_.push_back(type); }
|
void AddAliasType(type::AliasType* type) { alias_types_.push_back(type); }
|
||||||
|
|
|
@ -91,6 +91,19 @@ TEST_F(ModuleTest, LookupFunction) {
|
||||||
EXPECT_EQ(func_ptr, m.FindFunctionByName("main"));
|
EXPECT_EQ(func_ptr, m.FindFunctionByName("main"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ModuleTest, IsEntryPoint) {
|
||||||
|
type::F32Type f32;
|
||||||
|
Module m;
|
||||||
|
|
||||||
|
auto func = std::make_unique<Function>("other_func", VariableList{}, &f32);
|
||||||
|
m.AddFunction(std::move(func));
|
||||||
|
|
||||||
|
m.AddEntryPoint(
|
||||||
|
std::make_unique<EntryPoint>(PipelineStage::kVertex, "main", "vtx_main"));
|
||||||
|
EXPECT_TRUE(m.IsFunctionEntryPoint("vtx_main"));
|
||||||
|
EXPECT_FALSE(m.IsFunctionEntryPoint("other_func"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ModuleTest, LookupFunctionMissing) {
|
TEST_F(ModuleTest, LookupFunctionMissing) {
|
||||||
Module m;
|
Module m;
|
||||||
EXPECT_EQ(nullptr, m.FindFunctionByName("Missing"));
|
EXPECT_EQ(nullptr, m.FindFunctionByName("Missing"));
|
||||||
|
|
|
@ -59,6 +59,8 @@ namespace {
|
||||||
|
|
||||||
const char kInStructNameSuffix[] = "in";
|
const char kInStructNameSuffix[] = "in";
|
||||||
const char kOutStructNameSuffix[] = "out";
|
const char kOutStructNameSuffix[] = "out";
|
||||||
|
const char kTintStructInVarPrefix[] = "tint_in";
|
||||||
|
const char kTintStructOutVarPrefix[] = "tint_out";
|
||||||
|
|
||||||
bool last_is_break_or_fallthrough(const ast::StatementList& stmts) {
|
bool last_is_break_or_fallthrough(const ast::StatementList& stmts) {
|
||||||
if (stmts.empty()) {
|
if (stmts.empty()) {
|
||||||
|
@ -78,13 +80,11 @@ void GeneratorImpl::set_module_for_testing(ast::Module* mod) {
|
||||||
module_ = mod;
|
module_ = mod;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GeneratorImpl::generate_struct_name(ast::EntryPoint* ep,
|
std::string GeneratorImpl::generate_name(const std::string& prefix) {
|
||||||
const std::string& type) {
|
std::string name = prefix;
|
||||||
std::string base_name = ep->function_name() + "_" + type;
|
|
||||||
std::string name = base_name;
|
|
||||||
uint32_t i = 0;
|
uint32_t i = 0;
|
||||||
while (namer_.IsMapped(name)) {
|
while (namer_.IsMapped(name)) {
|
||||||
name = base_name + "_" + std::to_string(i);
|
name = prefix + "_" + std::to_string(i);
|
||||||
++i;
|
++i;
|
||||||
}
|
}
|
||||||
namer_.RegisterRemappedName(name);
|
namer_.RegisterRemappedName(name);
|
||||||
|
@ -96,6 +96,10 @@ bool GeneratorImpl::Generate(const ast::Module& module) {
|
||||||
|
|
||||||
out_ << "#include <metal_stdlib>" << std::endl << std::endl;
|
out_ << "#include <metal_stdlib>" << std::endl << std::endl;
|
||||||
|
|
||||||
|
for (const auto& global : module.global_variables()) {
|
||||||
|
global_variables_.set(global->name(), global.get());
|
||||||
|
}
|
||||||
|
|
||||||
for (auto* const alias : module.alias_types()) {
|
for (auto* const alias : module.alias_types()) {
|
||||||
if (!EmitAliasType(alias)) {
|
if (!EmitAliasType(alias)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -106,7 +110,7 @@ bool GeneratorImpl::Generate(const ast::Module& module) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& ep : module.entry_points()) {
|
for (const auto& ep : module.entry_points()) {
|
||||||
if (!EmitEntryPoint(ep.get())) {
|
if (!EmitEntryPointData(ep.get())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -115,6 +119,12 @@ bool GeneratorImpl::Generate(const ast::Module& module) {
|
||||||
if (!EmitFunction(func.get())) {
|
if (!EmitFunction(func.get())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& ep : module.entry_points()) {
|
||||||
|
if (!EmitEntryPointFunction(ep.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
out_ << std::endl;
|
out_ << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -283,12 +293,32 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ident->has_path()) {
|
if (!ident->has_path()) {
|
||||||
if (!EmitExpression(expr->func())) {
|
auto name = ident->name();
|
||||||
return false;
|
auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
|
||||||
|
if (it != ep_func_name_remapped_.end()) {
|
||||||
|
name = it->second;
|
||||||
}
|
}
|
||||||
out_ << "(";
|
out_ << name << "(";
|
||||||
|
|
||||||
bool first = true;
|
bool first = true;
|
||||||
|
|
||||||
|
auto in_it = ep_name_to_in_data_.find(current_ep_name_);
|
||||||
|
if (in_it != ep_name_to_in_data_.end()) {
|
||||||
|
out_ << in_it->second.var_name;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_it = ep_name_to_out_data_.find(current_ep_name_);
|
||||||
|
if (out_it != ep_name_to_out_data_.end()) {
|
||||||
|
if (!first) {
|
||||||
|
out_ << ", ";
|
||||||
|
}
|
||||||
|
out_ << out_it->second.var_name;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(dsinclair): Emit builtins
|
||||||
|
|
||||||
const auto& params = expr->params();
|
const auto& params = expr->params();
|
||||||
for (const auto& param : params) {
|
for (const auto& param : params) {
|
||||||
if (!first) {
|
if (!first) {
|
||||||
|
@ -459,7 +489,7 @@ bool GeneratorImpl::EmitLiteral(ast::Literal* lit) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) {
|
bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
|
||||||
auto* func = module_->FindFunctionByName(ep->function_name());
|
auto* func = module_->FindFunctionByName(ep->function_name());
|
||||||
if (func == nullptr) {
|
if (func == nullptr) {
|
||||||
error_ = "Unable to find entry point function: " + ep->function_name();
|
error_ = "Unable to find entry point function: " + ep->function_name();
|
||||||
|
@ -491,9 +521,20 @@ bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto ep_name = ep->name();
|
||||||
|
if (ep_name.empty()) {
|
||||||
|
ep_name = ep->function_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(dsinclair): There is a potential bug here. Entry points can have the
|
||||||
|
// same name in WGSL if they have different pipeline stages. This does not
|
||||||
|
// take that into account and will emit duplicate struct names. I'm ignoring
|
||||||
|
// this until https://github.com/gpuweb/gpuweb/issues/662 is resolved as it
|
||||||
|
// may remove this issue and entry point names will need to be unique.
|
||||||
if (!in_locations.empty()) {
|
if (!in_locations.empty()) {
|
||||||
auto in_struct_name = generate_struct_name(ep, kInStructNameSuffix);
|
auto in_struct_name = generate_name(ep_name + "_" + kInStructNameSuffix);
|
||||||
ep_name_to_in_struct_[ep->name()] = in_struct_name;
|
auto in_var_name = generate_name(kTintStructInVarPrefix);
|
||||||
|
ep_name_to_in_data_[ep_name] = {in_struct_name, in_var_name};
|
||||||
|
|
||||||
make_indent();
|
make_indent();
|
||||||
out_ << "struct " << in_struct_name << " {" << std::endl;
|
out_ << "struct " << in_struct_name << " {" << std::endl;
|
||||||
|
@ -527,8 +568,9 @@ bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!out_locations.empty()) {
|
if (!out_locations.empty()) {
|
||||||
auto out_struct_name = generate_struct_name(ep, kOutStructNameSuffix);
|
auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix);
|
||||||
ep_name_to_out_struct_[ep->name()] = out_struct_name;
|
auto out_var_name = generate_name(kTintStructOutVarPrefix);
|
||||||
|
ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name};
|
||||||
|
|
||||||
make_indent();
|
make_indent();
|
||||||
out_ << "struct " << out_struct_name << " {" << std::endl;
|
out_ << "struct " << out_struct_name << " {" << std::endl;
|
||||||
|
@ -615,33 +657,82 @@ void GeneratorImpl::EmitStage(ast::PipelineStage stage) {
|
||||||
bool GeneratorImpl::EmitFunction(ast::Function* func) {
|
bool GeneratorImpl::EmitFunction(ast::Function* func) {
|
||||||
make_indent();
|
make_indent();
|
||||||
|
|
||||||
// TODO(dsinclair): Technically this is wrong as you could, in theory, have
|
// Entry points will be emitted later, skip for now.
|
||||||
// multiple entry points pointing at the same function. I'm ignoring that for
|
if (module_->IsFunctionEntryPoint(func->name())) {
|
||||||
// now. It will either go away with the entry_point changes in the spec
|
return true;
|
||||||
// or we'll have to figure out how to deal with it.
|
|
||||||
|
|
||||||
auto name = func->name();
|
|
||||||
|
|
||||||
for (const auto& ep : module_->entry_points()) {
|
|
||||||
if (ep->function_name() == name) {
|
|
||||||
EmitStage(ep->stage());
|
|
||||||
out_ << " ";
|
|
||||||
|
|
||||||
if (!ep->name().empty()) {
|
|
||||||
name = ep->name();
|
|
||||||
}
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(dsinclair): This could be smarter. If the input/outputs for multiple
|
||||||
|
// entry points are the same we could generate a single struct and then have
|
||||||
|
// this determine it's the same struct and just emit once.
|
||||||
|
bool emit_duplicate_functions =
|
||||||
|
func->ancestor_entry_points().size() > 0 &&
|
||||||
|
func->referenced_module_variables().size() > 0;
|
||||||
|
|
||||||
|
if (emit_duplicate_functions) {
|
||||||
|
for (const auto& ep_name : func->ancestor_entry_points()) {
|
||||||
|
if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out_ << std::endl;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Emit as non-duplicated
|
||||||
|
if (!EmitFunctionInternal(func, false, "")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out_ << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
|
||||||
|
bool emit_duplicate_functions,
|
||||||
|
const std::string& ep_name) {
|
||||||
|
auto name = func->name();
|
||||||
|
|
||||||
if (!EmitType(func->return_type(), "")) {
|
if (!EmitType(func->return_type(), "")) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
out_ << " " << namer_.NameFor(name) << "(";
|
out_ << " ";
|
||||||
|
|
||||||
|
if (emit_duplicate_functions) {
|
||||||
|
name = generate_name(name + "_" + ep_name);
|
||||||
|
ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
|
||||||
|
} else {
|
||||||
|
name = namer_.NameFor(name);
|
||||||
|
}
|
||||||
|
out_ << name << "(";
|
||||||
|
|
||||||
bool first = true;
|
bool first = true;
|
||||||
|
|
||||||
|
// If we're emitting duplicate functions that means the function takes
|
||||||
|
// the stage_in or stage_out value from the entry point, emit them.
|
||||||
|
//
|
||||||
|
// We emit both of them if they're there regardless of if they're both used.
|
||||||
|
if (emit_duplicate_functions) {
|
||||||
|
auto in_it = ep_name_to_in_data_.find(ep_name);
|
||||||
|
if (in_it != ep_name_to_in_data_.end()) {
|
||||||
|
out_ << "thread " << in_it->second.struct_name << "& "
|
||||||
|
<< in_it->second.var_name;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_it = ep_name_to_out_data_.find(ep_name);
|
||||||
|
if (out_it != ep_name_to_out_data_.end()) {
|
||||||
|
if (!first) {
|
||||||
|
out_ << ", ";
|
||||||
|
}
|
||||||
|
out_ << "thread " << out_it->second.struct_name << "& "
|
||||||
|
<< out_it->second.var_name;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(dsinclair): Handle any entry point builtin params used here
|
||||||
|
|
||||||
for (const auto& v : func->params()) {
|
for (const auto& v : func->params()) {
|
||||||
if (!first) {
|
if (!first) {
|
||||||
out_ << ", ";
|
out_ << ", ";
|
||||||
|
@ -656,9 +747,79 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
|
||||||
out_ << " " << v->name();
|
out_ << " " << v->name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out_ << ")";
|
out_ << ")";
|
||||||
|
|
||||||
return EmitStatementBlockAndNewline(func->body());
|
current_ep_name_ = ep_name;
|
||||||
|
|
||||||
|
if (!EmitStatementBlockAndNewline(func->body())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
current_ep_name_ = "";
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
|
||||||
|
make_indent();
|
||||||
|
|
||||||
|
current_ep_name_ = ep->name();
|
||||||
|
if (current_ep_name_.empty()) {
|
||||||
|
current_ep_name_ = ep->function_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* func = module_->FindFunctionByName(ep->function_name());
|
||||||
|
if (func == nullptr) {
|
||||||
|
error_ = "unable to find function for entry point: " + ep->function_name();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
EmitStage(ep->stage());
|
||||||
|
out_ << " ";
|
||||||
|
|
||||||
|
// This is an entry point, the return type is the entry point output structure
|
||||||
|
// if one exists, or void otherwise.
|
||||||
|
auto out_data = ep_name_to_out_data_.find(current_ep_name_);
|
||||||
|
bool has_out_data = out_data != ep_name_to_out_data_.end();
|
||||||
|
if (has_out_data) {
|
||||||
|
out_ << out_data->second.struct_name;
|
||||||
|
} else {
|
||||||
|
out_ << "void";
|
||||||
|
}
|
||||||
|
out_ << " " << namer_.NameFor(current_ep_name_) << "(";
|
||||||
|
|
||||||
|
auto in_data = ep_name_to_in_data_.find(current_ep_name_);
|
||||||
|
if (in_data != ep_name_to_in_data_.end()) {
|
||||||
|
out_ << in_data->second.struct_name << " " << in_data->second.var_name
|
||||||
|
<< " [[stage_in]]";
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(dsinclair): Output other builtin inputs
|
||||||
|
out_ << ") {" << std::endl;
|
||||||
|
|
||||||
|
increment_indent();
|
||||||
|
|
||||||
|
if (has_out_data) {
|
||||||
|
make_indent();
|
||||||
|
out_ << out_data->second.struct_name << " " << out_data->second.var_name
|
||||||
|
<< " = {};" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
generating_entry_point_ = true;
|
||||||
|
for (const auto& s : func->body()) {
|
||||||
|
if (!EmitStatement(s.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
generating_entry_point_ = false;
|
||||||
|
|
||||||
|
decrement_indent();
|
||||||
|
make_indent();
|
||||||
|
out_ << "}" << std::endl;
|
||||||
|
|
||||||
|
current_ep_name_ = "";
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
|
bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
|
||||||
|
@ -668,7 +829,30 @@ bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
|
||||||
error_ = "Identifier paths not handled yet.";
|
error_ = "Identifier paths not handled yet.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ast::Variable* var = nullptr;
|
||||||
|
if (global_variables_.get(ident->name(), &var)) {
|
||||||
|
if (var->storage_class() == ast::StorageClass::kInput &&
|
||||||
|
var->IsDecorated() && var->AsDecorated()->HasLocationDecoration()) {
|
||||||
|
auto it = ep_name_to_in_data_.find(current_ep_name_);
|
||||||
|
if (it == ep_name_to_in_data_.end()) {
|
||||||
|
error_ = "unable to find entry point data for input";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out_ << it->second.var_name << ".";
|
||||||
|
} else if (var->storage_class() == ast::StorageClass::kOutput &&
|
||||||
|
var->IsDecorated() &&
|
||||||
|
var->AsDecorated()->HasLocationDecoration()) {
|
||||||
|
auto it = ep_name_to_out_data_.find(current_ep_name_);
|
||||||
|
if (it == ep_name_to_out_data_.end()) {
|
||||||
|
error_ = "unable to find entry point data for output";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out_ << it->second.var_name << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
out_ << namer_.NameFor(ident->name());
|
out_ << namer_.NameFor(ident->name());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -785,7 +969,13 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
|
||||||
make_indent();
|
make_indent();
|
||||||
|
|
||||||
out_ << "return";
|
out_ << "return";
|
||||||
if (stmt->has_value()) {
|
|
||||||
|
if (generating_entry_point_) {
|
||||||
|
auto out_data = ep_name_to_out_data_.find(current_ep_name_);
|
||||||
|
if (out_data != ep_name_to_out_data_.end()) {
|
||||||
|
out_ << " " << out_data->second.var_name;
|
||||||
|
}
|
||||||
|
} else if (stmt->has_value()) {
|
||||||
out_ << " ";
|
out_ << " ";
|
||||||
if (!EmitExpression(stmt->value())) {
|
if (!EmitExpression(stmt->value())) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "src/ast/module.h"
|
#include "src/ast/module.h"
|
||||||
#include "src/ast/scalar_constructor_expression.h"
|
#include "src/ast/scalar_constructor_expression.h"
|
||||||
#include "src/ast/type_constructor_expression.h"
|
#include "src/ast/type_constructor_expression.h"
|
||||||
|
#include "src/scope_stack.h"
|
||||||
#include "src/writer/msl/namer.h"
|
#include "src/writer/msl/namer.h"
|
||||||
#include "src/writer/text_generator.h"
|
#include "src/writer/text_generator.h"
|
||||||
|
|
||||||
|
@ -93,7 +94,11 @@ class GeneratorImpl : public TextGenerator {
|
||||||
/// Handles emitting information for an entry point
|
/// Handles emitting information for an entry point
|
||||||
/// @param ep the entry point
|
/// @param ep the entry point
|
||||||
/// @returns true if the entry point data was emitted
|
/// @returns true if the entry point data was emitted
|
||||||
bool EmitEntryPoint(ast::EntryPoint* ep);
|
bool EmitEntryPointData(ast::EntryPoint* ep);
|
||||||
|
/// Handles emitting the entry point function
|
||||||
|
/// @param ep the entry point
|
||||||
|
/// @returns true if the entry point function was emitted
|
||||||
|
bool EmitEntryPointFunction(ast::EntryPoint* ep);
|
||||||
/// Handles generate an Expression
|
/// Handles generate an Expression
|
||||||
/// @param expr the expression
|
/// @param expr the expression
|
||||||
/// @returns true if the expression was emitted
|
/// @returns true if the expression was emitted
|
||||||
|
@ -102,6 +107,15 @@ class GeneratorImpl : public TextGenerator {
|
||||||
/// @param func the function to generate
|
/// @param func the function to generate
|
||||||
/// @returns true if the function was emitted
|
/// @returns true if the function was emitted
|
||||||
bool EmitFunction(ast::Function* func);
|
bool EmitFunction(ast::Function* func);
|
||||||
|
/// Internal helper for emitting functions
|
||||||
|
/// @param func the function to emit
|
||||||
|
/// @param emit_duplicate_functions set true if we need to duplicate per entry
|
||||||
|
/// point
|
||||||
|
/// @param ep_name the current entry point or blank if none set
|
||||||
|
/// @returns true if the function was emitted.
|
||||||
|
bool EmitFunctionInternal(ast::Function* func,
|
||||||
|
bool emit_duplicate_functions,
|
||||||
|
const std::string& ep_name);
|
||||||
/// Handles generating an identifier expression
|
/// Handles generating an identifier expression
|
||||||
/// @param expr the identifier expression
|
/// @param expr the identifier expression
|
||||||
/// @returns true if the identifeir was emitted
|
/// @returns true if the identifeir was emitted
|
||||||
|
@ -179,22 +193,33 @@ class GeneratorImpl : public TextGenerator {
|
||||||
/// @param mod the module to set.
|
/// @param mod the module to set.
|
||||||
void set_module_for_testing(ast::Module* mod);
|
void set_module_for_testing(ast::Module* mod);
|
||||||
|
|
||||||
/// Generates a name for the input struct
|
/// Generates a name for the prefix
|
||||||
/// @param ep the entry point to generate for
|
/// @param prefix the prefix of the name to generate
|
||||||
/// @param type the type of struct to generate
|
/// @returns the name
|
||||||
/// @returns the input struct name
|
std::string generate_name(const std::string& prefix);
|
||||||
std::string generate_struct_name(ast::EntryPoint* ep,
|
|
||||||
const std::string& type);
|
|
||||||
|
|
||||||
/// @returns the namer for testing
|
/// @returns the namer for testing
|
||||||
Namer* namer_for_testing() { return &namer_; }
|
Namer* namer_for_testing() { return &namer_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Namer namer_;
|
Namer namer_;
|
||||||
|
ScopeStack<ast::Variable*> global_variables_;
|
||||||
|
std::string current_ep_name_;
|
||||||
|
bool generating_entry_point_ = false;
|
||||||
const ast::Module* module_ = nullptr;
|
const ast::Module* module_ = nullptr;
|
||||||
uint32_t loop_emission_counter_ = 0;
|
uint32_t loop_emission_counter_ = 0;
|
||||||
std::unordered_map<std::string, std::string> ep_name_to_in_struct_;
|
|
||||||
std::unordered_map<std::string, std::string> ep_name_to_out_struct_;
|
struct EntryPointData {
|
||||||
|
std::string struct_name;
|
||||||
|
std::string var_name;
|
||||||
|
};
|
||||||
|
std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_;
|
||||||
|
std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_;
|
||||||
|
|
||||||
|
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
|
||||||
|
// function name. If there is no entry for a given key then function did
|
||||||
|
// not need to be remapped for the entry point and can be emitted directly.
|
||||||
|
std::unordered_map<std::string, std::string> ep_func_name_remapped_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace msl
|
} // namespace msl
|
||||||
|
|
|
@ -33,7 +33,7 @@ namespace {
|
||||||
|
|
||||||
using MslGeneratorImplTest = testing::Test;
|
using MslGeneratorImplTest = testing::Test;
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) {
|
TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Input) {
|
||||||
// [[location 0]] var<in> foo : f32;
|
// [[location 0]] var<in> foo : f32;
|
||||||
// [[location 1]] var<in> bar : i32;
|
// [[location 1]] var<in> bar : i32;
|
||||||
//
|
//
|
||||||
|
@ -81,8 +81,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) {
|
||||||
|
|
||||||
mod.AddFunction(std::move(func));
|
mod.AddFunction(std::move(func));
|
||||||
|
|
||||||
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
|
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex, "",
|
||||||
"main", "vtx_main");
|
"vtx_main");
|
||||||
auto* ep_ptr = ep.get();
|
auto* ep_ptr = ep.get();
|
||||||
|
|
||||||
mod.AddEntryPoint(std::move(ep));
|
mod.AddEntryPoint(std::move(ep));
|
||||||
|
@ -91,7 +91,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) {
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
g.set_module_for_testing(&mod);
|
g.set_module_for_testing(&mod);
|
||||||
ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
|
ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
|
||||||
EXPECT_EQ(g.result(), R"(struct vtx_main_in {
|
EXPECT_EQ(g.result(), R"(struct vtx_main_in {
|
||||||
float foo [[attribute(0)]];
|
float foo [[attribute(0)]];
|
||||||
int bar [[attribute(1)]];
|
int bar [[attribute(1)]];
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) {
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) {
|
TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Output) {
|
||||||
// [[location 0]] var<out> foo : f32;
|
// [[location 0]] var<out> foo : f32;
|
||||||
// [[location 1]] var<out> bar : i32;
|
// [[location 1]] var<out> bar : i32;
|
||||||
//
|
//
|
||||||
|
@ -148,8 +148,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) {
|
||||||
|
|
||||||
mod.AddFunction(std::move(func));
|
mod.AddFunction(std::move(func));
|
||||||
|
|
||||||
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
|
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex, "",
|
||||||
"main", "vtx_main");
|
"vtx_main");
|
||||||
auto* ep_ptr = ep.get();
|
auto* ep_ptr = ep.get();
|
||||||
|
|
||||||
mod.AddEntryPoint(std::move(ep));
|
mod.AddEntryPoint(std::move(ep));
|
||||||
|
@ -158,7 +158,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) {
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
g.set_module_for_testing(&mod);
|
g.set_module_for_testing(&mod);
|
||||||
ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
|
ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
|
||||||
EXPECT_EQ(g.result(), R"(struct vtx_main_out {
|
EXPECT_EQ(g.result(), R"(struct vtx_main_out {
|
||||||
float foo [[user(locn0)]];
|
float foo [[user(locn0)]];
|
||||||
int bar [[user(locn1)]];
|
int bar [[user(locn1)]];
|
||||||
|
@ -167,7 +167,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) {
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) {
|
TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Input) {
|
||||||
// [[location 0]] var<in> foo : f32;
|
// [[location 0]] var<in> foo : f32;
|
||||||
// [[location 1]] var<in> bar : i32;
|
// [[location 1]] var<in> bar : i32;
|
||||||
//
|
//
|
||||||
|
@ -225,8 +225,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) {
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
g.set_module_for_testing(&mod);
|
g.set_module_for_testing(&mod);
|
||||||
ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
|
ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
|
||||||
EXPECT_EQ(g.result(), R"(struct frag_main_in {
|
EXPECT_EQ(g.result(), R"(struct main_in {
|
||||||
float foo [[user(locn0)]];
|
float foo [[user(locn0)]];
|
||||||
int bar [[user(locn1)]];
|
int bar [[user(locn1)]];
|
||||||
};
|
};
|
||||||
|
@ -234,7 +234,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) {
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) {
|
TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Output) {
|
||||||
// [[location 0]] var<out> foo : f32;
|
// [[location 0]] var<out> foo : f32;
|
||||||
// [[location 1]] var<out> bar : i32;
|
// [[location 1]] var<out> bar : i32;
|
||||||
//
|
//
|
||||||
|
@ -292,8 +292,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) {
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
g.set_module_for_testing(&mod);
|
g.set_module_for_testing(&mod);
|
||||||
ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
|
ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
|
||||||
EXPECT_EQ(g.result(), R"(struct frag_main_out {
|
EXPECT_EQ(g.result(), R"(struct main_out {
|
||||||
float foo [[color(0)]];
|
float foo [[color(0)]];
|
||||||
int bar [[color(1)]];
|
int bar [[color(1)]];
|
||||||
};
|
};
|
||||||
|
@ -301,7 +301,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) {
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Input) {
|
TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Input) {
|
||||||
// [[location 0]] var<in> foo : f32;
|
// [[location 0]] var<in> foo : f32;
|
||||||
// [[location 1]] var<in> bar : i32;
|
// [[location 1]] var<in> bar : i32;
|
||||||
//
|
//
|
||||||
|
@ -356,11 +356,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Input) {
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
g.set_module_for_testing(&mod);
|
g.set_module_for_testing(&mod);
|
||||||
ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error();
|
ASSERT_FALSE(g.EmitEntryPointData(ep_ptr)) << g.error();
|
||||||
EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
|
EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Output) {
|
TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Output) {
|
||||||
// [[location 0]] var<out> foo : f32;
|
// [[location 0]] var<out> foo : f32;
|
||||||
// [[location 1]] var<out> bar : i32;
|
// [[location 1]] var<out> bar : i32;
|
||||||
//
|
//
|
||||||
|
@ -415,7 +415,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Output) {
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
g.set_module_for_testing(&mod);
|
g.set_module_for_testing(&mod);
|
||||||
ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error();
|
ASSERT_FALSE(g.EmitEntryPointData(ep_ptr)) << g.error();
|
||||||
EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
|
EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,14 +13,27 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "src/ast/assignment_statement.h"
|
||||||
|
#include "src/ast/binary_expression.h"
|
||||||
|
#include "src/ast/call_expression.h"
|
||||||
|
#include "src/ast/decorated_variable.h"
|
||||||
|
#include "src/ast/float_literal.h"
|
||||||
#include "src/ast/function.h"
|
#include "src/ast/function.h"
|
||||||
|
#include "src/ast/identifier_expression.h"
|
||||||
|
#include "src/ast/if_statement.h"
|
||||||
|
#include "src/ast/location_decoration.h"
|
||||||
#include "src/ast/module.h"
|
#include "src/ast/module.h"
|
||||||
#include "src/ast/return_statement.h"
|
#include "src/ast/return_statement.h"
|
||||||
|
#include "src/ast/scalar_constructor_expression.h"
|
||||||
|
#include "src/ast/sint_literal.h"
|
||||||
#include "src/ast/type/array_type.h"
|
#include "src/ast/type/array_type.h"
|
||||||
#include "src/ast/type/f32_type.h"
|
#include "src/ast/type/f32_type.h"
|
||||||
#include "src/ast/type/i32_type.h"
|
#include "src/ast/type/i32_type.h"
|
||||||
#include "src/ast/type/void_type.h"
|
#include "src/ast/type/void_type.h"
|
||||||
#include "src/ast/variable.h"
|
#include "src/ast/variable.h"
|
||||||
|
#include "src/ast/variable_decl_statement.h"
|
||||||
|
#include "src/context.h"
|
||||||
|
#include "src/type_determiner.h"
|
||||||
#include "src/writer/msl/generator_impl.h"
|
#include "src/writer/msl/generator_impl.h"
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
@ -138,6 +151,415 @@ fragment void frag_main() {
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOutVars) {
|
||||||
|
ast::type::VoidType void_type;
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
auto foo_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
|
||||||
|
|
||||||
|
ast::VariableDecorationList decos;
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(0));
|
||||||
|
foo_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
auto bar_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(1));
|
||||||
|
bar_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ast::Module mod;
|
||||||
|
TypeDeterminer td(&ctx, &mod);
|
||||||
|
td.RegisterVariableForTesting(foo_var.get());
|
||||||
|
td.RegisterVariableForTesting(bar_var.get());
|
||||||
|
|
||||||
|
mod.AddGlobalVariable(std::move(foo_var));
|
||||||
|
mod.AddGlobalVariable(std::move(bar_var));
|
||||||
|
|
||||||
|
ast::VariableList params;
|
||||||
|
auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
|
||||||
|
&void_type);
|
||||||
|
|
||||||
|
ast::StatementList body;
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("bar"),
|
||||||
|
std::make_unique<ast::IdentifierExpression>("foo")));
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>());
|
||||||
|
func->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(func));
|
||||||
|
|
||||||
|
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
|
||||||
|
"frag_main");
|
||||||
|
mod.AddEntryPoint(std::move(ep));
|
||||||
|
|
||||||
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
|
|
||||||
|
GeneratorImpl g;
|
||||||
|
ASSERT_TRUE(g.Generate(mod)) << g.error();
|
||||||
|
EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
|
||||||
|
|
||||||
|
struct frag_main_in {
|
||||||
|
float foo [[user(locn0)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct frag_main_out {
|
||||||
|
float bar [[color(1)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
fragment frag_main_out frag_main(frag_main_in tint_in [[stage_in]]) {
|
||||||
|
frag_main_out tint_out = {};
|
||||||
|
tint_out.bar = tint_in.foo;
|
||||||
|
return tint_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MslGeneratorImplTest,
|
||||||
|
Emit_Function_Called_By_EntryPoints_WithGlobals_And_Params) {
|
||||||
|
ast::type::VoidType void_type;
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
auto foo_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
|
||||||
|
|
||||||
|
ast::VariableDecorationList decos;
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(0));
|
||||||
|
foo_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
auto bar_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(1));
|
||||||
|
bar_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
auto val_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("val", ast::StorageClass::kOutput, &f32));
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(0));
|
||||||
|
val_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ast::Module mod;
|
||||||
|
TypeDeterminer td(&ctx, &mod);
|
||||||
|
td.RegisterVariableForTesting(foo_var.get());
|
||||||
|
td.RegisterVariableForTesting(bar_var.get());
|
||||||
|
td.RegisterVariableForTesting(val_var.get());
|
||||||
|
|
||||||
|
mod.AddGlobalVariable(std::move(foo_var));
|
||||||
|
mod.AddGlobalVariable(std::move(bar_var));
|
||||||
|
mod.AddGlobalVariable(std::move(val_var));
|
||||||
|
|
||||||
|
ast::VariableList params;
|
||||||
|
params.push_back(std::make_unique<ast::Variable>(
|
||||||
|
"param", ast::StorageClass::kFunction, &f32));
|
||||||
|
auto sub_func =
|
||||||
|
std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
|
||||||
|
|
||||||
|
ast::StatementList body;
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("bar"),
|
||||||
|
std::make_unique<ast::IdentifierExpression>("foo")));
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("val"),
|
||||||
|
std::make_unique<ast::IdentifierExpression>("param")));
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("foo")));
|
||||||
|
sub_func->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(sub_func));
|
||||||
|
|
||||||
|
auto func_1 = std::make_unique<ast::Function>("frag_1_main",
|
||||||
|
std::move(params), &void_type);
|
||||||
|
|
||||||
|
ast::ExpressionList expr;
|
||||||
|
expr.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("bar"),
|
||||||
|
std::make_unique<ast::CallExpression>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("sub_func"),
|
||||||
|
std::move(expr))));
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>());
|
||||||
|
func_1->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(func_1));
|
||||||
|
|
||||||
|
auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
|
||||||
|
"ep_1", "frag_1_main");
|
||||||
|
mod.AddEntryPoint(std::move(ep1));
|
||||||
|
|
||||||
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
|
|
||||||
|
GeneratorImpl g;
|
||||||
|
ASSERT_TRUE(g.Generate(mod)) << g.error();
|
||||||
|
EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
|
||||||
|
|
||||||
|
struct ep_1_in {
|
||||||
|
float foo [[user(locn0)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ep_1_out {
|
||||||
|
float bar [[color(1)]];
|
||||||
|
float val [[color(0)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
float sub_func_ep_1(thread ep_1_in& tint_in, thread ep_1_out& tint_out, float param) {
|
||||||
|
tint_out.bar = tint_in.foo;
|
||||||
|
tint_out.val = param;
|
||||||
|
return tint_in.foo;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) {
|
||||||
|
ep_1_out tint_out = {};
|
||||||
|
tint_out.bar = sub_func_ep_1(tint_in, tint_out, 1.00000000f);
|
||||||
|
return tint_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) {
|
||||||
|
ast::type::VoidType void_type;
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
auto foo_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
|
||||||
|
|
||||||
|
ast::VariableDecorationList decos;
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(0));
|
||||||
|
foo_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
auto bar_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(1));
|
||||||
|
bar_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ast::Module mod;
|
||||||
|
TypeDeterminer td(&ctx, &mod);
|
||||||
|
td.RegisterVariableForTesting(foo_var.get());
|
||||||
|
td.RegisterVariableForTesting(bar_var.get());
|
||||||
|
|
||||||
|
mod.AddGlobalVariable(std::move(foo_var));
|
||||||
|
mod.AddGlobalVariable(std::move(bar_var));
|
||||||
|
|
||||||
|
ast::VariableList params;
|
||||||
|
auto sub_func =
|
||||||
|
std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
|
||||||
|
|
||||||
|
ast::StatementList body;
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("bar"),
|
||||||
|
std::make_unique<ast::IdentifierExpression>("foo")));
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("foo")));
|
||||||
|
sub_func->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(sub_func));
|
||||||
|
|
||||||
|
auto func_1 = std::make_unique<ast::Function>("frag_1_main",
|
||||||
|
std::move(params), &void_type);
|
||||||
|
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("bar"),
|
||||||
|
std::make_unique<ast::CallExpression>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("sub_func"),
|
||||||
|
ast::ExpressionList{})));
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>());
|
||||||
|
func_1->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(func_1));
|
||||||
|
|
||||||
|
auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
|
||||||
|
"ep_1", "frag_1_main");
|
||||||
|
auto ep2 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
|
||||||
|
"ep_2", "frag_1_main");
|
||||||
|
mod.AddEntryPoint(std::move(ep1));
|
||||||
|
mod.AddEntryPoint(std::move(ep2));
|
||||||
|
|
||||||
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
|
|
||||||
|
GeneratorImpl g;
|
||||||
|
ASSERT_TRUE(g.Generate(mod)) << g.error();
|
||||||
|
EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
|
||||||
|
|
||||||
|
struct ep_1_in {
|
||||||
|
float foo [[user(locn0)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ep_1_out {
|
||||||
|
float bar [[color(1)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ep_2_in {
|
||||||
|
float foo [[user(locn0)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ep_2_out {
|
||||||
|
float bar [[color(1)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
float sub_func_ep_1(thread ep_1_in& tint_in, thread ep_1_out& tint_out) {
|
||||||
|
tint_out.bar = tint_in.foo;
|
||||||
|
return tint_in.foo;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sub_func_ep_2(thread ep_2_in& tint_in, thread ep_2_out& tint_out) {
|
||||||
|
tint_out.bar = tint_in.foo;
|
||||||
|
return tint_in.foo;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) {
|
||||||
|
ep_1_out tint_out = {};
|
||||||
|
tint_out.bar = sub_func_ep_1(tint_in, tint_out);
|
||||||
|
return tint_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment ep_2_out ep_2(ep_2_in tint_in [[stage_in]]) {
|
||||||
|
ep_2_out tint_out = {};
|
||||||
|
tint_out.bar = sub_func_ep_2(tint_in, tint_out);
|
||||||
|
return tint_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MslGeneratorImplTest,
|
||||||
|
Emit_Function_EntryPoints_WithGlobal_Nested_Return) {
|
||||||
|
ast::type::VoidType void_type;
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
|
auto bar_var = std::make_unique<ast::DecoratedVariable>(
|
||||||
|
std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
|
||||||
|
ast::VariableDecorationList decos;
|
||||||
|
decos.push_back(std::make_unique<ast::LocationDecoration>(1));
|
||||||
|
bar_var->set_decorations(std::move(decos));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ast::Module mod;
|
||||||
|
TypeDeterminer td(&ctx, &mod);
|
||||||
|
td.RegisterVariableForTesting(bar_var.get());
|
||||||
|
mod.AddGlobalVariable(std::move(bar_var));
|
||||||
|
|
||||||
|
ast::VariableList params;
|
||||||
|
auto func_1 = std::make_unique<ast::Function>("frag_1_main",
|
||||||
|
std::move(params), &void_type);
|
||||||
|
|
||||||
|
ast::StatementList body;
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("bar"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::FloatLiteral>(&f32, 1.0f))));
|
||||||
|
|
||||||
|
ast::StatementList list;
|
||||||
|
list.push_back(std::make_unique<ast::ReturnStatement>());
|
||||||
|
body.push_back(std::make_unique<ast::IfStatement>(
|
||||||
|
std::make_unique<ast::BinaryExpression>(
|
||||||
|
ast::BinaryOp::kEqual,
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::SintLiteral>(&i32, 1)),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::SintLiteral>(&i32, 1))),
|
||||||
|
std::move(list)));
|
||||||
|
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>());
|
||||||
|
func_1->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(func_1));
|
||||||
|
|
||||||
|
auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
|
||||||
|
"ep_1", "frag_1_main");
|
||||||
|
mod.AddEntryPoint(std::move(ep1));
|
||||||
|
|
||||||
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
|
|
||||||
|
GeneratorImpl g;
|
||||||
|
ASSERT_TRUE(g.Generate(mod)) << g.error();
|
||||||
|
EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
|
||||||
|
|
||||||
|
struct ep_1_out {
|
||||||
|
float bar [[color(1)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
fragment ep_1_out ep_1() {
|
||||||
|
ep_1_out tint_out = {};
|
||||||
|
tint_out.bar = 1.00000000f;
|
||||||
|
if ((1 == 1)) {
|
||||||
|
return tint_out;
|
||||||
|
}
|
||||||
|
return tint_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MslGeneratorImplTest,
|
||||||
|
Emit_Function_Called_Two_EntryPoints_WithoutGlobals) {
|
||||||
|
ast::type::VoidType void_type;
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ast::Module mod;
|
||||||
|
TypeDeterminer td(&ctx, &mod);
|
||||||
|
|
||||||
|
ast::VariableList params;
|
||||||
|
auto sub_func =
|
||||||
|
std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
|
||||||
|
|
||||||
|
ast::StatementList body;
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>(
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::FloatLiteral>(&f32, 1.0))));
|
||||||
|
sub_func->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(sub_func));
|
||||||
|
|
||||||
|
auto func_1 = std::make_unique<ast::Function>("frag_1_main",
|
||||||
|
std::move(params), &void_type);
|
||||||
|
|
||||||
|
body.push_back(std::make_unique<ast::VariableDeclStatement>(
|
||||||
|
std::make_unique<ast::Variable>("foo", ast::StorageClass::kFunction,
|
||||||
|
&f32)));
|
||||||
|
body.back()->AsVariableDecl()->variable()->set_constructor(
|
||||||
|
std::make_unique<ast::CallExpression>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("sub_func"),
|
||||||
|
ast::ExpressionList{}));
|
||||||
|
|
||||||
|
body.push_back(std::make_unique<ast::ReturnStatement>());
|
||||||
|
func_1->set_body(std::move(body));
|
||||||
|
|
||||||
|
mod.AddFunction(std::move(func_1));
|
||||||
|
|
||||||
|
auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
|
||||||
|
"ep_1", "frag_1_main");
|
||||||
|
auto ep2 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
|
||||||
|
"ep_2", "frag_1_main");
|
||||||
|
mod.AddEntryPoint(std::move(ep1));
|
||||||
|
mod.AddEntryPoint(std::move(ep2));
|
||||||
|
|
||||||
|
ASSERT_TRUE(td.Determine()) << td.error();
|
||||||
|
|
||||||
|
GeneratorImpl g;
|
||||||
|
ASSERT_TRUE(g.Generate(mod)) << g.error();
|
||||||
|
EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
|
||||||
|
|
||||||
|
float sub_func() {
|
||||||
|
return 1.00000000f;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment void ep_1() {
|
||||||
|
float foo = sub_func();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment void ep_2() {
|
||||||
|
float foo = sub_func();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
)");
|
||||||
|
}
|
||||||
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) {
|
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) {
|
||||||
ast::type::VoidType void_type;
|
ast::type::VoidType void_type;
|
||||||
|
|
||||||
|
|
|
@ -51,29 +51,23 @@ compute void my_func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, InputStructName) {
|
TEST_F(MslGeneratorImplTest, InputStructName) {
|
||||||
ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
|
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in");
|
ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, InputStructName_ConflictWithExisting) {
|
TEST_F(MslGeneratorImplTest, InputStructName_ConflictWithExisting) {
|
||||||
ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
|
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
|
|
||||||
// Register the struct name as existing.
|
// Register the struct name as existing.
|
||||||
auto* namer = g.namer_for_testing();
|
auto* namer = g.namer_for_testing();
|
||||||
namer->NameFor("func_main_out");
|
namer->NameFor("func_main_out");
|
||||||
|
|
||||||
ASSERT_EQ(g.generate_struct_name(&ep, "out"), "func_main_out_0");
|
ASSERT_EQ(g.generate_name("func_main_out"), "func_main_out_0");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MslGeneratorImplTest, NameConflictWith_InputStructName) {
|
TEST_F(MslGeneratorImplTest, NameConflictWith_InputStructName) {
|
||||||
ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
|
|
||||||
|
|
||||||
GeneratorImpl g;
|
GeneratorImpl g;
|
||||||
ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in");
|
ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
|
||||||
|
|
||||||
ast::IdentifierExpression ident("func_main_in");
|
ast::IdentifierExpression ident("func_main_in");
|
||||||
ASSERT_TRUE(g.EmitIdentifier(&ident));
|
ASSERT_TRUE(g.EmitIdentifier(&ident));
|
||||||
|
|
|
@ -28,9 +28,9 @@ fn vtx_main() -> void {
|
||||||
entry_point vertex as "main" = vtx_main;
|
entry_point vertex as "main" = vtx_main;
|
||||||
|
|
||||||
# Fragment shader
|
# Fragment shader
|
||||||
[[location 0]] var outColor : ptr<out, vec4<f32>>;
|
[[location 0]] var<out> outColor : vec4<f32>;
|
||||||
fn frag_main() -> void {
|
fn frag_main() -> void {
|
||||||
outColor = vec4<f32>(1, 0, 0, 1);
|
outColor = vec4<f32>(1, 0, 0, 1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
entry_point fragment as "main" = frag_main;
|
entry_point fragment = frag_main;
|
||||||
|
|
Loading…
Reference in New Issue