Add transform to strip entry points from a module

Remove the Generator::GenerateEntryPoint() APIs as they were mostly
unimplemented and not used by anything except the Tint sample app,
which now uses the new transform.

Change-Id: I1ccb303d6c3aa15e622c193d33b753e22bf39a95
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49160
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2021-04-28 15:33:03 +00:00 committed by Commit Bot service account
parent f5f311e264
commit 0949bdf68f
20 changed files with 577 additions and 258 deletions

View File

@ -30,6 +30,7 @@
#include "src/transform/first_index_offset.h" #include "src/transform/first_index_offset.h"
#include "src/transform/manager.h" #include "src/transform/manager.h"
#include "src/transform/renamer.h" #include "src/transform/renamer.h"
#include "src/transform/single_entry_point.h"
#include "src/transform/vertex_pulling.h" #include "src/transform/vertex_pulling.h"
#include "src/writer/writer.h" #include "src/writer/writer.h"

View File

@ -59,7 +59,6 @@ struct Options {
Format format = Format::kNone; Format format = Format::kNone;
bool emit_single_entry_point = false; bool emit_single_entry_point = false;
tint::ast::PipelineStage stage;
std::string ep_name; std::string ep_name;
std::vector<std::string> transforms; std::vector<std::string> transforms;
@ -77,7 +76,7 @@ const char kUsage[] = R"(Usage: tint [options] <input-file>
.metal -> msl .metal -> msl
.hlsl -> hlsl .hlsl -> hlsl
If none matches, then default to SPIR-V assembly. If none matches, then default to SPIR-V assembly.
-ep <compute|fragment|vertex> <name> -- Output single entry point -ep <name> -- Output single entry point
--output-file <name> -- Output file name. Use "-" for standard output --output-file <name> -- Output file name. Use "-" for standard output
-o <name> -- Output file name. Use "-" for standard output -o <name> -- Output file name. Use "-" for standard output
--transform <name list> -- Runs transforms, name list is comma separated --transform <name list> -- Runs transforms, name list is comma separated
@ -185,19 +184,6 @@ Format infer_format(const std::string& filename) {
return Format::kNone; return Format::kNone;
} }
tint::ast::PipelineStage convert_to_pipeline_stage(const std::string& name) {
if (name == "compute") {
return tint::ast::PipelineStage::kCompute;
}
if (name == "fragment") {
return tint::ast::PipelineStage::kFragment;
}
if (name == "vertex") {
return tint::ast::PipelineStage::kVertex;
}
return tint::ast::PipelineStage::kNone;
}
std::vector<std::string> split_transform_names(std::string list) { std::vector<std::string> split_transform_names(std::string list) {
std::vector<std::string> res; std::vector<std::string> res;
@ -373,17 +359,10 @@ bool ParseArgs(const std::vector<std::string>& args, Options* opts) {
return false; return false;
} }
} else if (arg == "-ep") { } else if (arg == "-ep") {
if (i + 2 >= args.size()) { if (i + 1 >= args.size()) {
std::cerr << "Missing values for -ep" << std::endl; std::cerr << "Missing value for -ep" << std::endl;
return false; return false;
} }
i++;
opts->stage = convert_to_pipeline_stage(args[i]);
if (opts->stage == tint::ast::PipelineStage::kNone) {
std::cerr << "Invalid pipeline stage: " << args[i] << std::endl;
return false;
}
i++; i++;
opts->ep_name = args[i]; opts->ep_name = args[i];
opts->emit_single_entry_point = true; opts->emit_single_entry_point = true;
@ -697,6 +676,13 @@ int main(int argc, const char** argv) {
} }
} }
if (options.emit_single_entry_point) {
transform_manager.append(
std::make_unique<tint::transform::SingleEntryPoint>());
transform_inputs.Add<tint::transform::SingleEntryPoint::Config>(
options.ep_name);
}
switch (options.format) { switch (options.format) {
#if TINT_BUILD_SPV_WRITER #if TINT_BUILD_SPV_WRITER
case Format::kSpirv: case Format::kSpirv:
@ -801,17 +787,10 @@ int main(int argc, const char** argv) {
return 1; return 1;
} }
if (options.emit_single_entry_point) {
if (!writer->GenerateEntryPoint(options.stage, options.ep_name)) {
std::cerr << "Failed to generate: " << writer->error() << std::endl;
return 1;
}
} else {
if (!writer->Generate()) { if (!writer->Generate()) {
std::cerr << "Failed to generate: " << writer->error() << std::endl; std::cerr << "Failed to generate: " << writer->error() << std::endl;
return 1; return 1;
} }
}
#if TINT_BUILD_SPV_WRITER #if TINT_BUILD_SPV_WRITER
bool dawn_validation_failed = false; bool dawn_validation_failed = false;

View File

@ -519,6 +519,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/manager.h", "transform/manager.h",
"transform/renamer.cc", "transform/renamer.cc",
"transform/renamer.h", "transform/renamer.h",
"transform/single_entry_point.cc",
"transform/single_entry_point.h",
"transform/transform.cc", "transform/transform.cc",
"transform/transform.h", "transform/transform.h",
"transform/vertex_pulling.cc", "transform/vertex_pulling.cc",

View File

@ -275,6 +275,8 @@ set(TINT_LIB_SRCS
transform/manager.h transform/manager.h
transform/renamer.cc transform/renamer.cc
transform/renamer.h transform/renamer.h
transform/single_entry_point.cc
transform/single_entry_point.h
transform/transform.cc transform/transform.cc
transform/transform.h transform/transform.h
transform/vertex_pulling.cc transform/vertex_pulling.cc
@ -763,7 +765,6 @@ if(${TINT_BUILD_TESTS})
writer/wgsl/generator_impl_constructor_test.cc writer/wgsl/generator_impl_constructor_test.cc
writer/wgsl/generator_impl_continue_test.cc writer/wgsl/generator_impl_continue_test.cc
writer/wgsl/generator_impl_discard_test.cc writer/wgsl/generator_impl_discard_test.cc
writer/wgsl/generator_impl_entry_point_test.cc
writer/wgsl/generator_impl_fallthrough_test.cc writer/wgsl/generator_impl_fallthrough_test.cc
writer/wgsl/generator_impl_function_test.cc writer/wgsl/generator_impl_function_test.cc
writer/wgsl/generator_impl_global_decl_test.cc writer/wgsl/generator_impl_global_decl_test.cc
@ -791,6 +792,7 @@ if(${TINT_BUILD_TESTS})
transform/emit_vertex_point_size_test.cc transform/emit_vertex_point_size_test.cc
transform/first_index_offset_test.cc transform/first_index_offset_test.cc
transform/renamer_test.cc transform/renamer_test.cc
transform/single_entry_point.cc
transform/test_helper.h transform/test_helper.h
transform/vertex_pulling_test.cc transform/vertex_pulling_test.cc
) )

View File

@ -0,0 +1,105 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/single_entry_point.h"
#include <unordered_set>
#include <utility>
#include "src/program_builder.h"
#include "src/sem/function.h"
#include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint::Config);
namespace tint {
namespace transform {
SingleEntryPoint::SingleEntryPoint() = default;
SingleEntryPoint::~SingleEntryPoint() = default;
Output SingleEntryPoint::Run(const Program* in, const DataMap& data) {
ProgramBuilder out;
auto* cfg = data.Get<Config>();
if (cfg == nullptr) {
out.Diagnostics().add_error("missing transform data for SingleEntryPoint");
return Output(Program(std::move(out)));
}
// Find the target entry point.
ast::Function* entry_point = nullptr;
for (auto* f : in->AST().Functions()) {
if (!f->IsEntryPoint()) {
continue;
}
if (in->Symbols().NameFor(f->symbol()) == cfg->entry_point_name) {
entry_point = f;
break;
}
}
if (entry_point == nullptr) {
out.Diagnostics().add_error("entry point '" + cfg->entry_point_name +
"' not found");
return Output(Program(std::move(out)));
}
CloneContext ctx(&out, in);
auto* sem = in->Sem().Get(entry_point);
// Build set of referenced module-scope variables for faster lookups later.
std::unordered_set<const ast::Variable*> referenced_vars;
for (auto* var : sem->ReferencedModuleVariables()) {
referenced_vars.emplace(var->Declaration());
}
// Clone any module-scope variables, types, and functions that are statically
// referenced by the target entry point.
for (auto* decl : in->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<sem::Type>()) {
// TODO(jrprice): Strip unused types.
out.AST().AddConstructedType(ctx.Clone(ty));
} else if (auto* var = decl->As<ast::Variable>()) {
if (var->is_const() || referenced_vars.count(var)) {
out.AST().AddGlobalVariable(ctx.Clone(var));
}
} else if (auto* func = decl->As<ast::Function>()) {
if (in->Sem().Get(func)->HasAncestorEntryPoint(entry_point->symbol())) {
out.AST().AddFunction(ctx.Clone(func));
}
} else {
TINT_UNREACHABLE(out.Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;
return Output(Program(std::move(out)));
}
}
// Clone the entry point.
out.AST().AddFunction(ctx.Clone(entry_point));
return Output(Program(std::move(out)));
}
SingleEntryPoint::Config::Config(std::string entry_point)
: entry_point_name(entry_point) {}
SingleEntryPoint::Config::Config(const Config&) = default;
SingleEntryPoint::Config::~Config() = default;
SingleEntryPoint::Config& SingleEntryPoint::Config::operator=(const Config&) =
default;
} // namespace transform
} // namespace tint

View File

@ -0,0 +1,67 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_SINGLE_ENTRY_POINT_H_
#define SRC_TRANSFORM_SINGLE_ENTRY_POINT_H_
#include <string>
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// Strip all but one entry point a module.
///
/// All module-scope variables, types, and functions that are not used by the
/// target entry point will also be removed.
class SingleEntryPoint : public Transform {
public:
/// Configuration options for the transform
struct Config : public Castable<Config, Data> {
/// Constructor
/// @param entry_point the name of the entry point to keep
explicit Config(std::string entry_point = "");
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// Assignment operator
/// @returns this Config
Config& operator=(const Config&);
/// The name of the entry point to keep.
std::string entry_point_name;
};
/// Constructor
SingleEntryPoint();
/// Destructor
~SingleEntryPoint() override;
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_SINGLE_ENTRY_POINT_H_

View File

@ -0,0 +1,385 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/single_entry_point.h"
#include <utility>
#include "src/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using SingleEntryPointTest = TransformTest;
TEST_F(SingleEntryPointTest, Error_MissingTransformData) {
auto* src = "";
auto* expect = "error: missing transform data for SingleEntryPoint";
auto got = Run<SingleEntryPoint>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_NoEntryPoints) {
auto* src = "";
auto* expect = "error: entry point 'main' not found";
DataMap data;
data.Add<SingleEntryPoint::Config>("main");
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_InvalidEntryPoint) {
auto* src = R"(
[[stage(vertex)]]
fn main() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
)";
auto* expect = "error: entry point '_' not found";
SingleEntryPoint::Config cfg("_");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_NotAnEntryPoint) {
auto* src = R"(
fn foo() {}
[[stage(fragment)]]
fn main() {}
)";
auto* expect = "error: entry point 'foo' not found";
SingleEntryPoint::Config cfg("foo");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, SingleEntryPoint) {
auto* src = R"(
[[stage(compute)]]
fn main() {
}
)";
SingleEntryPoint::Config cfg("main");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(src, str(got));
}
TEST_F(SingleEntryPointTest, MultipleEntryPoints) {
auto* src = R"(
[[stage(vertex)]]
fn vert_main() {
}
[[stage(fragment)]]
fn frag_main() {
}
[[stage(compute)]]
fn comp_main1() {
}
[[stage(compute)]]
fn comp_main2() {
}
)";
auto* expect = R"(
[[stage(compute)]]
fn comp_main1() {
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalVariables) {
auto* src = R"(
var<private> a : f32;
var<private> b : f32;
var<private> c : f32;
var<private> d : f32;
[[stage(vertex)]]
fn vert_main() {
a = 0.0;
}
[[stage(fragment)]]
fn frag_main() {
b = 0.0;
}
[[stage(compute)]]
fn comp_main1() {
c = 0.0;
}
[[stage(compute)]]
fn comp_main2() {
d = 0.0;
}
)";
auto* expect = R"(
var<private> c : f32;
[[stage(compute)]]
fn comp_main1() {
c = 0.0;
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalConstants) {
auto* src = R"(
let a : f32 = 1.0;
let b : f32 = 1.0;
let c : f32 = 1.0;
let d : f32 = 1.0;
[[stage(vertex)]]
fn vert_main() {
let local_a : f32 = a;
}
[[stage(fragment)]]
fn frag_main() {
let local_b : f32 = b;
}
[[stage(compute)]]
fn comp_main1() {
let local_c : f32 = c;
}
[[stage(compute)]]
fn comp_main2() {
let local_d : f32 = d;
}
)";
auto* expect = R"(
let a : f32 = 1.0;
let b : f32 = 1.0;
let c : f32 = 1.0;
let d : f32 = 1.0;
[[stage(compute)]]
fn comp_main1() {
let local_c : f32 = c;
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, CalledFunctions) {
auto* src = R"(
fn inner1() {
}
fn inner2() {
}
fn inner_shared() {
}
fn outer1() {
inner1();
inner_shared();
}
fn outer2() {
inner2();
inner_shared();
}
[[stage(compute)]]
fn comp_main1() {
outer1();
}
[[stage(compute)]]
fn comp_main2() {
outer2();
}
)";
auto* expect = R"(
fn inner1() {
}
fn inner_shared() {
}
fn outer1() {
inner1();
inner_shared();
}
[[stage(compute)]]
fn comp_main1() {
outer1();
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalsReferencedByCalledFunctions) {
auto* src = R"(
var<private> inner1_var : f32;
var<private> inner2_var : f32;
var<private> inner_shared_var : f32;
var<private> outer1_var : f32;
var<private> outer2_var : f32;
fn inner1() {
inner1_var = 0.0;
}
fn inner2() {
inner2_var = 0.0;
}
fn inner_shared() {
inner_shared_var = 0.0;
}
fn outer1() {
inner1();
inner_shared();
outer1_var = 0.0;
}
fn outer2() {
inner2();
inner_shared();
outer2_var = 0.0;
}
[[stage(compute)]]
fn comp_main1() {
outer1();
}
[[stage(compute)]]
fn comp_main2() {
outer2();
}
)";
auto* expect = R"(
var<private> inner1_var : f32;
var<private> inner_shared_var : f32;
var<private> outer1_var : f32;
fn inner1() {
inner1_var = 0.0;
}
fn inner_shared() {
inner_shared_var = 0.0;
}
fn outer1() {
inner1();
inner_shared();
outer1_var = 0.0;
}
[[stage(compute)]]
fn comp_main1() {
outer1();
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@ -31,10 +31,6 @@ bool Generator::Generate() {
return ret; return ret;
} }
bool Generator::GenerateEntryPoint(ast::PipelineStage, const std::string&) {
return false;
}
std::string Generator::result() const { std::string Generator::result() const {
return out_.str(); return out_.str();
} }

View File

@ -39,13 +39,6 @@ class Generator : public Text {
/// @returns true on successful generation; false otherwise /// @returns true on successful generation; false otherwise
bool Generate() override; bool Generate() override;
/// Converts a single entry point
/// @param stage the pipeline stage
/// @param name the entry point name
/// @returns true on succes; false on failure
bool GenerateEntryPoint(ast::PipelineStage stage,
const std::string& name) override;
/// @returns the result data /// @returns the result data
std::string result() const override; std::string result() const override;

View File

@ -31,10 +31,6 @@ bool Generator::Generate() {
return ret; return ret;
} }
bool Generator::GenerateEntryPoint(ast::PipelineStage, const std::string&) {
return false;
}
std::string Generator::result() const { std::string Generator::result() const {
return impl_->result(); return impl_->result();
} }

View File

@ -39,13 +39,6 @@ class Generator : public Text {
/// @returns true on successful generation; false otherwise /// @returns true on successful generation; false otherwise
bool Generate() override; bool Generate() override;
/// Converts a single entry point
/// @param stage the pipeline stage
/// @param name the entry point name
/// @returns true on succes; false on failure
bool GenerateEntryPoint(ast::PipelineStage stage,
const std::string& name) override;
/// @returns the result data /// @returns the result data
std::string result() const override; std::string result() const override;

View File

@ -35,10 +35,6 @@ bool Generator::Generate() {
return true; return true;
} }
bool Generator::GenerateEntryPoint(ast::PipelineStage, const std::string&) {
return false;
}
} // namespace spirv } // namespace spirv
} // namespace writer } // namespace writer
} // namespace tint } // namespace tint

View File

@ -40,13 +40,6 @@ class Generator : public writer::Writer {
/// @returns true on successful generation; false otherwise /// @returns true on successful generation; false otherwise
bool Generate() override; bool Generate() override;
/// Converts a single entry point
/// @param stage the pipeline stage
/// @param name the entry point name
/// @returns true on succes; false on failure
bool GenerateEntryPoint(ast::PipelineStage stage,
const std::string& name) override;
/// @returns the result data /// @returns the result data
const std::vector<uint32_t>& result() const { return writer_->result(); } const std::vector<uint32_t>& result() const { return writer_->result(); }

View File

@ -31,15 +31,6 @@ bool Generator::Generate() {
return ret; return ret;
} }
bool Generator::GenerateEntryPoint(ast::PipelineStage stage,
const std::string& name) {
auto ret = impl_->GenerateEntryPoint(stage, name);
if (!ret) {
error_ = impl_->error();
}
return ret;
}
std::string Generator::result() const { std::string Generator::result() const {
return impl_->result(); return impl_->result();
} }

View File

@ -39,13 +39,6 @@ class Generator : public Text {
/// @returns true on successful generation; false otherwise /// @returns true on successful generation; false otherwise
bool Generate() override; bool Generate() override;
/// Converts a single entry point
/// @param stage the pipeline stage
/// @param name the entry point name
/// @returns true on succes; false on failure
bool GenerateEntryPoint(ast::PipelineStage stage,
const std::string& name) override;
/// @returns the result data /// @returns the result data
std::string result() const override; std::string result() const override;

View File

@ -104,17 +104,6 @@ bool GeneratorImpl::Generate(const ast::Function* entry) {
return true; return true;
} }
bool GeneratorImpl::GenerateEntryPoint(ast::PipelineStage stage,
const std::string& name) {
auto* func =
program_->AST().Functions().Find(program_->Symbols().Get(name), stage);
if (func == nullptr) {
diagnostics_.add_error("Unable to find requested entry point: " + name);
return false;
}
return Generate(func);
}
bool GeneratorImpl::EmitConstructedType(const sem::Type* ty) { bool GeneratorImpl::EmitConstructedType(const sem::Type* ty) {
make_indent(); make_indent();
if (auto* alias = ty->As<sem::Alias>()) { if (auto* alias = ty->As<sem::Alias>()) {

View File

@ -55,12 +55,6 @@ class GeneratorImpl : public TextGenerator {
/// @returns true on successful generation; false otherwise /// @returns true on successful generation; false otherwise
bool Generate(const ast::Function* entry); bool Generate(const ast::Function* entry);
/// Generates a single entry point
/// @param stage the pipeline stage
/// @param name the entry point name
/// @returns true on successful generation; false otherwise
bool GenerateEntryPoint(ast::PipelineStage stage, const std::string& name);
/// Handles generating a constructed type /// Handles generating a constructed type
/// @param ty the constructed to generate /// @param ty the constructed to generate
/// @returns true if the constructed was emitted /// @returns true if the constructed was emitted

View File

@ -1,149 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/call_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/variable_decl_statement.h"
#include "src/writer/wgsl/test_helper.h"
namespace tint {
namespace writer {
namespace wgsl {
namespace {
using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_UnusedFunction) {
Func("func_unused", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::DecorationList{});
Func("func_used", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::DecorationList{});
auto* call_func = Call("func_used");
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::CallStatement>(call_func),
},
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(
gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main"))
<< gen.error();
EXPECT_EQ(gen.result(), R"( fn func_used() {
}
[[stage(compute)]]
fn main() {
func_used();
}
)");
}
TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_UnusedVariable) {
Global("global_unused", ty.f32(), ast::StorageClass::kPrivate);
Global("global_used", ty.f32(), ast::StorageClass::kPrivate);
Func("main", {}, ty.void_(),
{
Assign("global_used", 1.f),
},
{
Stage(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(
gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main"))
<< gen.error();
EXPECT_EQ(gen.result(), R"( var<private> global_used : f32;
[[stage(compute)]]
fn main() {
global_used = 1.0;
}
)");
}
TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_GlobalsInterleaved) {
Global("a0", ty.f32(), ast::StorageClass::kPrivate);
auto s0 = Structure("S0", {Member("a", ty.i32())});
Func("func", {}, ty.f32(),
{
Return("a0"),
});
Global("a1", ty.f32(), ast::StorageClass::kOutput);
auto s1 = Structure("S1", {Member("a", ty.i32())});
Func("main", {}, ty.void_(),
{
Decl(Var("s0", s0, ast::StorageClass::kFunction)),
Decl(Var("s1", s1, ast::StorageClass::kFunction)),
Assign("a1", Call("func")),
},
{
Stage(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(
gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main"))
<< gen.error();
EXPECT_EQ(gen.result(), R"( var<private> a0 : f32;
struct S0 {
a : i32;
};
fn func() -> f32 {
return a0;
}
var<out> a1 : f32;
struct S1 {
a : i32;
};
[[stage(compute)]]
fn main() {
var s0 : S0;
var s1 : S1;
a1 = func();
}
)");
}
} // namespace
} // namespace wgsl
} // namespace writer
} // namespace tint

View File

@ -34,13 +34,6 @@ class Writer {
/// @returns true on success; false on failure /// @returns true on success; false on failure
virtual bool Generate() = 0; virtual bool Generate() = 0;
/// Converts a single entry point
/// @param stage the pipeline stage
/// @param name the entry point name
/// @returns true on succes; false on failure
virtual bool GenerateEntryPoint(ast::PipelineStage stage,
const std::string& name) = 0;
protected: protected:
/// Sets the error string /// Sets the error string
/// @param msg the error message /// @param msg the error message

View File

@ -296,6 +296,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/transform/emit_vertex_point_size_test.cc", "../src/transform/emit_vertex_point_size_test.cc",
"../src/transform/first_index_offset_test.cc", "../src/transform/first_index_offset_test.cc",
"../src/transform/renamer_test.cc", "../src/transform/renamer_test.cc",
"../src/transform/single_entry_point_test.cc",
"../src/transform/vertex_pulling_test.cc", "../src/transform/vertex_pulling_test.cc",
"../src/utils/command_test.cc", "../src/utils/command_test.cc",
"../src/utils/get_or_create_test.cc", "../src/utils/get_or_create_test.cc",
@ -491,7 +492,6 @@ tint_unittests_source_set("tint_unittests_wgsl_writer_src") {
"../src/writer/wgsl/generator_impl_constructor_test.cc", "../src/writer/wgsl/generator_impl_constructor_test.cc",
"../src/writer/wgsl/generator_impl_continue_test.cc", "../src/writer/wgsl/generator_impl_continue_test.cc",
"../src/writer/wgsl/generator_impl_discard_test.cc", "../src/writer/wgsl/generator_impl_discard_test.cc",
"../src/writer/wgsl/generator_impl_entry_point_test.cc",
"../src/writer/wgsl/generator_impl_fallthrough_test.cc", "../src/writer/wgsl/generator_impl_fallthrough_test.cc",
"../src/writer/wgsl/generator_impl_function_test.cc", "../src/writer/wgsl/generator_impl_function_test.cc",
"../src/writer/wgsl/generator_impl_global_decl_test.cc", "../src/writer/wgsl/generator_impl_global_decl_test.cc",