[hlsl-writer] Add support for input locations and builtins.

This CL adds the beginning of support for input/output locations and
builtins in the HLSL backend.

Bug: tint:7
Change-Id: I8fb01707b50635a800b0d7317cf4a8f62f12cfca
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26780
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-08-19 17:37:25 +00:00 committed by Commit Bot service account
parent be89a06b03
commit 34fd95ced0
4 changed files with 395 additions and 11 deletions

View File

@ -20,6 +20,7 @@
#include "src/ast/binary_expression.h" #include "src/ast/binary_expression.h"
#include "src/ast/bool_literal.h" #include "src/ast/bool_literal.h"
#include "src/ast/case_statement.h" #include "src/ast/case_statement.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h" #include "src/ast/else_statement.h"
#include "src/ast/float_literal.h" #include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h" #include "src/ast/identifier_expression.h"
@ -46,6 +47,11 @@ namespace writer {
namespace hlsl { namespace hlsl {
namespace { namespace {
const char kInStructNameSuffix[] = "in";
const char kOutStructNameSuffix[] = "out";
const char kTintStructInVarPrefix[] = "tint_in";
const char kTintStructOutVarPrefix[] = "tint_out";
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
if (stmts->empty()) { if (stmts->empty()) {
return false; return false;
@ -74,6 +80,11 @@ bool GeneratorImpl::Generate() {
out_ << std::endl; out_ << std::endl;
} }
for (const auto& ep : module_->entry_points()) {
if (!EmitEntryPointData(ep.get())) {
return false;
}
}
for (const auto& func : module_->functions()) { for (const auto& func : module_->functions()) {
if (!EmitFunction(func.get())) { if (!EmitFunction(func.get())) {
return false; return false;
@ -89,6 +100,17 @@ bool GeneratorImpl::Generate() {
return true; return true;
} }
std::string GeneratorImpl::generate_name(const std::string& prefix) {
std::string name = prefix;
uint32_t i = 0;
while (namer_.IsMapped(name)) {
name = prefix + "_" + std::to_string(i);
++i;
}
namer_.RegisterRemappedName(name);
return name;
}
std::string GeneratorImpl::current_ep_var_name(VarType type) { std::string GeneratorImpl::current_ep_var_name(VarType type) {
std::string name = ""; std::string name = "";
switch (type) { switch (type) {
@ -431,8 +453,12 @@ bool GeneratorImpl::EmitExpression(ast::Expression* expr) {
return false; return false;
} }
bool GeneratorImpl::global_is_in_struct(ast::Variable*) const { bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const {
return false; return var->IsDecorated() &&
(var->AsDecorated()->HasLocationDecoration() ||
var->AsDecorated()->HasBuiltinDecoration()) &&
(var->storage_class() == ast::StorageClass::kInput ||
var->storage_class() == ast::StorageClass::kOutput);
} }
bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) { bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
@ -499,6 +525,25 @@ bool GeneratorImpl::EmitElse(ast::ElseStatement* stmt) {
return EmitBlock(stmt->body()); return EmitBlock(stmt->body());
} }
bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
for (auto data : func->referenced_location_variables()) {
auto* var = data.first;
if (var->storage_class() == ast::StorageClass::kOutput ||
var->storage_class() == ast::StorageClass::kInput) {
return true;
}
}
for (auto data : func->referenced_builtin_variables()) {
auto* var = data.first;
if (var->storage_class() == ast::StorageClass::kOutput ||
var->storage_class() == ast::StorageClass::kInput) {
return true;
}
}
return false;
}
bool GeneratorImpl::EmitFunction(ast::Function* func) { bool GeneratorImpl::EmitFunction(ast::Function* func) {
make_indent(); make_indent();
@ -507,6 +552,33 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
return true; return true;
} }
// TODO(dsinclair): This could be smarter. If the input/outputs for multiple
// entry points are the same we could generate a single struct and then have
// this determine it's the same struct and just emit once.
bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 &&
has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) {
for (const auto& ep_name : func->ancestor_entry_points()) {
if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
return false;
}
out_ << std::endl;
}
} else {
// Emit as non-duplicated
if (!EmitFunctionInternal(func, false, "")) {
return false;
}
out_ << std::endl;
}
return true;
}
bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
const std::string& ep_name) {
auto name = func->name(); auto name = func->name();
if (!EmitType(func->return_type(), "")) { if (!EmitType(func->return_type(), "")) {
@ -516,6 +588,30 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
out_ << " " << namer_.NameFor(name) << "("; out_ << " " << namer_.NameFor(name) << "(";
bool first = true; bool first = true;
// If we're emitting duplicate functions that means the function takes
// the stage_in or stage_out value from the entry point, emit them.
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
auto in_it = ep_name_to_in_data_.find(ep_name);
if (in_it != ep_name_to_in_data_.end()) {
out_ << "in " << in_it->second.struct_name << " "
<< in_it->second.var_name;
first = false;
}
auto out_it = ep_name_to_out_data_.find(ep_name);
if (out_it != ep_name_to_out_data_.end()) {
if (!first) {
out_ << ", ";
}
out_ << "out " << out_it->second.struct_name << " "
<< out_it->second.var_name;
first = false;
}
}
for (const auto& v : func->params()) { for (const auto& v : func->params()) {
if (!first) { if (!first) {
out_ << ", "; out_ << ", ";
@ -533,19 +629,188 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
out_ << ") "; out_ << ") ";
current_ep_name_ = ep_name;
if (!EmitBlockAndNewline(func->body())) { if (!EmitBlockAndNewline(func->body())) {
return false; return false;
} }
current_ep_name_ = "";
return true;
}
bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
auto* func = module_->FindFunctionByName(ep->function_name());
if (func == nullptr) {
error_ = "Unable to find entry point function: " + ep->function_name();
return false;
}
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> in_variables;
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>>
out_variables;
for (auto data : func->referenced_location_variables()) {
auto* var = data.first;
auto* deco = data.second;
if (var->storage_class() == ast::StorageClass::kInput) {
in_variables.push_back({var, deco});
} else if (var->storage_class() == ast::StorageClass::kOutput) {
out_variables.push_back({var, deco});
}
}
for (auto data : func->referenced_builtin_variables()) {
auto* var = data.first;
auto* deco = data.second;
if (var->storage_class() == ast::StorageClass::kInput) {
in_variables.push_back({var, deco});
} else if (var->storage_class() == ast::StorageClass::kOutput) {
out_variables.push_back({var, deco});
}
}
auto ep_name = ep->name();
if (ep_name.empty()) {
ep_name = ep->function_name();
}
// TODO(dsinclair): There is a potential bug here. Entry points can have the
// same name in WGSL if they have different pipeline stages. This does not
// take that into account and will emit duplicate struct names. I'm ignoring
// this until https://github.com/gpuweb/gpuweb/issues/662 is resolved as it
// may remove this issue and entry point names will need to be unique.
if (!in_variables.empty()) {
auto in_struct_name = generate_name(ep_name + "_" + kInStructNameSuffix);
auto in_var_name = generate_name(kTintStructInVarPrefix);
ep_name_to_in_data_[ep_name] = {in_struct_name, in_var_name};
make_indent();
out_ << "struct " << in_struct_name << " {" << std::endl;
increment_indent();
for (auto& data : in_variables) {
auto* var = data.first;
auto* deco = data.second;
make_indent();
if (!EmitType(var->type(), var->name())) {
return false;
}
out_ << " " << var->name() << " : ";
if (deco->IsLocation()) {
out_ << "TEXCOORD" << deco->AsLocation()->value();
} else if (deco->IsBuiltin()) {
auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
if (attr.empty()) {
error_ = "unsupported builtin";
return false;
}
out_ << attr;
} else {
error_ = "unsupported variable decoration for entry point output";
return false;
}
out_ << ";" << std::endl;
}
decrement_indent();
make_indent();
out_ << "};" << std::endl << std::endl;
}
if (!out_variables.empty()) {
auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix);
auto out_var_name = generate_name(kTintStructOutVarPrefix);
ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name};
make_indent();
out_ << "struct " << out_struct_name << " {" << std::endl;
increment_indent();
for (auto& data : out_variables) {
auto* var = data.first;
auto* deco = data.second;
make_indent();
if (!EmitType(var->type(), var->name())) {
return false;
}
out_ << " " << var->name() << " : ";
if (deco->IsLocation()) {
auto loc = deco->AsLocation()->value();
if (ep->stage() == ast::PipelineStage::kVertex) {
out_ << "TEXCOORD" << loc;
} else if (ep->stage() == ast::PipelineStage::kFragment) {
out_ << "SV_Target" << loc << "";
} else {
error_ = "invalid location variable for pipeline stage";
return false;
}
} else if (deco->IsBuiltin()) {
auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
if (attr.empty()) {
error_ = "unsupported builtin";
return false;
}
out_ << attr;
} else {
error_ = "unsupported variable decoration for entry point output";
return false;
}
out_ << ";" << std::endl;
}
decrement_indent();
make_indent();
out_ << "};" << std::endl << std::endl;
}
return true; return true;
} }
std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
switch (builtin) {
case ast::Builtin::kPosition:
return "SV_Position";
case ast::Builtin::kVertexIdx:
return "SV_VertexID";
case ast::Builtin::kInstanceIdx:
return "SV_InstanceID";
case ast::Builtin::kFrontFacing:
return "SV_IsFrontFacing";
case ast::Builtin::kFragCoord:
return "SV_Position";
case ast::Builtin::kFragDepth:
return "SV_Depth";
// TODO(dsinclair): Ignore for now. This has been removed as a builtin
// in the spec. Need to update Tint to match.
// https://github.com/gpuweb/gpuweb/pull/824
case ast::Builtin::kWorkgroupSize:
return "";
case ast::Builtin::kLocalInvocationId:
return "SV_GroupThreadID";
case ast::Builtin::kLocalInvocationIdx:
return "SV_GroupIndex";
case ast::Builtin::kGlobalInvocationId:
return "SV_DispatchThreadID";
default:
break;
}
return "";
}
bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) { bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
make_indent(); make_indent();
auto current_ep_name = ep->name(); current_ep_name_ = ep->name();
if (current_ep_name.empty()) { if (current_ep_name_.empty()) {
current_ep_name = ep->function_name(); current_ep_name_ = ep->function_name();
} }
auto* func = module_->FindFunctionByName(ep->function_name()); auto* func = module_->FindFunctionByName(ep->function_name());
@ -554,19 +819,43 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
return false; return false;
} }
out_ << "void " << namer_.NameFor(current_ep_name) << "() {" << std::endl; auto out_data = ep_name_to_out_data_.find(current_ep_name_);
bool has_out_data = out_data != ep_name_to_out_data_.end();
if (has_out_data) {
out_ << out_data->second.struct_name;
} else {
out_ << "void";
}
out_ << " " << namer_.NameFor(current_ep_name_) << "(";
auto in_data = ep_name_to_in_data_.find(current_ep_name_);
if (in_data != ep_name_to_in_data_.end()) {
out_ << in_data->second.struct_name << " " << in_data->second.var_name;
}
out_ << ") {" << std::endl;
increment_indent(); increment_indent();
if (has_out_data) {
make_indent();
out_ << out_data->second.struct_name << " " << out_data->second.var_name
<< ";" << std::endl;
}
generating_entry_point_ = true;
for (const auto& s : *(func->body())) { for (const auto& s : *(func->body())) {
if (!EmitStatement(s.get())) { if (!EmitStatement(s.get())) {
return false; return false;
} }
} }
generating_entry_point_ = false;
decrement_indent(); decrement_indent();
make_indent(); make_indent();
out_ << "}" << std::endl; out_ << "}" << std::endl;
current_ep_name_ = "";
return true; return true;
} }

View File

@ -110,6 +110,19 @@ class GeneratorImpl : public TextGenerator {
/// @param func the function to generate /// @param func the function to generate
/// @returns true if the function was emitted /// @returns true if the function was emitted
bool EmitFunction(ast::Function* func); bool EmitFunction(ast::Function* func);
/// Internal helper for emitting functions
/// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point
/// @param ep_name the current entry point or blank if none set
/// @returns true if the function was emitted.
bool EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
const std::string& ep_name);
/// Handles emitting information for an entry point
/// @param ep the entry point
/// @returns true if the entry point data was emitted
bool EmitEntryPointData(ast::EntryPoint* ep);
/// Handles emitting the entry point function /// Handles emitting the entry point function
/// @param ep the entry point /// @param ep the entry point
/// @returns true if the entry point function was emitted /// @returns true if the entry point function was emitted
@ -168,6 +181,21 @@ class GeneratorImpl : public TextGenerator {
/// @param var the variable to check /// @param var the variable to check
/// @returns true if the global is in an input or output struct /// @returns true if the global is in an input or output struct
bool global_is_in_struct(ast::Variable* var) const; bool global_is_in_struct(ast::Variable* var) const;
/// Generates a name for the prefix
/// @param prefix the prefix of the name to generate
/// @returns the name
std::string generate_name(const std::string& prefix);
/// Converts a builtin to an attribute name
/// @param builtin the builtin to convert
/// @returns the string name of the builtin or blank on error
std::string builtin_to_attribute(ast::Builtin builtin) const;
/// Determines if any used module variable requires an input or output struct.
/// @param func the function to check
/// @returns true if an input or output struct is required.
bool has_referenced_var_needing_struct(ast::Function* func);
/// @returns the namer for testing
Namer* namer_for_testing() { return &namer_; }
private: private:
enum class VarType { kIn, kOut }; enum class VarType { kIn, kOut };

View File

@ -67,6 +67,7 @@ TEST_F(HlslGeneratorImplTest, Emit_Function) {
EXPECT_EQ(g.result(), R"( void my_func() { EXPECT_EQ(g.result(), R"( void my_func() {
return; return;
} }
)"); )");
} }
@ -90,6 +91,7 @@ TEST_F(HlslGeneratorImplTest, Emit_Function_Name_Collision) {
EXPECT_EQ(g.result(), R"( void GeometryShader_tint_0() { EXPECT_EQ(g.result(), R"( void GeometryShader_tint_0() {
return; return;
} }
)"); )");
} }
@ -121,6 +123,7 @@ TEST_F(HlslGeneratorImplTest, Emit_Function_WithParams) {
EXPECT_EQ(g.result(), R"( void my_func(float a, int b) { EXPECT_EQ(g.result(), R"( void my_func(float a, int b) {
return; return;
} }
)"); )");
} }
@ -144,7 +147,7 @@ TEST_F(HlslGeneratorImplTest, Emit_Function_EntryPoint_NoName) {
)"); )");
} }
TEST_F(HlslGeneratorImplTest, DISABLED_Emit_Function_EntryPoint_WithInOutVars) { TEST_F(HlslGeneratorImplTest, Emit_Function_EntryPoint_WithInOutVars) {
ast::type::VoidType void_type; ast::type::VoidType void_type;
ast::type::F32Type f32; ast::type::F32Type f32;
@ -207,8 +210,7 @@ frag_main_out frag_main(frag_main_in tint_in) {
)"); )");
} }
TEST_F(HlslGeneratorImplTest, TEST_F(HlslGeneratorImplTest, Emit_Function_EntryPoint_WithInOut_Builtins) {
DISABLED_Emit_Function_EntryPoint_WithInOut_Builtins) {
ast::type::VoidType void_type; ast::type::VoidType void_type;
ast::type::F32Type f32; ast::type::F32Type f32;
ast::type::VectorType vec4(&f32, 4); ast::type::VectorType vec4(&f32, 4);
@ -262,7 +264,7 @@ TEST_F(HlslGeneratorImplTest,
GeneratorImpl g(&mod); GeneratorImpl g(&mod);
ASSERT_TRUE(g.Generate()) << g.error(); ASSERT_TRUE(g.Generate()) << g.error();
EXPECT_EQ(g.result(), R"(struct frag_main_in { EXPECT_EQ(g.result(), R"(struct frag_main_in {
float gl_FragCoord : SV_Position; vector<float, 4> coord : SV_Position;
}; };
struct frag_main_out { struct frag_main_out {
@ -271,7 +273,7 @@ struct frag_main_out {
frag_main_out frag_main(frag_main_in tint_in) { frag_main_out frag_main(frag_main_in tint_in) {
frag_main_out tint_out; frag_main_out tint_out;
tint_out.depth = tint_in.gl_FragCoord.x; tint_out.depth = tint_in.coord.x;
return tint_out; return tint_out;
} }
@ -377,6 +379,7 @@ TEST_F(HlslGeneratorImplTest,
EXPECT_EQ(g.result(), R"( ... )"); EXPECT_EQ(g.result(), R"( ... )");
} }
// TODO(dsinclair): Requires CallExpression
TEST_F( TEST_F(
HlslGeneratorImplTest, HlslGeneratorImplTest,
DISABLED_Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) { DISABLED_Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) {
@ -480,6 +483,7 @@ ep_1_out ep_1(ep_1_in tint_in) {
)"); )");
} }
// TODO(dsinclair): Requires CallExpression
TEST_F(HlslGeneratorImplTest, TEST_F(HlslGeneratorImplTest,
DISABLED_Emit_Function_Called_By_EntryPoints_NoUsedGlobals) { DISABLED_Emit_Function_Called_By_EntryPoints_NoUsedGlobals) {
ast::type::VoidType void_type; ast::type::VoidType void_type;
@ -558,6 +562,7 @@ fragment ep_1_out ep_1() {
)"); )");
} }
// TODO(dsinclair): Requires CallExpression
TEST_F( TEST_F(
HlslGeneratorImplTest, HlslGeneratorImplTest,
DISABLED_Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { DISABLED_Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) {
@ -1098,6 +1103,7 @@ TEST_F(HlslGeneratorImplTest, Emit_Function_WithArrayParams) {
EXPECT_EQ(g.result(), R"( void my_func(float a[5]) { EXPECT_EQ(g.result(), R"( void my_func(float a[5]) {
return; return;
} }
)"); )");
} }

View File

@ -19,6 +19,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/ast/entry_point.h" #include "src/ast/entry_point.h"
#include "src/ast/function.h" #include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/module.h" #include "src/ast/module.h"
#include "src/ast/type/void_type.h" #include "src/ast/type/void_type.h"
@ -46,6 +47,66 @@ void my_func() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest, InputStructName) {
ast::Module m;
GeneratorImpl g(&m);
ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
}
TEST_F(HlslGeneratorImplTest, InputStructName_ConflictWithExisting) {
ast::Module m;
GeneratorImpl g(&m);
// Register the struct name as existing.
auto* namer = g.namer_for_testing();
namer->NameFor("func_main_out");
ASSERT_EQ(g.generate_name("func_main_out"), "func_main_out_0");
}
TEST_F(HlslGeneratorImplTest, NameConflictWith_InputStructName) {
ast::Module m;
GeneratorImpl g(&m);
ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
ast::IdentifierExpression ident("func_main_in");
ASSERT_TRUE(g.EmitIdentifier(&ident));
EXPECT_EQ(g.result(), "func_main_in_0");
}
struct HlslBuiltinData {
ast::Builtin builtin;
const char* attribute_name;
};
inline std::ostream& operator<<(std::ostream& out, HlslBuiltinData data) {
out << data.builtin;
return out;
}
using HlslBuiltinConversionTest = testing::TestWithParam<HlslBuiltinData>;
TEST_P(HlslBuiltinConversionTest, Emit) {
auto params = GetParam();
ast::Module m;
GeneratorImpl g(&m);
EXPECT_EQ(g.builtin_to_attribute(params.builtin),
std::string(params.attribute_name));
}
INSTANTIATE_TEST_SUITE_P(
HlslGeneratorImplTest,
HlslBuiltinConversionTest,
testing::Values(
HlslBuiltinData{ast::Builtin::kPosition, "SV_Position"},
HlslBuiltinData{ast::Builtin::kVertexIdx, "SV_VertexID"},
HlslBuiltinData{ast::Builtin::kInstanceIdx, "SV_InstanceID"},
HlslBuiltinData{ast::Builtin::kFrontFacing, "SV_IsFrontFacing"},
HlslBuiltinData{ast::Builtin::kFragCoord, "SV_Position"},
HlslBuiltinData{ast::Builtin::kFragDepth, "SV_Depth"},
HlslBuiltinData{ast::Builtin::kWorkgroupSize, ""},
HlslBuiltinData{ast::Builtin::kLocalInvocationId, "SV_GroupThreadID"},
HlslBuiltinData{ast::Builtin::kLocalInvocationIdx, "SV_GroupIndex"},
HlslBuiltinData{ast::Builtin::kGlobalInvocationId,
"SV_DispatchThreadID"}));
} // namespace } // namespace
} // namespace hlsl } // namespace hlsl
} // namespace writer } // namespace writer