mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-07-04 12:16:10 +00:00
Move workgroup_size property into sem::Function
The workgroup size should not be a property of the function in the AST, and this lays the groundwork for allowing both literals and module-scope constants to be used for this attribute. Bug: tint:713 Change-Id: I014be879e2adb81cfc5b0ea0e221035fae626223 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51261 Auto-Submit: James Price <jrprice@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
594075a2f0
commit
ce8f868815
@ -58,13 +58,6 @@ Function::Function(Function&&) = default;
|
|||||||
|
|
||||||
Function::~Function() = default;
|
Function::~Function() = default;
|
||||||
|
|
||||||
std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
|
|
||||||
if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
|
|
||||||
return workgroup->values();
|
|
||||||
}
|
|
||||||
return {1, 1, 1};
|
|
||||||
}
|
|
||||||
|
|
||||||
PipelineStage Function::pipeline_stage() const {
|
PipelineStage Function::pipeline_stage() const {
|
||||||
if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
|
if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
|
||||||
return stage->value();
|
return stage->value();
|
||||||
|
@ -66,10 +66,6 @@ class Function : public Castable<Function, Node> {
|
|||||||
/// @returns the decorations attached to this function
|
/// @returns the decorations attached to this function
|
||||||
const DecorationList& decorations() const { return decorations_; }
|
const DecorationList& decorations() const { return decorations_; }
|
||||||
|
|
||||||
/// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
|
|
||||||
/// return if no workgroup size was set.
|
|
||||||
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
|
|
||||||
|
|
||||||
/// @returns the functions pipeline stage or None if not set
|
/// @returns the functions pipeline stage or None if not set
|
||||||
PipelineStage pipeline_stage() const;
|
PipelineStage pipeline_stage() const;
|
||||||
|
|
||||||
|
@ -225,31 +225,6 @@ TEST_F(FunctionTest, GetLastStatement_nullptr) {
|
|||||||
EXPECT_EQ(f->get_last_statement(), nullptr);
|
EXPECT_EQ(f->get_last_statement(), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
|
|
||||||
auto* f = Func("func", VariableList{}, ty.void_(), StatementList{},
|
|
||||||
DecorationList{});
|
|
||||||
uint32_t x = 0;
|
|
||||||
uint32_t y = 0;
|
|
||||||
uint32_t z = 0;
|
|
||||||
std::tie(x, y, z) = f->workgroup_size();
|
|
||||||
EXPECT_EQ(x, 1u);
|
|
||||||
EXPECT_EQ(y, 1u);
|
|
||||||
EXPECT_EQ(z, 1u);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(FunctionTest, WorkgroupSize) {
|
|
||||||
auto* f = Func("func", VariableList{}, ty.void_(), StatementList{},
|
|
||||||
DecorationList{create<WorkgroupDecoration>(2u, 4u, 6u)});
|
|
||||||
|
|
||||||
uint32_t x = 0;
|
|
||||||
uint32_t y = 0;
|
|
||||||
uint32_t z = 0;
|
|
||||||
std::tie(x, y, z) = f->workgroup_size();
|
|
||||||
EXPECT_EQ(x, 2u);
|
|
||||||
EXPECT_EQ(y, 4u);
|
|
||||||
EXPECT_EQ(z, 6u);
|
|
||||||
}
|
|
||||||
|
|
||||||
using FunctionListTest = TestHelper;
|
using FunctionListTest = TestHelper;
|
||||||
|
|
||||||
TEST_F(FunctionListTest, FindSymbol) {
|
TEST_F(FunctionListTest, FindSymbol) {
|
||||||
|
@ -198,8 +198,16 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
|||||||
entry_point.name = program_->Symbols().NameFor(func->symbol());
|
entry_point.name = program_->Symbols().NameFor(func->symbol());
|
||||||
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol());
|
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol());
|
||||||
entry_point.stage = func->pipeline_stage();
|
entry_point.stage = func->pipeline_stage();
|
||||||
std::tie(entry_point.workgroup_size_x, entry_point.workgroup_size_y,
|
|
||||||
entry_point.workgroup_size_z) = func->workgroup_size();
|
auto wgsize = sem->workgroup_size();
|
||||||
|
entry_point.workgroup_size_x = wgsize[0].value;
|
||||||
|
entry_point.workgroup_size_y = wgsize[1].value;
|
||||||
|
entry_point.workgroup_size_z = wgsize[2].value;
|
||||||
|
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
|
||||||
|
wgsize[2].overridable_const) {
|
||||||
|
// TODO(crbug.com/tint/713): Handle overridable constants.
|
||||||
|
TINT_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
for (auto* param : sem->Parameters()) {
|
for (auto* param : sem->Parameters()) {
|
||||||
AddEntryPointInOutVariables(
|
AddEntryPointInOutVariables(
|
||||||
|
@ -1287,6 +1287,20 @@ bool Resolver::Function(ast::Function* func) {
|
|||||||
Mark(deco);
|
Mark(deco);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set work-group size defaults.
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
info->workgroup_size[i].value = 1;
|
||||||
|
info->workgroup_size[i].overridable_const = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto* workgroup =
|
||||||
|
ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
|
||||||
|
// TODO(crbug.com/tint/713): Handle non-literals.
|
||||||
|
info->workgroup_size[0].value = std::get<0>(workgroup->values());
|
||||||
|
info->workgroup_size[1].value = std::get<1>(workgroup->values());
|
||||||
|
info->workgroup_size[2].value = std::get<2>(workgroup->values());
|
||||||
|
}
|
||||||
|
|
||||||
if (!ValidateFunction(func, info)) {
|
if (!ValidateFunction(func, info)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -2517,7 +2531,7 @@ void Resolver::CreateSemanticNodes() const {
|
|||||||
info->declaration, const_cast<sem::Type*>(info->return_type),
|
info->declaration, const_cast<sem::Type*>(info->return_type),
|
||||||
remap_vars(info->parameters), remap_vars(info->referenced_module_vars),
|
remap_vars(info->parameters), remap_vars(info->referenced_module_vars),
|
||||||
remap_vars(info->local_referenced_module_vars), info->return_statements,
|
remap_vars(info->local_referenced_module_vars), info->return_statements,
|
||||||
ancestor_entry_points[func->symbol()]);
|
ancestor_entry_points[func->symbol()], info->workgroup_size);
|
||||||
func_info_to_sem_func.emplace(info, sem_func);
|
func_info_to_sem_func.emplace(info, sem_func);
|
||||||
sem.Add(func, sem_func);
|
sem.Add(func, sem_func);
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
#include "src/scope_stack.h"
|
#include "src/scope_stack.h"
|
||||||
#include "src/sem/binding_point.h"
|
#include "src/sem/binding_point.h"
|
||||||
#include "src/sem/block_statement.h"
|
#include "src/sem/block_statement.h"
|
||||||
|
#include "src/sem/function.h"
|
||||||
#include "src/sem/struct.h"
|
#include "src/sem/struct.h"
|
||||||
#include "src/utils/unique_vector.h"
|
#include "src/utils/unique_vector.h"
|
||||||
|
|
||||||
@ -112,6 +113,7 @@ class Resolver {
|
|||||||
std::vector<const ast::ReturnStatement*> return_statements;
|
std::vector<const ast::ReturnStatement*> return_statements;
|
||||||
sem::Type* return_type = nullptr;
|
sem::Type* return_type = nullptr;
|
||||||
std::string return_type_name;
|
std::string return_type_name;
|
||||||
|
std::array<sem::WorkgroupDimension, 3> workgroup_size;
|
||||||
|
|
||||||
// List of transitive calls this function makes
|
// List of transitive calls this function makes
|
||||||
UniqueVector<FunctionInfo*> transitive_calls;
|
UniqueVector<FunctionInfo*> transitive_calls;
|
||||||
|
@ -32,6 +32,7 @@
|
|||||||
#include "src/ast/switch_statement.h"
|
#include "src/ast/switch_statement.h"
|
||||||
#include "src/ast/unary_op_expression.h"
|
#include "src/ast/unary_op_expression.h"
|
||||||
#include "src/ast/variable_decl_statement.h"
|
#include "src/ast/variable_decl_statement.h"
|
||||||
|
#include "src/ast/workgroup_decoration.h"
|
||||||
#include "src/resolver/resolver_test_helper.h"
|
#include "src/resolver/resolver_test_helper.h"
|
||||||
#include "src/sem/call.h"
|
#include "src/sem/call.h"
|
||||||
#include "src/sem/function.h"
|
#include "src/sem/function.h"
|
||||||
@ -887,6 +888,40 @@ TEST_F(ResolverTest, Function_ReturnStatements) {
|
|||||||
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::F32>());
|
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::F32>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
|
||||||
|
auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
|
||||||
|
auto* func_sem = Sem().Get(func);
|
||||||
|
ASSERT_NE(func_sem, nullptr);
|
||||||
|
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[0].value, 1u);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[1].value, 1u);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[2].value, 1u);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
|
||||||
|
auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
|
||||||
|
{Stage(ast::PipelineStage::kCompute),
|
||||||
|
create<ast::WorkgroupDecoration>(8, 2, 3)});
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
|
||||||
|
auto* func_sem = Sem().Get(func);
|
||||||
|
ASSERT_NE(func_sem, nullptr);
|
||||||
|
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[1].value, 2u);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[2].value, 3u);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
|
||||||
|
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
|
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
|
||||||
auto* st = Structure("S", {Member("first_member", ty.i32()),
|
auto* st = Structure("S", {Member("first_member", ty.i32()),
|
||||||
Member("second_member", ty.f32())});
|
Member("second_member", ty.f32())});
|
||||||
|
@ -46,14 +46,16 @@ Function::Function(ast::Function* declaration,
|
|||||||
std::vector<const Variable*> referenced_module_vars,
|
std::vector<const Variable*> referenced_module_vars,
|
||||||
std::vector<const Variable*> local_referenced_module_vars,
|
std::vector<const Variable*> local_referenced_module_vars,
|
||||||
std::vector<const ast::ReturnStatement*> return_statements,
|
std::vector<const ast::ReturnStatement*> return_statements,
|
||||||
std::vector<Symbol> ancestor_entry_points)
|
std::vector<Symbol> ancestor_entry_points,
|
||||||
|
std::array<WorkgroupDimension, 3> workgroup_size)
|
||||||
: Base(return_type, GetParameters(parameters)),
|
: Base(return_type, GetParameters(parameters)),
|
||||||
declaration_(declaration),
|
declaration_(declaration),
|
||||||
parameters_(std::move(parameters)),
|
parameters_(std::move(parameters)),
|
||||||
referenced_module_vars_(std::move(referenced_module_vars)),
|
referenced_module_vars_(std::move(referenced_module_vars)),
|
||||||
local_referenced_module_vars_(std::move(local_referenced_module_vars)),
|
local_referenced_module_vars_(std::move(local_referenced_module_vars)),
|
||||||
return_statements_(std::move(return_statements)),
|
return_statements_(std::move(return_statements)),
|
||||||
ancestor_entry_points_(std::move(ancestor_entry_points)) {}
|
ancestor_entry_points_(std::move(ancestor_entry_points)),
|
||||||
|
workgroup_size_(std::move(workgroup_size)) {}
|
||||||
|
|
||||||
Function::~Function() = default;
|
Function::~Function() = default;
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#ifndef SRC_SEM_FUNCTION_H_
|
#ifndef SRC_SEM_FUNCTION_H_
|
||||||
#define SRC_SEM_FUNCTION_H_
|
#define SRC_SEM_FUNCTION_H_
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -37,6 +38,16 @@ namespace sem {
|
|||||||
|
|
||||||
class Variable;
|
class Variable;
|
||||||
|
|
||||||
|
/// WorkgroupDimension describes the size of a single dimension of an entry
|
||||||
|
/// point's workgroup size.
|
||||||
|
struct WorkgroupDimension {
|
||||||
|
/// The size of this dimension.
|
||||||
|
uint32_t value;
|
||||||
|
/// A pipeline-overridable constant that overrides the size, or nullptr if
|
||||||
|
/// this dimension is not overridable.
|
||||||
|
const ast::Variable* overridable_const = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
/// Function holds the semantic information for function nodes.
|
/// Function holds the semantic information for function nodes.
|
||||||
class Function : public Castable<Function, CallTarget> {
|
class Function : public Castable<Function, CallTarget> {
|
||||||
public:
|
public:
|
||||||
@ -53,13 +64,15 @@ class Function : public Castable<Function, CallTarget> {
|
|||||||
/// @param return_statements the function return statements
|
/// @param return_statements the function return statements
|
||||||
/// variables
|
/// variables
|
||||||
/// @param ancestor_entry_points the ancestor entry points
|
/// @param ancestor_entry_points the ancestor entry points
|
||||||
|
/// @param workgroup_size the workgroup size
|
||||||
Function(ast::Function* declaration,
|
Function(ast::Function* declaration,
|
||||||
Type* return_type,
|
Type* return_type,
|
||||||
std::vector<const Variable*> parameters,
|
std::vector<const Variable*> parameters,
|
||||||
std::vector<const Variable*> referenced_module_vars,
|
std::vector<const Variable*> referenced_module_vars,
|
||||||
std::vector<const Variable*> local_referenced_module_vars,
|
std::vector<const Variable*> local_referenced_module_vars,
|
||||||
std::vector<const ast::ReturnStatement*> return_statements,
|
std::vector<const ast::ReturnStatement*> return_statements,
|
||||||
std::vector<Symbol> ancestor_entry_points);
|
std::vector<Symbol> ancestor_entry_points,
|
||||||
|
std::array<WorkgroupDimension, 3> workgroup_size);
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
~Function() override;
|
~Function() override;
|
||||||
@ -148,6 +161,11 @@ class Function : public Castable<Function, CallTarget> {
|
|||||||
/// @returns true if `sym` is an ancestor entry point of this function
|
/// @returns true if `sym` is an ancestor entry point of this function
|
||||||
bool HasAncestorEntryPoint(Symbol sym) const;
|
bool HasAncestorEntryPoint(Symbol sym) const;
|
||||||
|
|
||||||
|
/// @returns the workgroup size {x, y, z} for the function.
|
||||||
|
const std::array<WorkgroupDimension, 3>& workgroup_size() const {
|
||||||
|
return workgroup_size_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
VariableBindings ReferencedSamplerVariablesImpl(ast::SamplerKind kind) const;
|
VariableBindings ReferencedSamplerVariablesImpl(ast::SamplerKind kind) const;
|
||||||
VariableBindings ReferencedSampledTextureVariablesImpl(
|
VariableBindings ReferencedSampledTextureVariablesImpl(
|
||||||
@ -159,6 +177,7 @@ class Function : public Castable<Function, CallTarget> {
|
|||||||
std::vector<const Variable*> const local_referenced_module_vars_;
|
std::vector<const Variable*> const local_referenced_module_vars_;
|
||||||
std::vector<const ast::ReturnStatement*> const return_statements_;
|
std::vector<const ast::ReturnStatement*> const return_statements_;
|
||||||
std::vector<Symbol> const ancestor_entry_points_;
|
std::vector<Symbol> const ancestor_entry_points_;
|
||||||
|
std::array<WorkgroupDimension, 3> workgroup_size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sem
|
} // namespace sem
|
||||||
|
@ -1989,12 +1989,19 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
|
|||||||
make_indent(out);
|
make_indent(out);
|
||||||
|
|
||||||
current_ep_sym_ = func->symbol();
|
current_ep_sym_ = func->symbol();
|
||||||
|
auto* func_sem = builder_.Sem().Get(func);
|
||||||
|
|
||||||
if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
|
if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
|
||||||
uint32_t x = 0;
|
auto wgsize = func_sem->workgroup_size();
|
||||||
uint32_t y = 0;
|
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
|
||||||
uint32_t z = 0;
|
wgsize[2].overridable_const) {
|
||||||
std::tie(x, y, z) = func->workgroup_size();
|
// TODO(crbug.com/tint/713): Handle overridable constants.
|
||||||
|
TINT_UNIMPLEMENTED(builder_.Diagnostics())
|
||||||
|
<< "pipeline-overridable workgroup sizes are not implemented";
|
||||||
|
}
|
||||||
|
uint32_t x = wgsize[0].value;
|
||||||
|
uint32_t y = wgsize[1].value;
|
||||||
|
uint32_t z = wgsize[2].value;
|
||||||
out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y)
|
out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y)
|
||||||
<< ", " << std::to_string(z) << ")]" << std::endl;
|
<< ", " << std::to_string(z) << ")]" << std::endl;
|
||||||
make_indent(out);
|
make_indent(out);
|
||||||
|
@ -435,23 +435,30 @@ bool Builder::GenerateEntryPoint(ast::Function* func, uint32_t id) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool Builder::GenerateExecutionModes(ast::Function* func, uint32_t id) {
|
bool Builder::GenerateExecutionModes(ast::Function* func, uint32_t id) {
|
||||||
|
auto* func_sem = builder_.Sem().Get(func);
|
||||||
|
|
||||||
// WGSL fragment shader origin is upper left
|
// WGSL fragment shader origin is upper left
|
||||||
if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
|
if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
|
||||||
push_execution_mode(
|
push_execution_mode(
|
||||||
spv::Op::OpExecutionMode,
|
spv::Op::OpExecutionMode,
|
||||||
{Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
|
{Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
|
||||||
} else if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
|
} else if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
|
||||||
uint32_t x = 0;
|
auto& wgsize = func_sem->workgroup_size();
|
||||||
uint32_t y = 0;
|
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
|
||||||
uint32_t z = 0;
|
wgsize[2].overridable_const) {
|
||||||
std::tie(x, y, z) = func->workgroup_size();
|
// TODO(crbug.com/tint/713): Handle overridable constants.
|
||||||
|
TINT_UNIMPLEMENTED(builder_.Diagnostics())
|
||||||
|
<< "pipeline-overridable workgroup sizes are not implemented";
|
||||||
|
}
|
||||||
|
uint32_t x = wgsize[0].value;
|
||||||
|
uint32_t y = wgsize[1].value;
|
||||||
|
uint32_t z = wgsize[2].value;
|
||||||
push_execution_mode(
|
push_execution_mode(
|
||||||
spv::Op::OpExecutionMode,
|
spv::Op::OpExecutionMode,
|
||||||
{Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
|
{Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
|
||||||
Operand::Int(x), Operand::Int(y), Operand::Int(z)});
|
Operand::Int(x), Operand::Int(y), Operand::Int(z)});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* func_sem = builder_.Sem().Get(func);
|
|
||||||
for (auto builtin : func_sem->ReferencedBuiltinVariables()) {
|
for (auto builtin : func_sem->ReferencedBuiltinVariables()) {
|
||||||
if (builtin.second->value() == ast::Builtin::kFragDepth) {
|
if (builtin.second->value() == ast::Builtin::kFragDepth) {
|
||||||
push_execution_mode(
|
push_execution_mode(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user