diff --git a/BUILD.gn b/BUILD.gn index 98337098da..a37e2a0629 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -369,6 +369,8 @@ source_set("libtint_core_src") { "src/ast/variable_decl_statement.h", "src/ast/variable_decoration.cc", "src/ast/variable_decoration.h", + "src/ast/workgroup_decoration.cc", + "src/ast/workgroup_decoration.h", "src/context.cc", "src/context.h", "src/reader/reader.cc", @@ -744,6 +746,7 @@ source_set("tint_unittests_core_src") { "src/ast/unary_op_expression_test.cc", "src/ast/variable_decl_statement_test.cc", "src/ast/variable_test.cc", + "src/ast/workgroup_decoration_test.cc", "src/scope_stack_test.cc", "src/type_determiner_test.cc", "src/type_manager_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f821429b03..04fbea405d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -190,6 +190,8 @@ set(TINT_LIB_SRCS ast/variable_decoration.h ast/variable_decl_statement.cc ast/variable_decl_statement.h + ast/workgroup_decoration.cc + ast/workgroup_decoration.h context.h context.cc reader/reader.cc @@ -354,6 +356,7 @@ set(TINT_TEST_SRCS ast/unary_op_expression_test.cc ast/variable_decl_statement_test.cc ast/variable_test.cc + ast/workgroup_decoration_test.cc scope_stack_test.cc type_determiner_test.cc type_manager_test.cc diff --git a/src/ast/function.cc b/src/ast/function.cc index 0e52bf9489..8e8d853ccf 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -17,6 +17,7 @@ #include #include "src/ast/decorated_variable.h" +#include "src/ast/workgroup_decoration.h" namespace tint { namespace ast { @@ -46,6 +47,15 @@ Function::Function(Function&&) = default; Function::~Function() = default; +std::tuple Function::workgroup_size() const { + for (const auto& deco : decorations_) { + if (deco->IsWorkgroup()) { + return deco->AsWorkgroup()->values(); + } + } + return {1, 1, 1}; +} + void Function::add_referenced_module_variable(Variable* var) { for (const auto* v : referenced_module_vars_) { if (v->name() == var->name()) { diff --git a/src/ast/function.h b/src/ast/function.h index b292464fb2..d64c0236b8 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -90,6 +90,10 @@ class Function : public Node { /// @returns the decorations attached to this function const FunctionDecorationList& decorations() const { return decorations_; } + /// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be + /// return if no workgroup size was set. + std::tuple workgroup_size() const; + /// Adds the given variable to the list of referenced module variables if it /// is not already included. /// @param var the module variable to add diff --git a/src/ast/function_decoration.cc b/src/ast/function_decoration.cc index a44729735b..c246b97786 100644 --- a/src/ast/function_decoration.cc +++ b/src/ast/function_decoration.cc @@ -14,6 +14,10 @@ #include "src/ast/function_decoration.h" +#include + +#include "src/ast/workgroup_decoration.h" + namespace tint { namespace ast { @@ -21,5 +25,14 @@ FunctionDecoration::FunctionDecoration() = default; FunctionDecoration::~FunctionDecoration() = default; +bool FunctionDecoration::IsWorkgroup() const { + return false; +} + +const WorkgroupDecoration* FunctionDecoration::AsWorkgroup() const { + assert(IsWorkgroup()); + return static_cast(this); +} + } // namespace ast } // namespace tint diff --git a/src/ast/function_decoration.h b/src/ast/function_decoration.h index 4b75f305cd..461a037ad2 100644 --- a/src/ast/function_decoration.h +++ b/src/ast/function_decoration.h @@ -22,11 +22,19 @@ namespace tint { namespace ast { +class WorkgroupDecoration; + /// A decoration attached to a function class FunctionDecoration { public: virtual ~FunctionDecoration(); + /// @returns true if this is a workgroup decoration + virtual bool IsWorkgroup() const; + + /// @returns the decoration as a workgroup decoration + const WorkgroupDecoration* AsWorkgroup() const; + /// Outputs the function decoration to the given stream /// @param out the stream to output too virtual void to_str(std::ostream& out) const = 0; diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index a75a619bcb..20d37edb94 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -24,7 +24,7 @@ #include "src/ast/type/i32_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" -// #include "src/ast/workgroup_decoration.h" +#include "src/ast/workgroup_decoration.h" namespace tint { namespace ast { @@ -298,27 +298,27 @@ TEST_F(FunctionTest, ToStr) { )"); } -// TEST_F(FunctionTest, ToStr_WithDecoration) { -// type::VoidType void_type; -// type::I32Type i32; +TEST_F(FunctionTest, ToStr_WithDecoration) { + type::VoidType void_type; + type::I32Type i32; -// auto block = std::make_unique(); -// block->append(std::make_unique()); + auto block = std::make_unique(); + block->append(std::make_unique()); -// Function f("func", {}, &void_type); -// f.set_body(std::move(block)); -// f.add_decoration(std::make_unique(2, 4, 6)); + Function f("func", {}, &void_type); + f.set_body(std::move(block)); + f.add_decoration(std::make_unique(2, 4, 6)); -// std::ostringstream out; -// f.to_str(out, 2); -// EXPECT_EQ(out.str(), R"( Function func -> __void -// workgroup_size 2 4 6 -// () -// { -// Discard{} -// } -// )"); -// } + std::ostringstream out; + f.to_str(out, 2); + EXPECT_EQ(out.str(), R"( Function func -> __void + WorkgroupDecoration{2 4 6} + () + { + Discard{} + } +)"); +} TEST_F(FunctionTest, ToStr_WithParams) { type::VoidType void_type; @@ -396,6 +396,33 @@ TEST_F(FunctionTest, GetLastStatement_nullptr) { EXPECT_EQ(f.get_last_statement(), nullptr); } + +TEST_F(FunctionTest, WorkgroupSize_NoneSet) { + type::VoidType void_type; + Function f("f", {}, &void_type); + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = f.workgroup_size(); + EXPECT_EQ(x, 1u); + EXPECT_EQ(y, 1u); + EXPECT_EQ(z, 1u); +} + +TEST_F(FunctionTest, WorkgroupSize) { + type::VoidType void_type; + Function f("f", {}, &void_type); + f.add_decoration(std::make_unique(2u, 4u, 6u)); + + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = f.workgroup_size(); + EXPECT_EQ(x, 2u); + EXPECT_EQ(y, 4u); + EXPECT_EQ(z, 6u); +} + } // namespace } // namespace ast } // namespace tint diff --git a/src/ast/workgroup_decoration.cc b/src/ast/workgroup_decoration.cc new file mode 100644 index 0000000000..fd44db4cc5 --- /dev/null +++ b/src/ast/workgroup_decoration.cc @@ -0,0 +1,40 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/workgroup_decoration.h" + +namespace tint { +namespace ast { + +WorkgroupDecoration::WorkgroupDecoration(uint32_t x) : x_(x) {} + +WorkgroupDecoration::WorkgroupDecoration(uint32_t x, uint32_t y) + : x_(x), y_(y) {} + +WorkgroupDecoration::WorkgroupDecoration(uint32_t x, uint32_t y, uint32_t z) + : x_(x), y_(y), z_(z) {} + +WorkgroupDecoration::~WorkgroupDecoration() = default; + +bool WorkgroupDecoration::IsWorkgroup() const { + return true; +} + +void WorkgroupDecoration::to_str(std::ostream& out) const { + out << "WorkgroupDecoration{" << x_ << " " << y_ << " " << z_ << "}" + << std::endl; +} + +} // namespace ast +} // namespace tint diff --git a/src/ast/workgroup_decoration.h b/src/ast/workgroup_decoration.h new file mode 100644 index 0000000000..04678d65c0 --- /dev/null +++ b/src/ast/workgroup_decoration.h @@ -0,0 +1,65 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_AST_WORKGROUP_DECORATION_H_ +#define SRC_AST_WORKGROUP_DECORATION_H_ + +#include + +#include + +#include "src/ast/function_decoration.h" + +namespace tint { +namespace ast { + +/// A workgroup decoration +class WorkgroupDecoration : public FunctionDecoration { + public: + /// constructor + /// @param x the workgroup x dimension size + explicit WorkgroupDecoration(uint32_t x); + /// constructor + /// @param x the workgroup x dimension size + /// @param y the workgroup x dimension size + WorkgroupDecoration(uint32_t x, uint32_t y); + /// constructor + /// @param x the workgroup x dimension size + /// @param y the workgroup x dimension size + /// @param z the workgroup x dimension size + WorkgroupDecoration(uint32_t x, uint32_t y, uint32_t z); + ~WorkgroupDecoration() override; + + /// @returns true if this is a workgroup decoration + bool IsWorkgroup() const override; + + /// @returns the workgroup dimensions + std::tuple values() const { + return {x_, y_, z_}; + } + + /// Outputs the decoration to the given stream + /// @param out the stream to output too + void to_str(std::ostream& out) const override; + + private: + uint32_t x_ = 1; + uint32_t y_ = 1; + uint32_t z_ = 1; +}; + +} // namespace ast +} // namespace tint + +#endif // SRC_AST_WORKGROUP_DECORATION_H_ diff --git a/src/ast/workgroup_decoration_test.cc b/src/ast/workgroup_decoration_test.cc new file mode 100644 index 0000000000..750d351925 --- /dev/null +++ b/src/ast/workgroup_decoration_test.cc @@ -0,0 +1,74 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/workgroup_decoration.h" + +#include + +#include "gtest/gtest.h" + +namespace tint { +namespace ast { +namespace { + +using WorkgroupDecorationTest = testing::Test; + +TEST_F(WorkgroupDecorationTest, Creation_1param) { + WorkgroupDecoration d{2}; + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = d.values(); + EXPECT_EQ(x, 2u); + EXPECT_EQ(y, 1u); + EXPECT_EQ(z, 1u); +} +TEST_F(WorkgroupDecorationTest, Creation_2param) { + WorkgroupDecoration d{2, 4}; + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = d.values(); + EXPECT_EQ(x, 2u); + EXPECT_EQ(y, 4u); + EXPECT_EQ(z, 1u); +} + +TEST_F(WorkgroupDecorationTest, Creation_3param) { + WorkgroupDecoration d{2, 4, 6}; + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = d.values(); + EXPECT_EQ(x, 2u); + EXPECT_EQ(y, 4u); + EXPECT_EQ(z, 6u); +} + +TEST_F(WorkgroupDecorationTest, Is) { + WorkgroupDecoration d{2, 4, 6}; + EXPECT_TRUE(d.IsWorkgroup()); +} + +TEST_F(WorkgroupDecorationTest, ToStr) { + WorkgroupDecoration d{2, 4, 6}; + std::ostringstream out; + d.to_str(out); + EXPECT_EQ(out.str(), R"(WorkgroupDecoration{2 4 6} +)"); +} + +} // namespace +} // namespace ast +} // namespace tint diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 43270f93dc..d34e572a20 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1428,9 +1428,12 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, } if (ep->stage() == ast::PipelineStage::kCompute) { - // TODO(dsinclair): When we have a way to set the thread group size this - // should be updated. - out << "[numthreads(1, 1, 1)]" << std::endl; + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = func->workgroup_size(); + out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y) + << ", " << std::to_string(z) << ")]" << std::endl; make_indent(out); } diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 80f5b82c76..2056427ea0 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -39,6 +39,7 @@ #include "src/ast/type/void_type.h" #include "src/ast/variable.h" #include "src/ast/variable_decl_statement.h" +#include "src/ast/workgroup_decoration.h" #include "src/context.h" #include "src/type_determiner.h" #include "src/writer/hlsl/test_helper.h" @@ -1223,6 +1224,35 @@ void main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_Function_EntryPoint_Compute_WithWorkgroup) { + ast::type::VoidType void_type; + + ast::VariableList params; + auto func = std::make_unique("comp_main", std::move(params), + &void_type); + func->add_decoration(std::make_unique(2u, 4u, 6u)); + + auto body = std::make_unique(); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kCompute, + "main", "comp_main"); + mod()->AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"([numthreads(2, 4, 6)] +void main() { + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { ast::type::F32Type f32; ast::type::ArrayType ary(&f32, 5); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 9754fb2c77..879ded399d 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -359,10 +359,15 @@ bool Builder::GenerateExecutionModes(ast::EntryPoint* ep) { spv::Op::OpExecutionMode, {Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)}); } else if (ep->stage() == ast::PipelineStage::kCompute) { - // TODO(dsinclair): Support LocalSize other then (1, 1, 1) + auto* func = func_name_to_func_[ep->function_name()]; + + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = func->workgroup_size(); push_preamble(spv::Op::OpExecutionMode, {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize), - Operand::Int(1), Operand::Int(1), Operand::Int(1)}); + Operand::Int(x), Operand::Int(y), Operand::Int(z)}); } return true; diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index 04176c1654..85af7cfbb2 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc @@ -25,6 +25,7 @@ #include "src/ast/type/f32_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" +#include "src/ast/workgroup_decoration.h" #include "src/context.h" #include "src/type_determiner.h" #include "src/writer/spirv/builder.h" @@ -264,6 +265,23 @@ TEST_F(BuilderTest, ExecutionModel_Compute_LocalSize) { )"); } +TEST_F(BuilderTest, ExecutionModel_Compute_LocalSize_WithWorkgroup) { + ast::type::VoidType void_type; + + ast::Function func("main", {}, &void_type); + func.add_decoration(std::make_unique(2u, 4u, 6u)); + ast::EntryPoint ep(ast::PipelineStage::kCompute, "main", "main"); + + ast::Module mod; + Builder b(&mod); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateExecutionModes(&ep)); + + EXPECT_EQ(DumpInstructions(b.preamble()), + R"(OpExecutionMode %3 LocalSize 2 4 6 +)"); +} + } // namespace } // namespace spirv } // namespace writer diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 11d36904e7..08afc35e10 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -63,6 +63,7 @@ #include "src/ast/uint_literal.h" #include "src/ast/unary_op_expression.h" #include "src/ast/variable_decl_statement.h" +#include "src/ast/workgroup_decoration.h" namespace tint { namespace writer { @@ -422,8 +423,21 @@ bool GeneratorImpl::EmitImport(const ast::Import* import) { } bool GeneratorImpl::EmitFunction(ast::Function* func) { - make_indent(); + for (auto& deco : func->decorations()) { + make_indent(); + out_ << "[["; + if (deco->IsWorkgroup()) { + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + std::tie(x, y, z) = deco->AsWorkgroup()->values(); + out_ << "workgroup_size(" << std::to_string(x) << ", " + << std::to_string(y) << ", " << std::to_string(z) << ")"; + } + out_ << "]]" << std::endl; + } + make_indent(); out_ << "fn " << func->name() << "("; bool first = true; diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index 0240a6277a..25bac24ed0 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -20,6 +20,7 @@ #include "src/ast/type/i32_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" +#include "src/ast/workgroup_decoration.h" #include "src/writer/wgsl/generator_impl.h" namespace tint { @@ -77,6 +78,28 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) { )"); } +TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecorations) { + auto body = std::make_unique(); + body->append(std::make_unique()); + body->append(std::make_unique()); + + ast::type::VoidType void_type; + ast::Function func("my_func", {}, &void_type); + func.add_decoration(std::make_unique(2u, 4u, 6u)); + func.set_body(std::move(body)); + + GeneratorImpl g; + g.increment_indent(); + + ASSERT_TRUE(g.EmitFunction(&func)); + EXPECT_EQ(g.result(), R"( [[workgroup_size(2, 4, 6)]] + fn my_func() -> void { + discard; + return; + } +)"); +} + } // namespace } // namespace wgsl } // namespace writer