[msl-writer] Add builtin support

This CL extends module scoped variables to include support for builtins.

Bug: tint:8
Change-Id: I9e4363be32401bfdd45ad5d1727d9432aca206fe
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24786
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-07-15 20:51:16 +00:00
parent 5423d91d87
commit 7caf6e5959
8 changed files with 449 additions and 19 deletions

View File

@ -35,6 +35,15 @@ bool DecoratedVariable::HasLocationDecoration() const {
return false;
}
bool DecoratedVariable::HasBuiltinDecoration() const {
for (const auto& deco : decorations_) {
if (deco->IsBuiltin()) {
return true;
}
}
return false;
}
bool DecoratedVariable::IsDecorated() const {
return true;
}

View File

@ -47,6 +47,8 @@ class DecoratedVariable : public Variable {
/// @returns true if the decorations include a LocationDecoration
bool HasLocationDecoration() const;
/// @returns true if the deocrations include a BuiltinDecoration
bool HasBuiltinDecoration() const;
/// @returns true if this is a decorated variable
bool IsDecorated() const override;

View File

@ -333,11 +333,26 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
if (!first) {
out_ << ", ";
}
out_ << var_name;
first = false;
out_ << var_name;
}
// TODO(dsinclair): Emit builtins
auto* func = module_->FindFunctionByName(ident->name());
if (func == nullptr) {
error_ = "Unable to find function: " + name;
return false;
}
for (const auto& data : func->referenced_builtin_variables()) {
auto* var = data.first;
if (var->storage_class() != ast::StorageClass::kInput) {
continue;
}
if (!first) {
out_ << ", ";
}
first = false;
out_ << var->name();
}
const auto& params = expr->params();
for (const auto& param : params) {
@ -517,15 +532,25 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
}
std::vector<std::pair<ast::Variable*, uint32_t>> in_locations;
std::vector<std::pair<ast::Variable*, uint32_t>> out_locations;
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>>
out_variables;
for (auto data : func->referenced_location_variables()) {
auto var = data.first;
auto locn_deco = data.second;
auto deco = data.second;
if (var->storage_class() == ast::StorageClass::kInput) {
in_locations.push_back({var, locn_deco->value()});
in_locations.push_back({var, deco->value()});
} else if (var->storage_class() == ast::StorageClass::kOutput) {
out_locations.push_back({var, locn_deco->value()});
out_variables.push_back({var, deco});
}
}
for (auto data : func->referenced_builtin_variables()) {
auto var = data.first;
auto deco = data.second;
if (var->storage_class() == ast::StorageClass::kOutput) {
out_variables.push_back({var, deco});
}
}
@ -575,7 +600,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
out_ << "};" << std::endl << std::endl;
}
if (!out_locations.empty()) {
if (!out_variables.empty()) {
auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix);
auto out_var_name = generate_name(kTintStructOutVarPrefix);
ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name};
@ -584,9 +609,9 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
out_ << "struct " << out_struct_name << " {" << std::endl;
increment_indent();
for (auto& data : out_locations) {
for (auto& data : out_variables) {
auto* var = data.first;
uint32_t loc = data.second;
auto* deco = data.second;
make_indent();
if (!EmitType(var->type(), var->name())) {
@ -594,12 +619,26 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
}
out_ << " " << var->name() << " [[";
if (ep->stage() == ast::PipelineStage::kVertex) {
out_ << "user(locn" << loc << ")";
} else if (ep->stage() == ast::PipelineStage::kFragment) {
out_ << "color(" << loc << ")";
if (deco->IsLocation()) {
auto loc = deco->AsLocation()->value();
if (ep->stage() == ast::PipelineStage::kVertex) {
out_ << "user(locn" << loc << ")";
} else if (ep->stage() == ast::PipelineStage::kFragment) {
out_ << "color(" << loc << ")";
} else {
error_ = "invalid location variable for pipeline stage";
return false;
}
} else if (deco->IsBuiltin()) {
auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
if (attr.empty()) {
error_ = "unsupported builtin";
return false;
}
out_ << attr;
} else {
error_ = "invalid location variable for pipeline stage";
error_ = "unsupported variable decoration for entry point output";
return false;
}
out_ << "]];" << std::endl;
@ -739,7 +778,22 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
}
}
// TODO(dsinclair): Handle any entry point builtin params used here
for (const auto& data : func->referenced_builtin_variables()) {
auto* var = data.first;
if (var->storage_class() != ast::StorageClass::kInput) {
continue;
}
if (!first) {
out_ << ", ";
}
first = false;
out_ << "thread ";
if (!EmitType(var->type(), "")) {
return false;
}
out_ << "& " << var->name();
}
// TODO(dsinclair): Binding/Set inputs
@ -771,6 +825,41 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
return true;
}
std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
switch (builtin) {
case ast::Builtin::kPosition:
return "position";
case ast::Builtin::kVertexIdx:
return "vertex_id";
case ast::Builtin::kInstanceIdx:
return "instance_id";
case ast::Builtin::kFrontFacing:
return "front_facing";
case ast::Builtin::kFragCoord:
return "position";
case ast::Builtin::kFragDepth:
return "depth(any)";
// TODO(dsinclair): Ignore for now, I believe it will be removed from WGSL
// https://github.com/gpuweb/gpuweb/issues/920
case ast::Builtin::kNumWorkgroups:
return "";
// TODO(dsinclair): Ignore for now. This has been removed as a builtin
// in the spec. Need to update Tint to match.
// https://github.com/gpuweb/gpuweb/pull/824
case ast::Builtin::kWorkgroupSize:
return "";
case ast::Builtin::kLocalInvocationId:
return "thread_position_in_threadgroup";
case ast::Builtin::kLocalInvocationIdx:
return "thread_index_in_threadgroup";
case ast::Builtin::kGlobalInvocationId:
return "thread_position_in_grid";
default:
break;
}
return "";
}
bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
make_indent();
@ -799,13 +888,38 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
}
out_ << " " << namer_.NameFor(current_ep_name_) << "(";
bool first = true;
auto in_data = ep_name_to_in_data_.find(current_ep_name_);
if (in_data != ep_name_to_in_data_.end()) {
out_ << in_data->second.struct_name << " " << in_data->second.var_name
<< " [[stage_in]]";
first = false;
}
// TODO(dsinclair): Output other builtin inputs
for (auto data : func->referenced_builtin_variables()) {
auto* var = data.first;
if (var->storage_class() != ast::StorageClass::kInput) {
continue;
}
if (!first) {
out_ << ", ";
}
first = false;
auto* builtin = data.second;
if (!EmitType(var->type(), "")) {
return false;
}
auto attr = builtin_to_attribute(builtin->value());
if (attr.empty()) {
error_ = "unknown builtin";
return false;
}
out_ << " " << var->name() << " [[" << attr << "]]";
}
// TODO(dsinclair): Binding/Set inputs
@ -845,9 +959,14 @@ bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
ast::Variable* var = nullptr;
if (global_variables_.get(ident->name(), &var)) {
if (var->IsDecorated() && var->AsDecorated()->HasLocationDecoration() &&
bool in_or_out_struct_has_location =
var->IsDecorated() && var->AsDecorated()->HasLocationDecoration() &&
(var->storage_class() == ast::StorageClass::kInput ||
var->storage_class() == ast::StorageClass::kOutput)) {
var->storage_class() == ast::StorageClass::kOutput);
bool in_struct_has_builtin =
var->IsDecorated() && var->AsDecorated()->HasBuiltinDecoration() &&
var->storage_class() == ast::StorageClass::kOutput;
if (in_or_out_struct_has_location || in_struct_has_builtin) {
auto var_type = var->storage_class() == ast::StorageClass::kInput
? VarType::kIn
: VarType::kOut;

View File

@ -198,6 +198,11 @@ class GeneratorImpl : public TextGenerator {
/// @returns the name
std::string generate_name(const std::string& prefix);
/// Converts a builtin to an attribute name
/// @param builtin the builtin to convert
/// @returns the string name of the builtin or blank on error
std::string builtin_to_attribute(ast::Builtin builtin) const;
/// @returns the namer for testing
Namer* namer_for_testing() { return &namer_; }

View File

@ -16,7 +16,10 @@
#include "gtest/gtest.h"
#include "src/ast/call_expression.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/module.h"
#include "src/ast/type/void_type.h"
#include "src/writer/msl/generator_impl.h"
namespace tint {
@ -27,22 +30,42 @@ namespace {
using MslGeneratorImplTest = testing::Test;
TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithoutParams) {
ast::type::VoidType void_type;
auto id = std::make_unique<ast::IdentifierExpression>("my_func");
ast::CallExpression call(std::move(id), {});
auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{},
&void_type);
ast::Module m;
m.AddFunction(std::move(func));
GeneratorImpl g;
g.set_module_for_testing(&m);
ASSERT_TRUE(g.EmitExpression(&call)) << g.error();
EXPECT_EQ(g.result(), "my_func()");
}
TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) {
ast::type::VoidType void_type;
auto id = std::make_unique<ast::IdentifierExpression>("my_func");
ast::ExpressionList params;
params.push_back(std::make_unique<ast::IdentifierExpression>("param1"));
params.push_back(std::make_unique<ast::IdentifierExpression>("param2"));
ast::CallExpression call(std::move(id), std::move(params));
auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{},
&void_type);
ast::Module m;
m.AddFunction(std::move(func));
GeneratorImpl g;
g.set_module_for_testing(&m);
ASSERT_TRUE(g.EmitExpression(&call)) << g.error();
EXPECT_EQ(g.result(), "my_func(param1, param2)");
}

View File

@ -18,9 +18,12 @@
#include "src/ast/entry_point.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/location_decoration.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/module.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type/void_type.h"
#include "src/ast/variable.h"
#include "src/context.h"
#include "src/type_determiner.h"
@ -419,6 +422,78 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Output) {
EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
}
TEST_F(MslGeneratorImplTest, EmitEntryPointData_Builtins) {
// Output builtins go in the output struct, input builtins will be passed
// as input parameters to the entry point function.
// [[builtin frag_coord]] var<in> coord : vec4<f32>;
// [[builtin frag_depth]] var<out> depth : f32;
//
// struct main_out {
// float depth [[depth(any)]];
// };
ast::type::F32Type f32;
ast::type::VoidType void_type;
ast::type::VectorType vec4(&f32, 4);
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"coord", ast::StorageClass::kInput, &vec4));
ast::VariableDecorationList decos;
decos.push_back(
std::make_unique<ast::BuiltinDecoration>(ast::Builtin::kFragCoord));
coord_var->set_decorations(std::move(decos));
auto depth_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"depth", ast::StorageClass::kOutput, &f32));
decos.push_back(
std::make_unique<ast::BuiltinDecoration>(ast::Builtin::kFragDepth));
depth_var->set_decorations(std::move(decos));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(coord_var.get());
td.RegisterVariableForTesting(depth_var.get());
mod.AddGlobalVariable(std::move(coord_var));
mod.AddGlobalVariable(std::move(depth_var));
ast::VariableList params;
auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
&void_type);
ast::StatementList body;
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("depth"),
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("coord"),
std::make_unique<ast::IdentifierExpression>("x"))));
func->set_body(std::move(body));
mod.AddFunction(std::move(func));
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
"main", "frag_main");
auto* ep_ptr = ep.get();
mod.AddEntryPoint(std::move(ep));
ASSERT_TRUE(td.Determine()) << td.error();
GeneratorImpl g;
g.set_module_for_testing(&mod);
ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
EXPECT_EQ(g.result(), R"(struct main_out {
float depth [[depth(any)]];
};
)");
}
} // namespace
} // namespace msl
} // namespace writer

View File

@ -22,6 +22,7 @@
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/location_decoration.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
@ -29,6 +30,7 @@
#include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type/void_type.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
@ -216,8 +218,76 @@ fragment frag_main_out frag_main(frag_main_in tint_in [[stage_in]]) {
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOut_Builtins) {
ast::type::VoidType void_type;
ast::type::F32Type f32;
ast::type::VectorType vec4(&f32, 4);
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"coord", ast::StorageClass::kInput, &vec4));
ast::VariableDecorationList decos;
decos.push_back(
std::make_unique<ast::BuiltinDecoration>(ast::Builtin::kFragCoord));
coord_var->set_decorations(std::move(decos));
auto depth_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"depth", ast::StorageClass::kOutput, &f32));
decos.push_back(
std::make_unique<ast::BuiltinDecoration>(ast::Builtin::kFragDepth));
depth_var->set_decorations(std::move(decos));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(coord_var.get());
td.RegisterVariableForTesting(depth_var.get());
mod.AddGlobalVariable(std::move(coord_var));
mod.AddGlobalVariable(std::move(depth_var));
ast::VariableList params;
auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
&void_type);
ast::StatementList body;
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("depth"),
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("coord"),
std::make_unique<ast::IdentifierExpression>("x"))));
body.push_back(std::make_unique<ast::ReturnStatement>());
func->set_body(std::move(body));
mod.AddFunction(std::move(func));
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
"frag_main");
mod.AddEntryPoint(std::move(ep));
ASSERT_TRUE(td.Determine()) << td.error();
GeneratorImpl g;
ASSERT_TRUE(g.Generate(mod)) << g.error();
EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
struct frag_main_out {
float depth [[depth(any)]];
};
fragment frag_main_out frag_main(float4 coord [[position]]) {
frag_main_out tint_out = {};
tint_out.depth = coord.x;
return tint_out;
}
)");
}
TEST_F(MslGeneratorImplTest,
Emit_Function_Called_By_EntryPoints_WithGlobals_And_Params) {
Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) {
ast::type::VoidType void_type;
ast::type::F32Type f32;
@ -318,6 +388,99 @@ fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) {
)");
}
TEST_F(MslGeneratorImplTest,
Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) {
ast::type::VoidType void_type;
ast::type::F32Type f32;
ast::type::VectorType vec4(&f32, 4);
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"coord", ast::StorageClass::kInput, &vec4));
ast::VariableDecorationList decos;
decos.push_back(
std::make_unique<ast::BuiltinDecoration>(ast::Builtin::kFragCoord));
coord_var->set_decorations(std::move(decos));
auto depth_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"depth", ast::StorageClass::kOutput, &f32));
decos.push_back(
std::make_unique<ast::BuiltinDecoration>(ast::Builtin::kFragDepth));
depth_var->set_decorations(std::move(decos));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(coord_var.get());
td.RegisterVariableForTesting(depth_var.get());
mod.AddGlobalVariable(std::move(coord_var));
mod.AddGlobalVariable(std::move(depth_var));
ast::VariableList params;
params.push_back(std::make_unique<ast::Variable>(
"param", ast::StorageClass::kFunction, &f32));
auto sub_func =
std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
ast::StatementList body;
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("depth"),
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("coord"),
std::make_unique<ast::IdentifierExpression>("x"))));
body.push_back(std::make_unique<ast::ReturnStatement>(
std::make_unique<ast::IdentifierExpression>("param")));
sub_func->set_body(std::move(body));
mod.AddFunction(std::move(sub_func));
auto func_1 = std::make_unique<ast::Function>("frag_1_main",
std::move(params), &void_type);
ast::ExpressionList expr;
expr.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("depth"),
std::make_unique<ast::CallExpression>(
std::make_unique<ast::IdentifierExpression>("sub_func"),
std::move(expr))));
body.push_back(std::make_unique<ast::ReturnStatement>());
func_1->set_body(std::move(body));
mod.AddFunction(std::move(func_1));
auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
"ep_1", "frag_1_main");
mod.AddEntryPoint(std::move(ep1));
ASSERT_TRUE(td.Determine()) << td.error();
GeneratorImpl g;
ASSERT_TRUE(g.Generate(mod)) << g.error();
EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
struct ep_1_out {
float depth [[depth(any)]];
};
float sub_func_ep_1(thread ep_1_out& tint_out, thread float4& coord, float param) {
tint_out.depth = coord.x;
return param;
}
fragment ep_1_out ep_1(float4 coord [[position]]) {
ep_1_out tint_out = {};
tint_out.depth = sub_func_ep_1(tint_out, coord, 1.00000000f);
return tint_out;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) {
ast::type::VoidType void_type;
ast::type::F32Type f32;

View File

@ -74,6 +74,40 @@ TEST_F(MslGeneratorImplTest, NameConflictWith_InputStructName) {
EXPECT_EQ(g.result(), "func_main_in_0");
}
struct MslBuiltinData {
ast::Builtin builtin;
const char* attribute_name;
};
inline std::ostream& operator<<(std::ostream& out, MslBuiltinData data) {
out << data.builtin;
return out;
}
using MslBuiltinConversionTest = testing::TestWithParam<MslBuiltinData>;
TEST_P(MslBuiltinConversionTest, Emit) {
auto params = GetParam();
GeneratorImpl g;
EXPECT_EQ(g.builtin_to_attribute(params.builtin),
std::string(params.attribute_name));
}
INSTANTIATE_TEST_SUITE_P(
MslGeneratorImplTest,
MslBuiltinConversionTest,
testing::Values(MslBuiltinData{ast::Builtin::kPosition, "position"},
MslBuiltinData{ast::Builtin::kVertexIdx, "vertex_id"},
MslBuiltinData{ast::Builtin::kInstanceIdx, "instance_id"},
MslBuiltinData{ast::Builtin::kFrontFacing, "front_facing"},
MslBuiltinData{ast::Builtin::kFragCoord, "position"},
MslBuiltinData{ast::Builtin::kFragDepth, "depth(any)"},
MslBuiltinData{ast::Builtin::kNumWorkgroups, ""},
MslBuiltinData{ast::Builtin::kWorkgroupSize, ""},
MslBuiltinData{ast::Builtin::kLocalInvocationId,
"thread_position_in_threadgroup"},
MslBuiltinData{ast::Builtin::kLocalInvocationIdx,
"thread_index_in_threadgroup"},
MslBuiltinData{ast::Builtin::kGlobalInvocationId,
"thread_position_in_grid"}));
} // namespace
} // namespace msl
} // namespace writer