diff --git a/BUILD.gn b/BUILD.gn index e9c4b5ee42..e75d512c3e 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -1097,6 +1097,7 @@ source_set("tint_unittests_hlsl_writer_src") { "src/writer/hlsl/generator_impl_continue_test.cc", "src/writer/hlsl/generator_impl_discard_test.cc", "src/writer/hlsl/generator_impl_entry_point_test.cc", + "src/writer/hlsl/generator_impl_function_entry_point_data_test.cc", "src/writer/hlsl/generator_impl_function_test.cc", "src/writer/hlsl/generator_impl_identifier_test.cc", "src/writer/hlsl/generator_impl_if_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aa142150f6..c2edcd0bf6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -607,6 +607,7 @@ if (${TINT_BUILD_HLSL_WRITER}) writer/hlsl/generator_impl_continue_test.cc writer/hlsl/generator_impl_discard_test.cc writer/hlsl/generator_impl_entry_point_test.cc + writer/hlsl/generator_impl_function_entry_point_data_test.cc writer/hlsl/generator_impl_function_test.cc writer/hlsl/generator_impl_identifier_test.cc writer/hlsl/generator_impl_if_test.cc diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index f81e36c6bf..b219595423 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -139,6 +139,18 @@ bool GeneratorImpl::Generate(std::ostream& out) { return false; } } + + // Make sure all entry point data is emitted before the entry point functions + for (const auto& func : module_->functions()) { + if (!func->IsEntryPoint()) { + continue; + } + + if (!EmitEntryPointData(out, func.get())) { + return false; + } + } + for (const auto& func : module_->functions()) { if (!EmitFunction(out, func.get())) { return false; @@ -151,6 +163,15 @@ bool GeneratorImpl::Generate(std::ostream& out) { out << std::endl; } + for (const auto& func : module_->functions()) { + if (!func->IsEntryPoint()) { + continue; + } + if (!EmitEntryPointFunction(out, func.get())) { + return false; + } + out << std::endl; + } return true; } @@ -1085,7 +1106,7 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) { make_indent(out); // Entry points will be emitted later, skip for now. - if (module_->IsFunctionEntryPoint(func->name())) { + if (func->IsEntryPoint() || module_->IsFunctionEntryPoint(func->name())) { return true; } @@ -1381,6 +1402,186 @@ bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::EntryPoint* ep) { return true; } +bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) { + std::vector> in_variables; + std::vector> outvariables; + for (auto data : func->referenced_location_variables()) { + auto* var = data.first; + auto* deco = data.second; + + if (var->storage_class() == ast::StorageClass::kInput) { + in_variables.push_back({var, deco}); + } else if (var->storage_class() == ast::StorageClass::kOutput) { + outvariables.push_back({var, deco}); + } + } + + for (auto data : func->referenced_builtin_variables()) { + auto* var = data.first; + auto* deco = data.second; + + if (var->storage_class() == ast::StorageClass::kInput) { + in_variables.push_back({var, deco}); + } else if (var->storage_class() == ast::StorageClass::kOutput) { + outvariables.push_back({var, deco}); + } + } + + bool emitted_uniform = false; + for (auto data : func->referenced_uniform_variables()) { + auto* var = data.first; + // TODO(dsinclair): We're using the binding to make up the buffer number but + // we should instead be using a provided mapping that uses both buffer and + // set. https://bugs.chromium.org/p/tint/issues/detail?id=104 + auto* binding = data.second.binding; + if (binding == nullptr) { + error_ = "unable to find binding information for uniform: " + var->name(); + return false; + } + // auto* set = data.second.set; + + auto* type = var->type()->UnwrapAliasesIfNeeded(); + if (type->IsStruct()) { + auto* strct = type->AsStruct(); + + out << "ConstantBuffer<" << strct->name() << "> " << var->name() + << " : register(b" << binding->value() << ");" << std::endl; + } else { + // TODO(dsinclair): There is outstanding spec work to require all uniform + // buffers to be [[block]] decorated, which means structs. This is + // currently not the case, so this code handles the cases where the data + // is not a block. + // Relevant: https://github.com/gpuweb/gpuweb/issues/1004 + // https://github.com/gpuweb/gpuweb/issues/1008 + out << "cbuffer : register(b" << binding->value() << ") {" << std::endl; + + increment_indent(); + make_indent(out); + if (!EmitType(out, type, "")) { + return false; + } + out << " " << var->name() << ";" << std::endl; + decrement_indent(); + out << "};" << std::endl; + } + + emitted_uniform = true; + } + if (emitted_uniform) { + out << std::endl; + } + + bool emitted_storagebuffer = false; + for (auto data : func->referenced_storagebuffer_variables()) { + auto* var = data.first; + auto* binding = data.second.binding; + + out << "RWByteAddressBuffer " << var->name() << " : register(u" + << binding->value() << ");" << std::endl; + emitted_storagebuffer = true; + } + if (emitted_storagebuffer) { + out << std::endl; + } + + if (!in_variables.empty()) { + auto in_struct_name = + generate_name(func->name() + "_" + kInStructNameSuffix); + auto in_var_name = generate_name(kTintStructInVarPrefix); + ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; + + make_indent(out); + out << "struct " << in_struct_name << " {" << std::endl; + + increment_indent(); + + for (auto& data : in_variables) { + auto* var = data.first; + auto* deco = data.second; + + make_indent(out); + if (!EmitType(out, var->type(), var->name())) { + return false; + } + + out << " " << var->name() << " : "; + if (deco->IsLocation()) { + if (func->pipeline_stage() == ast::PipelineStage::kCompute) { + error_ = "invalid location variable for pipeline stage"; + return false; + } + out << "TEXCOORD" << deco->AsLocation()->value(); + } else if (deco->IsBuiltin()) { + auto attr = builtin_to_attribute(deco->AsBuiltin()->value()); + if (attr.empty()) { + error_ = "unsupported builtin"; + return false; + } + out << attr; + } else { + error_ = "unsupported variable decoration for entry point output"; + return false; + } + out << ";" << std::endl; + } + decrement_indent(); + make_indent(out); + + out << "};" << std::endl << std::endl; + } + + if (!outvariables.empty()) { + auto outstruct_name = + generate_name(func->name() + "_" + kOutStructNameSuffix); + auto outvar_name = generate_name(kTintStructOutVarPrefix); + ep_name_to_out_data_[func->name()] = {outstruct_name, outvar_name}; + + make_indent(out); + out << "struct " << outstruct_name << " {" << std::endl; + + increment_indent(); + for (auto& data : outvariables) { + auto* var = data.first; + auto* deco = data.second; + + make_indent(out); + if (!EmitType(out, var->type(), var->name())) { + return false; + } + + out << " " << var->name() << " : "; + + if (deco->IsLocation()) { + auto loc = deco->AsLocation()->value(); + if (func->pipeline_stage() == ast::PipelineStage::kVertex) { + out << "TEXCOORD" << loc; + } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) { + out << "SV_Target" << loc << ""; + } else { + error_ = "invalid location variable for pipeline stage"; + return false; + } + } else if (deco->IsBuiltin()) { + auto attr = builtin_to_attribute(deco->AsBuiltin()->value()); + if (attr.empty()) { + error_ = "unsupported builtin"; + return false; + } + out << attr; + } else { + error_ = "unsupported variable decoration for entry point output"; + return false; + } + out << ";" << std::endl; + } + decrement_indent(); + make_indent(out); + out << "};" << std::endl << std::endl; + } + + return true; +} + std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const { switch (builtin) { case ast::Builtin::kPosition: @@ -1472,6 +1673,62 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, return true; } +bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, + ast::Function* func) { + make_indent(out); + + current_ep_name_ = func->name(); + + if (func->pipeline_stage() == ast::PipelineStage::kCompute) { + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = func->workgroup_size(); + out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y) + << ", " << std::to_string(z) << ")]" << std::endl; + make_indent(out); + } + + auto outdata = ep_name_to_out_data_.find(current_ep_name_); + bool has_outdata = outdata != ep_name_to_out_data_.end(); + if (has_outdata) { + out << outdata->second.struct_name; + } else { + out << "void"; + } + out << " " << namer_.NameFor(current_ep_name_) << "("; + + auto in_data = ep_name_to_in_data_.find(current_ep_name_); + if (in_data != ep_name_to_in_data_.end()) { + out << in_data->second.struct_name << " " << in_data->second.var_name; + } + out << ") {" << std::endl; + + increment_indent(); + + if (has_outdata) { + make_indent(out); + out << outdata->second.struct_name << " " << outdata->second.var_name << ";" + << std::endl; + } + + generating_entry_point_ = true; + for (const auto& s : *(func->body())) { + if (!EmitStatement(out, s.get())) { + return false; + } + } + generating_entry_point_ = false; + + decrement_indent(); + make_indent(out); + out << "}" << std::endl; + + current_ep_name_ = ""; + + return true; +} + bool GeneratorImpl::EmitLiteral(std::ostream& out, ast::Literal* lit) { if (lit->IsBool()) { out << (lit->AsBool()->IsTrue() ? "true" : "false"); diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 5da53a67f6..264c119ccb 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -196,11 +196,21 @@ class GeneratorImpl { /// @param ep the entry point /// @returns true if the entry point data was emitted bool EmitEntryPointData(std::ostream& out, ast::EntryPoint* ep); + /// Handles emitting information for an entry point + /// @param out the output stream + /// @param func the entry point + /// @returns true if the entry point data was emitted + bool EmitEntryPointData(std::ostream& out, ast::Function* func); /// Handles emitting the entry point function /// @param out the output stream /// @param ep the entry point /// @returns true if the entry point function was emitted bool EmitEntryPointFunction(std::ostream& out, ast::EntryPoint* ep); + /// Handles emitting the entry point function + /// @param out the output stream + /// @param func the entry point + /// @returns true if the entry point function was emitted + bool EmitEntryPointFunction(std::ostream& out, ast::Function* func); /// Handles an if statement /// @param out the output stream /// @param stmt the statement to emit diff --git a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc new file mode 100644 index 0000000000..4a65bef053 --- /dev/null +++ b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc @@ -0,0 +1,449 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/assignment_statement.h" +#include "src/ast/decorated_variable.h" +#include "src/ast/entry_point.h" +#include "src/ast/identifier_expression.h" +#include "src/ast/location_decoration.h" +#include "src/ast/member_accessor_expression.h" +#include "src/ast/module.h" +#include "src/ast/pipeline_stage.h" +#include "src/ast/return_statement.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/type/f32_type.h" +#include "src/ast/type/i32_type.h" +#include "src/ast/type/vector_type.h" +#include "src/ast/type/void_type.h" +#include "src/ast/variable.h" +#include "src/context.h" +#include "src/type_determiner.h" +#include "src/writer/hlsl/test_helper.h" + +#include + +namespace tint { +namespace writer { +namespace hlsl { +namespace { + +using HlslGeneratorImplTest_EntryPoint = TestHelper; + +TEST_F(HlslGeneratorImplTest_EntryPoint, + Emit_Function_EntryPointData_Vertex_Input) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct vtx_main_in { + // float foo : TEXCOORD0; + // int bar : TEXCOORD1; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kInput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = + std::make_unique("vtx_main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kVertex)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + EXPECT_EQ(result(), R"(struct vtx_main_in { + float foo : TEXCOORD0; + int bar : TEXCOORD1; +}; + +)"); +} + +TEST_F(HlslGeneratorImplTest_EntryPoint, + Emit_Function_EntryPointData_Vertex_Output) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct vtx_main_out { + // float foo : TEXCOORD0; + // int bar : TEXCOORD1; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = + std::make_unique("vtx_main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kVertex)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + EXPECT_EQ(result(), R"(struct vtx_main_out { + float foo : TEXCOORD0; + int bar : TEXCOORD1; +}; + +)"); +} + +TEST_F(HlslGeneratorImplTest_EntryPoint, + Emit_Function_EntryPointData_Fragment_Input) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct frag_main_in { + // float foo : TEXCOORD0; + // int bar : TEXCOORD1; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kInput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kVertex)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + EXPECT_EQ(result(), R"(struct main_in { + float foo : TEXCOORD0; + int bar : TEXCOORD1; +}; + +)"); +} + +TEST_F(HlslGeneratorImplTest_EntryPoint, + Emit_Function_EntryPointData_Fragment_Output) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct frag_main_out { + // float foo : SV_Target0; + // int bar : SV_Target1; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + EXPECT_EQ(result(), R"(struct main_out { + float foo : SV_Target0; + int bar : SV_Target1; +}; + +)"); +} + +TEST_F(HlslGeneratorImplTest_EntryPoint, + Emit_Function_EntryPointData_Compute_Input) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // -> Error, not allowed + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kInput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kCompute)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_FALSE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + EXPECT_EQ(gen().error(), R"(invalid location variable for pipeline stage)"); +} + +TEST_F(HlslGeneratorImplTest_EntryPoint, + Emit_Function_EntryPointData_Compute_Output) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // -> Error not allowed + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kCompute)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_FALSE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + EXPECT_EQ(gen().error(), R"(invalid location variable for pipeline stage)"); +} + +TEST_F(HlslGeneratorImplTest_EntryPoint, + Emit_Function_EntryPointData_Builtins) { + // [[builtin frag_coord]] var coord : vec4; + // [[builtin frag_depth]] var depth : f32; + // + // struct main_in { + // vector coord : SV_Position; + // }; + // + // struct main_out { + // float depth : SV_Depth; + // }; + + ast::type::F32Type f32; + ast::type::VoidType void_type; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + td().RegisterVariableForTesting(depth_var.get()); + + mod()->AddGlobalVariable(std::move(coord_var)); + mod()->AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + auto func = + std::make_unique("main", std::move(params), &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + EXPECT_EQ(result(), R"(struct main_in { + vector coord : SV_Position; +}; + +struct main_out { + float depth : SV_Depth; +}; + +)"); +} + +} // namespace +} // namespace hlsl +} // namespace writer +} // namespace tint diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 2056427ea0..4c25dd7c54 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -24,10 +24,12 @@ #include "src/ast/location_decoration.h" #include "src/ast/member_accessor_expression.h" #include "src/ast/module.h" +#include "src/ast/pipeline_stage.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/set_decoration.h" #include "src/ast/sint_literal.h" +#include "src/ast/stage_decoration.h" #include "src/ast/struct.h" #include "src/ast/struct_member_offset_decoration.h" #include "src/ast/type/alias_type.h" @@ -198,6 +200,63 @@ frag_main_out frag_main(frag_main_in tint_in) { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_WithInOutVars) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("foo"))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct frag_main_in { + float foo : TEXCOORD0; +}; + +struct frag_main_out { + float bar : SV_Target1; +}; + +frag_main_out frag_main(frag_main_in tint_in) { + frag_main_out tint_out; + tint_out.bar = tint_in.foo; + return tint_out; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithInOut_Builtins) { ast::type::VoidType void_type; @@ -264,6 +323,70 @@ frag_main_out frag_main(frag_main_in tint_in) { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_WithInOut_Builtins) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + td().RegisterVariableForTesting(depth_var.get()); + + mod()->AddGlobalVariable(std::move(coord_var)); + mod()->AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct frag_main_in { + vector coord : SV_Position; +}; + +struct frag_main_out { + float depth : SV_Depth; +}; + +frag_main_out frag_main(frag_main_in tint_in) { + frag_main_out tint_out; + tint_out.depth = tint_in.coord.x; + return tint_out; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_With_Uniform) { ast::type::VoidType void_type; ast::type::F32Type f32; @@ -316,6 +439,57 @@ void frag_main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_With_Uniform) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kUniform, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("coord"), + std::make_unique("x"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(cbuffer : register(b0) { + vector coord; +}; + +void frag_main() { + float v = coord.x; + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_With_UniformStruct) { ast::type::VoidType void_type; @@ -386,6 +560,74 @@ void frag_main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_With_UniformStruct) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + ast::StructMemberList members; + members.push_back(std::make_unique( + "coord", &vec4, ast::StructMemberDecorationList{})); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Uniforms"); + auto alias = std::make_unique("Uniforms", &s); + + auto coord_var = + std::make_unique(std::make_unique( + "uniforms", ast::StorageClass::kUniform, alias.get())); + + mod()->AddAliasType(alias.get()); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique( + std::make_unique("uniforms"), + std::make_unique("coord")), + std::make_unique("x"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct Uniforms { + vector coord; +}; + +ConstantBuffer uniforms : register(b0); + +void frag_main() { + float v = uniforms.coord.x; + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_With_StorageBuffer_Read) { ast::type::VoidType void_type; @@ -454,6 +696,72 @@ void frag_main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_With_StorageBuffer_Read) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kStorageBuffer, &s)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("coord"), + std::make_unique("b"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(RWByteAddressBuffer coord : register(u0); + +void frag_main() { + float v = asfloat(coord.Load(4)); + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_With_StorageBuffer_Store) { ast::type::VoidType void_type; @@ -524,6 +832,74 @@ void frag_main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_With_StorageBuffer_Store) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kStorageBuffer, &s)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + + mod()->AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto assign = std::make_unique( + std::make_unique( + std::make_unique("coord"), + std::make_unique("b")), + std::make_unique( + std::make_unique(&f32, 2.0f))); + + auto body = std::make_unique(); + body->append(std::move(assign)); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(RWByteAddressBuffer coord : register(u0); + +void frag_main() { + coord.Store(4, asuint(2.00000000f)); + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) { ast::type::VoidType void_type; @@ -621,6 +997,102 @@ ep_1_out ep_1(ep_1_in tint_in) { )"); } +TEST_F( + HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + auto val_var = std::make_unique( + std::make_unique("val", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(0)); + val_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(foo_var.get()); + td().RegisterVariableForTesting(bar_var.get()); + td().RegisterVariableForTesting(val_var.get()); + + mod()->AddGlobalVariable(std::move(foo_var)); + mod()->AddGlobalVariable(std::move(bar_var)); + mod()->AddGlobalVariable(std::move(val_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("val"), + std::make_unique("param"))); + body->append(std::make_unique( + std::make_unique("foo"))); + sub_func->set_body(std::move(body)); + + mod()->AddFunction(std::move(sub_func)); + + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod()->AddFunction(std::move(func_1)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct ep_1_in { + float foo : TEXCOORD0; +}; + +struct ep_1_out { + float bar : SV_Target1; + float val : SV_Target0; +}; + +float sub_func_ep_1(in ep_1_in tint_in, out ep_1_out tint_out, float param) { + tint_out.bar = tint_in.foo; + tint_out.val = param; + return tint_in.foo; +} + +ep_1_out ep_1(ep_1_in tint_in) { + ep_1_out tint_out; + tint_out.bar = sub_func_ep_1(tint_in, tint_out, 1.00000000f); + return tint_out; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Called_By_EntryPoints_NoUsedGlobals) { ast::type::VoidType void_type; @@ -694,6 +1166,77 @@ ep_1_out ep_1() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_Called_By_EntryPoints_NoUsedGlobals) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(depth_var.get()); + + mod()->AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("param"))); + sub_func->set_body(std::move(body)); + + mod()->AddFunction(std::move(sub_func)); + + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod()->AddFunction(std::move(func_1)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct ep_1_out { + float depth : SV_Depth; +}; + +float sub_func(float param) { + return param; +} + +ep_1_out ep_1() { + ep_1_out tint_out; + tint_out.depth = sub_func(1.00000000f); + return tint_out; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { ast::type::VoidType void_type; @@ -786,6 +1329,97 @@ ep_1_out ep_1(ep_1_in tint_in) { )"); } +TEST_F( + HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + td().RegisterVariableForTesting(depth_var.get()); + + mod()->AddGlobalVariable(std::move(coord_var)); + mod()->AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + body->append(std::make_unique( + std::make_unique("param"))); + sub_func->set_body(std::move(body)); + + mod()->AddFunction(std::move(sub_func)); + + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod()->AddFunction(std::move(func_1)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct ep_1_in { + vector coord : SV_Position; +}; + +struct ep_1_out { + float depth : SV_Depth; +}; + +float sub_func_ep_1(in ep_1_in tint_in, out ep_1_out tint_out, float param) { + tint_out.depth = tint_in.coord.x; + return param; +} + +ep_1_out ep_1(ep_1_in tint_in) { + ep_1_out tint_out; + tint_out.depth = sub_func_ep_1(tint_in, tint_out, 1.00000000f); + return tint_out; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Called_By_EntryPoint_With_Uniform) { ast::type::VoidType void_type; @@ -862,6 +1496,80 @@ void frag_main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_Called_By_EntryPoint_With_Uniform) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kUniform, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + + mod()->AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + sub_func->set_body(std::move(body)); + + mod()->AddFunction(std::move(sub_func)); + + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("sub_func"), + std::move(expr))); + + body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(cbuffer : register(b0) { + vector coord; +}; + +float sub_func(float param) { + return coord.x; +} + +void frag_main() { + float v = sub_func(1.00000000f); + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Called_By_EntryPoint_With_StorageBuffer) { ast::type::VoidType void_type; @@ -936,6 +1644,78 @@ void frag_main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_Called_By_EntryPoint_With_StorageBuffer) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kStorageBuffer, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + + mod()->AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + sub_func->set_body(std::move(body)); + + mod()->AddFunction(std::move(sub_func)); + + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("sub_func"), + std::move(expr))); + + body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(RWByteAddressBuffer coord : register(u0); + +float sub_func(float param) { + return asfloat(coord.Load((4 * 0))); +} + +void frag_main() { + float v = sub_func(1.00000000f); + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Called_Two_EntryPoints_WithGlobals) { ast::type::VoidType void_type; @@ -1101,6 +1881,68 @@ ep_1_out ep_1() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoints_WithGlobal_Nested_Return) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(bar_var.get()); + mod()->AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique( + std::make_unique(&f32, 1.0f)))); + + auto list = std::make_unique(); + list->append(std::make_unique()); + + body->append(std::make_unique( + std::make_unique( + ast::BinaryOp::kEqual, + std::make_unique( + std::make_unique(&i32, 1)), + std::make_unique( + std::make_unique(&i32, 1))), + std::move(list))); + + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod()->AddFunction(std::move(func_1)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct ep_1_out { + float bar : SV_Target1; +}; + +ep_1_out ep_1() { + ep_1_out tint_out; + tint_out.bar = 1.00000000f; + if ((1 == 1)) { + return tint_out; + } + return tint_out; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Called_Two_EntryPoints_WithoutGlobals) { ast::type::VoidType void_type; @@ -1160,6 +2002,7 @@ void ep_2() { )"); } + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithName) { ast::type::VoidType void_type; @@ -1197,6 +2040,24 @@ TEST_F(HlslGeneratorImplTest_Function, )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_WithNameCollision) { + ast::type::VoidType void_type; + + auto func = std::make_unique("GeometryShader", + ast::VariableList{}, &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(void GeometryShader_tint_0() { +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_Compute) { ast::type::VoidType void_type; @@ -1224,6 +2085,32 @@ void main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_Compute) { + ast::type::VoidType void_type; + + ast::VariableList params; + auto func = + std::make_unique("main", std::move(params), &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kCompute)); + + auto body = std::make_unique(); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"([numthreads(1, 1, 1)] +void main() { + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_Compute_WithWorkgroup) { ast::type::VoidType void_type; @@ -1253,6 +2140,33 @@ void main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_Compute_WithWorkgroup) { + ast::type::VoidType void_type; + + ast::VariableList params; + auto func = + std::make_unique("main", std::move(params), &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kCompute)); + func->add_decoration(std::make_unique(2u, 4u, 6u)); + + auto body = std::make_unique(); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"([numthreads(2, 4, 6)] +void main() { + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { ast::type::F32Type f32; ast::type::ArrayType ary(&f32, 5);