From ce8f86881503c256a68dfe052d581cf3243f34ff Mon Sep 17 00:00:00 2001 From: James Price Date: Wed, 19 May 2021 08:15:18 +0000 Subject: [PATCH] Move workgroup_size property into sem::Function The workgroup size should not be a property of the function in the AST, and this lays the groundwork for allowing both literals and module-scope constants to be used for this attribute. Bug: tint:713 Change-Id: I014be879e2adb81cfc5b0ea0e221035fae626223 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51261 Auto-Submit: James Price Commit-Queue: Ben Clayton Kokoro: Kokoro Reviewed-by: Ben Clayton --- src/ast/function.cc | 7 ------- src/ast/function.h | 4 ---- src/ast/function_test.cc | 25 ---------------------- src/inspector/inspector.cc | 12 +++++++++-- src/resolver/resolver.cc | 16 +++++++++++++- src/resolver/resolver.h | 2 ++ src/resolver/resolver_test.cc | 35 +++++++++++++++++++++++++++++++ src/sem/function.cc | 6 ++++-- src/sem/function.h | 21 ++++++++++++++++++- src/writer/hlsl/generator_impl.cc | 15 +++++++++---- src/writer/spirv/builder.cc | 17 ++++++++++----- 11 files changed, 109 insertions(+), 51 deletions(-) diff --git a/src/ast/function.cc b/src/ast/function.cc index 24b0227210..0896e7ac4f 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -58,13 +58,6 @@ Function::Function(Function&&) = default; Function::~Function() = default; -std::tuple Function::workgroup_size() const { - if (auto* workgroup = GetDecoration(decorations_)) { - return workgroup->values(); - } - return {1, 1, 1}; -} - PipelineStage Function::pipeline_stage() const { if (auto* stage = GetDecoration(decorations_)) { return stage->value(); diff --git a/src/ast/function.h b/src/ast/function.h index fd91c7c29b..397e1c4389 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -66,10 +66,6 @@ class Function : public Castable { /// @returns the decorations attached to this function const DecorationList& decorations() const { return decorations_; } - /// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be - /// return if no workgroup size was set. - std::tuple workgroup_size() const; - /// @returns the functions pipeline stage or None if not set PipelineStage pipeline_stage() const; diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index 60c1525edc..44ba686b7e 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -225,31 +225,6 @@ TEST_F(FunctionTest, GetLastStatement_nullptr) { EXPECT_EQ(f->get_last_statement(), nullptr); } -TEST_F(FunctionTest, WorkgroupSize_NoneSet) { - auto* f = Func("func", VariableList{}, ty.void_(), StatementList{}, - DecorationList{}); - uint32_t x = 0; - uint32_t y = 0; - uint32_t z = 0; - std::tie(x, y, z) = f->workgroup_size(); - EXPECT_EQ(x, 1u); - EXPECT_EQ(y, 1u); - EXPECT_EQ(z, 1u); -} - -TEST_F(FunctionTest, WorkgroupSize) { - auto* f = Func("func", VariableList{}, ty.void_(), StatementList{}, - DecorationList{create(2u, 4u, 6u)}); - - uint32_t x = 0; - uint32_t y = 0; - uint32_t z = 0; - std::tie(x, y, z) = f->workgroup_size(); - EXPECT_EQ(x, 2u); - EXPECT_EQ(y, 4u); - EXPECT_EQ(z, 6u); -} - using FunctionListTest = TestHelper; TEST_F(FunctionListTest, FindSymbol) { diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index b8ac9b454e..b687b22260 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -198,8 +198,16 @@ std::vector Inspector::GetEntryPoints() { entry_point.name = program_->Symbols().NameFor(func->symbol()); entry_point.remapped_name = program_->Symbols().NameFor(func->symbol()); entry_point.stage = func->pipeline_stage(); - std::tie(entry_point.workgroup_size_x, entry_point.workgroup_size_y, - entry_point.workgroup_size_z) = func->workgroup_size(); + + auto wgsize = sem->workgroup_size(); + entry_point.workgroup_size_x = wgsize[0].value; + entry_point.workgroup_size_y = wgsize[1].value; + entry_point.workgroup_size_z = wgsize[2].value; + if (wgsize[0].overridable_const || wgsize[1].overridable_const || + wgsize[2].overridable_const) { + // TODO(crbug.com/tint/713): Handle overridable constants. + TINT_ASSERT(false); + } for (auto* param : sem->Parameters()) { AddEntryPointInOutVariables( diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index a80472b180..93266aa193 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1287,6 +1287,20 @@ bool Resolver::Function(ast::Function* func) { Mark(deco); } + // Set work-group size defaults. + for (int i = 0; i < 3; i++) { + info->workgroup_size[i].value = 1; + info->workgroup_size[i].overridable_const = nullptr; + } + + if (auto* workgroup = + ast::GetDecoration(func->decorations())) { + // TODO(crbug.com/tint/713): Handle non-literals. + info->workgroup_size[0].value = std::get<0>(workgroup->values()); + info->workgroup_size[1].value = std::get<1>(workgroup->values()); + info->workgroup_size[2].value = std::get<2>(workgroup->values()); + } + if (!ValidateFunction(func, info)) { return false; } @@ -2517,7 +2531,7 @@ void Resolver::CreateSemanticNodes() const { info->declaration, const_cast(info->return_type), remap_vars(info->parameters), remap_vars(info->referenced_module_vars), remap_vars(info->local_referenced_module_vars), info->return_statements, - ancestor_entry_points[func->symbol()]); + ancestor_entry_points[func->symbol()], info->workgroup_size); func_info_to_sem_func.emplace(info, sem_func); sem.Add(func, sem_func); } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 76a722f281..ad428a8ef5 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -26,6 +26,7 @@ #include "src/scope_stack.h" #include "src/sem/binding_point.h" #include "src/sem/block_statement.h" +#include "src/sem/function.h" #include "src/sem/struct.h" #include "src/utils/unique_vector.h" @@ -112,6 +113,7 @@ class Resolver { std::vector return_statements; sem::Type* return_type = nullptr; std::string return_type_name; + std::array workgroup_size; // List of transitive calls this function makes UniqueVector transitive_calls; diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index e81be395d8..109b31a279 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -32,6 +32,7 @@ #include "src/ast/switch_statement.h" #include "src/ast/unary_op_expression.h" #include "src/ast/variable_decl_statement.h" +#include "src/ast/workgroup_decoration.h" #include "src/resolver/resolver_test_helper.h" #include "src/sem/call.h" #include "src/sem/function.h" @@ -887,6 +888,40 @@ TEST_F(ResolverTest, Function_ReturnStatements) { EXPECT_TRUE(func_sem->ReturnType()->Is()); } +TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) { + auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {}); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->workgroup_size()[0].value, 1u); + EXPECT_EQ(func_sem->workgroup_size()[1].value, 1u); + EXPECT_EQ(func_sem->workgroup_size()[2].value, 1u); + EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr); + EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr); + EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr); +} + +TEST_F(ResolverTest, Function_WorkgroupSize_Literals) { + auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), + create(8, 2, 3)}); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u); + EXPECT_EQ(func_sem->workgroup_size()[1].value, 2u); + EXPECT_EQ(func_sem->workgroup_size()[2].value, 3u); + EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr); + EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr); + EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr); +} + TEST_F(ResolverTest, Expr_MemberAccessor_Struct) { auto* st = Structure("S", {Member("first_member", ty.i32()), Member("second_member", ty.f32())}); diff --git a/src/sem/function.cc b/src/sem/function.cc index 0c5e8443ca..f33fc07258 100644 --- a/src/sem/function.cc +++ b/src/sem/function.cc @@ -46,14 +46,16 @@ Function::Function(ast::Function* declaration, std::vector referenced_module_vars, std::vector local_referenced_module_vars, std::vector return_statements, - std::vector ancestor_entry_points) + std::vector ancestor_entry_points, + std::array workgroup_size) : Base(return_type, GetParameters(parameters)), declaration_(declaration), parameters_(std::move(parameters)), referenced_module_vars_(std::move(referenced_module_vars)), local_referenced_module_vars_(std::move(local_referenced_module_vars)), return_statements_(std::move(return_statements)), - ancestor_entry_points_(std::move(ancestor_entry_points)) {} + ancestor_entry_points_(std::move(ancestor_entry_points)), + workgroup_size_(std::move(workgroup_size)) {} Function::~Function() = default; diff --git a/src/sem/function.h b/src/sem/function.h index a29e318251..94cc02e0ed 100644 --- a/src/sem/function.h +++ b/src/sem/function.h @@ -15,6 +15,7 @@ #ifndef SRC_SEM_FUNCTION_H_ #define SRC_SEM_FUNCTION_H_ +#include #include #include @@ -37,6 +38,16 @@ namespace sem { class Variable; +/// WorkgroupDimension describes the size of a single dimension of an entry +/// point's workgroup size. +struct WorkgroupDimension { + /// The size of this dimension. + uint32_t value; + /// A pipeline-overridable constant that overrides the size, or nullptr if + /// this dimension is not overridable. + const ast::Variable* overridable_const = nullptr; +}; + /// Function holds the semantic information for function nodes. class Function : public Castable { public: @@ -53,13 +64,15 @@ class Function : public Castable { /// @param return_statements the function return statements /// variables /// @param ancestor_entry_points the ancestor entry points + /// @param workgroup_size the workgroup size Function(ast::Function* declaration, Type* return_type, std::vector parameters, std::vector referenced_module_vars, std::vector local_referenced_module_vars, std::vector return_statements, - std::vector ancestor_entry_points); + std::vector ancestor_entry_points, + std::array workgroup_size); /// Destructor ~Function() override; @@ -148,6 +161,11 @@ class Function : public Castable { /// @returns true if `sym` is an ancestor entry point of this function bool HasAncestorEntryPoint(Symbol sym) const; + /// @returns the workgroup size {x, y, z} for the function. + const std::array& workgroup_size() const { + return workgroup_size_; + } + private: VariableBindings ReferencedSamplerVariablesImpl(ast::SamplerKind kind) const; VariableBindings ReferencedSampledTextureVariablesImpl( @@ -159,6 +177,7 @@ class Function : public Castable { std::vector const local_referenced_module_vars_; std::vector const return_statements_; std::vector const ancestor_entry_points_; + std::array workgroup_size_; }; } // namespace sem diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index e894106798..b5165338ab 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1989,12 +1989,19 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, make_indent(out); current_ep_sym_ = func->symbol(); + auto* func_sem = builder_.Sem().Get(func); if (func->pipeline_stage() == ast::PipelineStage::kCompute) { - uint32_t x = 0; - uint32_t y = 0; - uint32_t z = 0; - std::tie(x, y, z) = func->workgroup_size(); + auto wgsize = func_sem->workgroup_size(); + if (wgsize[0].overridable_const || wgsize[1].overridable_const || + wgsize[2].overridable_const) { + // TODO(crbug.com/tint/713): Handle overridable constants. + TINT_UNIMPLEMENTED(builder_.Diagnostics()) + << "pipeline-overridable workgroup sizes are not implemented"; + } + uint32_t x = wgsize[0].value; + uint32_t y = wgsize[1].value; + uint32_t z = wgsize[2].value; out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y) << ", " << std::to_string(z) << ")]" << std::endl; make_indent(out); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 22f4a6a71a..60c159c52f 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -435,23 +435,30 @@ bool Builder::GenerateEntryPoint(ast::Function* func, uint32_t id) { } bool Builder::GenerateExecutionModes(ast::Function* func, uint32_t id) { + auto* func_sem = builder_.Sem().Get(func); + // WGSL fragment shader origin is upper left if (func->pipeline_stage() == ast::PipelineStage::kFragment) { push_execution_mode( spv::Op::OpExecutionMode, {Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)}); } else if (func->pipeline_stage() == ast::PipelineStage::kCompute) { - uint32_t x = 0; - uint32_t y = 0; - uint32_t z = 0; - std::tie(x, y, z) = func->workgroup_size(); + auto& wgsize = func_sem->workgroup_size(); + if (wgsize[0].overridable_const || wgsize[1].overridable_const || + wgsize[2].overridable_const) { + // TODO(crbug.com/tint/713): Handle overridable constants. + TINT_UNIMPLEMENTED(builder_.Diagnostics()) + << "pipeline-overridable workgroup sizes are not implemented"; + } + uint32_t x = wgsize[0].value; + uint32_t y = wgsize[1].value; + uint32_t z = wgsize[2].value; push_execution_mode( spv::Op::OpExecutionMode, {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize), Operand::Int(x), Operand::Int(y), Operand::Int(z)}); } - auto* func_sem = builder_.Sem().Get(func); for (auto builtin : func_sem->ReferencedBuiltinVariables()) { if (builtin.second->value() == ast::Builtin::kFragDepth) { push_execution_mode(