From 93cd23c01b59591731817a4c9a4bd23da9635b60 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Tue, 13 Apr 2021 23:04:07 +0000 Subject: [PATCH] transform/VertexPulling: Use SymbolTable::New() And clean up some code in the process. Avoids potential symbol collisions. Simplifies the logic. Bug: tint:712 Change-Id: Ibce5ccbd4c7fd45d5bf29906b5a83b3637b6cdcc Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47633 Commit-Queue: Ben Clayton Reviewed-by: Antonio Maiorano Reviewed-by: James Price --- src/transform/vertex_pulling.cc | 722 +++++++++++++-------------- src/transform/vertex_pulling.h | 99 ---- src/transform/vertex_pulling_test.cc | 34 +- 3 files changed, 372 insertions(+), 483 deletions(-) diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index 951f4fd934..df00a17a10 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -22,6 +22,7 @@ #include "src/ast/variable_decl_statement.h" #include "src/program_builder.h" #include "src/semantic/variable.h" +#include "src/utils/get_or_create.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config); @@ -29,12 +30,360 @@ namespace tint { namespace transform { namespace { -static const char kVertexBufferNamePrefix[] = "_tint_pulling_vertex_buffer_"; -static const char kStructBufferName[] = "_tint_vertex_data"; -static const char kStructName[] = "TintVertexData"; -static const char kPullingPosVarName[] = "_tint_pulling_pos"; -static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index"; -static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index"; +struct State { + State(CloneContext& context, const VertexPulling::Config& c) + : ctx(context), cfg(c) {} + State(const State&) = default; + ~State() = default; + + /// LocationReplacement describes an ast::Variable replacement for a + /// location input. + struct LocationReplacement { + /// The variable to replace in the source Program + ast::Variable* from; + /// The replacement to use in the target ProgramBuilder + ast::Variable* to; + }; + + CloneContext& ctx; + VertexPulling::Config const cfg; + std::unordered_map location_to_var; + std::vector location_replacements; + Symbol vertex_index_name; + Symbol instance_index_name; + Symbol pulling_position_name; + Symbol struct_buffer_name; + std::unordered_map vertex_buffer_names; + + /// Generate the vertex buffer binding name + /// @param index index to append to buffer name + Symbol GetVertexBufferName(uint32_t index) { + return utils::GetOrCreate(vertex_buffer_names, index, [&] { + static const char kVertexBufferNamePrefix[] = + "_tint_pulling_vertex_buffer_"; + return ctx.dst->Symbols().New(kVertexBufferNamePrefix + + std::to_string(index)); + }); + } + + /// Lazily generates the pulling position symbol + Symbol GetPullingPositionName() { + if (!pulling_position_name.IsValid()) { + static const char kPullingPosVarName[] = "_tint_pulling_pos"; + pulling_position_name = ctx.dst->Symbols().New(kPullingPosVarName); + } + return pulling_position_name; + } + + /// Lazily generates the structure buffer symbol + Symbol GetStructBufferName() { + if (!struct_buffer_name.IsValid()) { + static const char kStructBufferName[] = "_tint_vertex_data"; + struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName); + } + return struct_buffer_name; + } + + /// Inserts vertex_index binding, or finds the existing one + void FindOrInsertVertexIndexIfUsed() { + bool uses_vertex_step_mode = false; + for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) { + if (buffer_layout.step_mode == InputStepMode::kVertex) { + uses_vertex_step_mode = true; + break; + } + } + if (!uses_vertex_step_mode) { + return; + } + + // Look for an existing vertex index builtin + for (auto* v : ctx.src->AST().GlobalVariables()) { + auto* sem = ctx.src->Sem().Get(v); + if (sem->StorageClass() != ast::StorageClass::kInput) { + continue; + } + + for (auto* d : v->decorations()) { + if (auto* builtin = d->As()) { + if (builtin->value() == ast::Builtin::kVertexIndex) { + vertex_index_name = ctx.Clone(v->symbol()); + return; + } + } + } + } + + // We didn't find a vertex index builtin, so create one + static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index"; + vertex_index_name = ctx.dst->Symbols().New(kDefaultVertexIndexName); + + ctx.dst->Global( + vertex_index_name, ctx.dst->ty.u32(), ast::StorageClass::kInput, + nullptr, + ast::DecorationList{ + ctx.dst->create(ast::Builtin::kVertexIndex), + }); + } + + /// Inserts instance_index binding, or finds the existing one + void FindOrInsertInstanceIndexIfUsed() { + bool uses_instance_step_mode = false; + for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) { + if (buffer_layout.step_mode == InputStepMode::kInstance) { + uses_instance_step_mode = true; + break; + } + } + if (!uses_instance_step_mode) { + return; + } + + // Look for an existing instance index builtin + for (auto* v : ctx.src->AST().GlobalVariables()) { + auto* sem = ctx.src->Sem().Get(v); + if (sem->StorageClass() != ast::StorageClass::kInput) { + continue; + } + + for (auto* d : v->decorations()) { + if (auto* builtin = d->As()) { + if (builtin->value() == ast::Builtin::kInstanceIndex) { + instance_index_name = ctx.Clone(v->symbol()); + return; + } + } + } + } + + // We didn't find an instance index builtin, so create one + static const char kDefaultInstanceIndexName[] = + "_tint_pulling_instance_index"; + instance_index_name = ctx.dst->Symbols().New(kDefaultInstanceIndexName); + + ctx.dst->Global(instance_index_name, ctx.dst->ty.u32(), + ast::StorageClass::kInput, nullptr, + ast::DecorationList{ + ctx.dst->create( + ast::Builtin::kInstanceIndex), + }); + } + + /// Converts var with a location decoration to var + void ConvertVertexInputVariablesToPrivate() { + for (auto* v : ctx.src->AST().GlobalVariables()) { + auto* sem = ctx.src->Sem().Get(v); + if (sem->StorageClass() != ast::StorageClass::kInput) { + continue; + } + + for (auto* d : v->decorations()) { + if (auto* l = d->As()) { + uint32_t location = l->value(); + // This is where the replacement is created. Expressions use + // identifier strings instead of pointers, so we don't need to update + // any other place in the AST. + auto* replacement = ctx.dst->Var(ctx.Clone(v->symbol()), + ctx.Clone(v->declared_type()), + ast::StorageClass::kPrivate); + location_to_var[location] = replacement; + location_replacements.emplace_back( + LocationReplacement{v, replacement}); + break; + } + } + } + } + + /// Adds storage buffer decorated variables for the vertex buffers + void AddVertexStorageBuffers() { + // TODO(idanr): Make this readonly + // https://github.com/gpuweb/gpuweb/issues/935 + + // Creating the struct type + static const char kStructName[] = "TintVertexData"; + auto* struct_type = ctx.dst->Structure( + ctx.dst->Symbols().New(kStructName), + { + ctx.dst->Member(GetStructBufferName(), + ctx.dst->ty.array(4)), + }, + { + ctx.dst->create(), + }); + + for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { + // The decorated variable with struct type + ctx.dst->Global( + GetVertexBufferName(i), struct_type, ast::StorageClass::kStorage, + nullptr, + ast::DecorationList{ + ctx.dst->create(i), + ctx.dst->create(cfg.pulling_group), + }); + } + } + + /// Creates and returns the assignment to the variables from the buffers + ast::BlockStatement* CreateVertexPullingPreamble() { + // Assign by looking at the vertex descriptor to find attributes with + // matching location. + + ast::StatementList stmts; + + // Declare the pulling position variable in the shader + stmts.emplace_back(ctx.dst->create( + ctx.dst->Var(GetPullingPositionName(), ctx.dst->ty.u32(), + ast::StorageClass::kFunction))); + + for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { + const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i]; + + for (const VertexAttributeDescriptor& attribute_desc : + buffer_layout.attributes) { + auto it = location_to_var.find(attribute_desc.shader_location); + if (it == location_to_var.end()) { + continue; + } + auto* v = it->second; + + auto name = buffer_layout.step_mode == InputStepMode::kVertex + ? vertex_index_name + : instance_index_name; + + // An expression for the start of the read in the buffer in bytes + auto* pos_value = ctx.dst->Add( + ctx.dst->Mul(name, + static_cast(buffer_layout.array_stride)), + static_cast(attribute_desc.offset)); + + // Update position of the read + auto* set_pos_expr = ctx.dst->create( + ctx.dst->Expr(GetPullingPositionName()), pos_value); + stmts.emplace_back(set_pos_expr); + + stmts.emplace_back(ctx.dst->create( + ctx.dst->create(v->symbol()), + AccessByFormat(i, attribute_desc.format))); + } + } + + return ctx.dst->create(stmts); + } + + /// Generates an expression reading from a buffer a specific format. + /// This reads the value wherever `kPullingPosVarName` points to at the time + /// of the read. + /// @param buffer the index of the vertex buffer + /// @param format the format to read + ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) { + // TODO(idanr): this doesn't account for the format of the attribute in the + // shader. ex: vec in shader, and attribute claims VertexFormat::Float4 + // right now, we would try to assign a vec4 to this attribute, but we + // really need to assign a vec4 by casting. + // We could split this function to first do memory accesses and unpacking + // into int/uint/float1-4/etc, then convert that variable to a var with + // the conversion defined in the WebGPU spec. + switch (format) { + case VertexFormat::kU32: + return AccessU32(buffer, ctx.dst->Expr(GetPullingPositionName())); + case VertexFormat::kI32: + return AccessI32(buffer, ctx.dst->Expr(GetPullingPositionName())); + case VertexFormat::kF32: + return AccessF32(buffer, ctx.dst->Expr(GetPullingPositionName())); + case VertexFormat::kVec2F32: + return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 2); + case VertexFormat::kVec3F32: + return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 3); + case VertexFormat::kVec4F32: + return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 4); + default: + return nullptr; + } + } + + /// Generates an expression reading a uint32 from a vertex buffer + /// @param buffer the index of the vertex buffer + /// @param pos an expression for the position of the access, in bytes + ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) { + // Here we divide by 4, since the buffer is uint32 not uint8. The input + // buffer has byte offsets for each attribute, and we will convert it to u32 + // indexes by dividing. Then, that element is going to be read, and if + // needed, unpacked into an appropriate variable. All reads should end up + // here as a base case. + return ctx.dst->create( + ctx.dst->MemberAccessor(GetVertexBufferName(buffer), + GetStructBufferName()), + ctx.dst->Div(pos, 4u)); + } + + /// Generates an expression reading an int32 from a vertex buffer + /// @param buffer the index of the vertex buffer + /// @param pos an expression for the position of the access, in bytes + ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) { + // as reinterprets bits + return ctx.dst->create(ctx.dst->ty.i32(), + AccessU32(buffer, pos)); + } + + /// Generates an expression reading a float from a vertex buffer + /// @param buffer the index of the vertex buffer + /// @param pos an expression for the position of the access, in bytes + ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) { + // as reinterprets bits + return ctx.dst->create(ctx.dst->ty.f32(), + AccessU32(buffer, pos)); + } + + /// Generates an expression reading a basic type (u32, i32, f32) from a + /// vertex buffer + /// @param buffer the index of the vertex buffer + /// @param pos an expression for the position of the access, in bytes + /// @param format the underlying vertex format + ast::Expression* AccessPrimitive(uint32_t buffer, + ast::Expression* pos, + VertexFormat format) { + // This function uses a position expression to read, rather than using the + // position variable. This allows us to read from offset positions relative + // to |kPullingPosVarName|. We can't call AccessByFormat because it reads + // only from the position variable. + switch (format) { + case VertexFormat::kU32: + return AccessU32(buffer, pos); + case VertexFormat::kI32: + return AccessI32(buffer, pos); + case VertexFormat::kF32: + return AccessF32(buffer, pos); + default: + return nullptr; + } + } + + /// Generates an expression reading a vec2/3/4 from a vertex buffer. + /// This reads the value wherever `kPullingPosVarName` points to at the time + /// of the read. + /// @param buffer the index of the vertex buffer + /// @param element_stride stride between elements, in bytes + /// @param base_type underlying AST type + /// @param base_format underlying vertex format + /// @param count how many elements the vector has + ast::Expression* AccessVec(uint32_t buffer, + uint32_t element_stride, + type::Type* base_type, + VertexFormat base_format, + uint32_t count) { + ast::ExpressionList expr_list; + for (uint32_t i = 0; i < count; ++i) { + // Offset read position by element_stride for each component + auto* cur_pos = + ctx.dst->Add(GetPullingPositionName(), element_stride * i); + expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format)); + } + + return ctx.dst->create( + ctx.dst->create(base_type, count), std::move(expr_list)); + } +}; } // namespace @@ -93,367 +442,6 @@ VertexPulling::Config::~Config() = default; VertexPulling::Config& VertexPulling::Config::operator=(const Config&) = default; -VertexPulling::State::State(CloneContext& context, const Config& c) - : ctx(context), cfg(c) {} -VertexPulling::State::State(const State&) = default; -VertexPulling::State::~State() = default; - -std::string VertexPulling::State::GetVertexBufferName(uint32_t index) const { - return kVertexBufferNamePrefix + std::to_string(index); -} - -void VertexPulling::State::FindOrInsertVertexIndexIfUsed() { - bool uses_vertex_step_mode = false; - for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) { - if (buffer_layout.step_mode == InputStepMode::kVertex) { - uses_vertex_step_mode = true; - break; - } - } - if (!uses_vertex_step_mode) { - return; - } - - // Look for an existing vertex index builtin - for (auto* v : ctx.src->AST().GlobalVariables()) { - auto* sem = ctx.src->Sem().Get(v); - if (sem->StorageClass() != ast::StorageClass::kInput) { - continue; - } - - for (auto* d : v->decorations()) { - if (auto* builtin = d->As()) { - if (builtin->value() == ast::Builtin::kVertexIndex) { - vertex_index_name = ctx.src->Symbols().NameFor(v->symbol()); - return; - } - } - } - } - - // We didn't find a vertex index builtin, so create one - vertex_index_name = kDefaultVertexIndexName; - - auto* var = ctx.dst->create( - Source{}, // source - ctx.dst->Symbols().Register(vertex_index_name), // symbol - ast::StorageClass::kInput, // storage_class - GetU32Type(), // type - false, // is_const - nullptr, // constructor - ast::DecorationList{ - ctx.dst->create(Source{}, - ast::Builtin::kVertexIndex), - }); - - ctx.dst->AST().AddGlobalVariable(var); -} - -void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() { - bool uses_instance_step_mode = false; - for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) { - if (buffer_layout.step_mode == InputStepMode::kInstance) { - uses_instance_step_mode = true; - break; - } - } - if (!uses_instance_step_mode) { - return; - } - - // Look for an existing instance index builtin - for (auto* v : ctx.src->AST().GlobalVariables()) { - auto* sem = ctx.src->Sem().Get(v); - if (sem->StorageClass() != ast::StorageClass::kInput) { - continue; - } - - for (auto* d : v->decorations()) { - if (auto* builtin = d->As()) { - if (builtin->value() == ast::Builtin::kInstanceIndex) { - instance_index_name = ctx.src->Symbols().NameFor(v->symbol()); - return; - } - } - } - } - - // We didn't find an instance index builtin, so create one - instance_index_name = kDefaultInstanceIndexName; - - auto* var = ctx.dst->create( - Source{}, // source - ctx.dst->Symbols().Register(instance_index_name), // symbol - ast::StorageClass::kInput, // storage_class - GetU32Type(), // type - false, // is_const - nullptr, // constructor - ast::DecorationList{ - ctx.dst->create(Source{}, - ast::Builtin::kInstanceIndex), - }); - ctx.dst->AST().AddGlobalVariable(var); -} - -void VertexPulling::State::ConvertVertexInputVariablesToPrivate() { - for (auto* v : ctx.src->AST().GlobalVariables()) { - auto* sem = ctx.src->Sem().Get(v); - if (sem->StorageClass() != ast::StorageClass::kInput) { - continue; - } - - for (auto* d : v->decorations()) { - if (auto* l = d->As()) { - uint32_t location = l->value(); - // This is where the replacement is created. Expressions use identifier - // strings instead of pointers, so we don't need to update any other - // place in the AST. - auto name = ctx.src->Symbols().NameFor(v->symbol()); - auto* replacement = ctx.dst->create( - Source{}, // source - ctx.dst->Symbols().Register(name), // symbol - ast::StorageClass::kPrivate, // storage_class - ctx.Clone(v->declared_type()), // type - false, // is_const - nullptr, // constructor - ast::DecorationList{}); // decorations - location_to_var[location] = replacement; - location_replacements.emplace_back(LocationReplacement{v, replacement}); - break; - } - } - } -} - -void VertexPulling::State::AddVertexStorageBuffers() { - // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935 - // The array inside the struct definition - auto* internal_array_type = ctx.dst->create( - GetU32Type(), 0, - ast::DecorationList{ - ctx.dst->create(Source{}, 4u), - }); - - // Creating the struct type - ast::StructMemberList members; - members.push_back(ctx.dst->create( - Source{}, ctx.dst->Symbols().Register(kStructBufferName), - internal_array_type, ast::DecorationList{})); - - ast::DecorationList decos; - decos.push_back(ctx.dst->create(Source{})); - - auto* struct_type = ctx.dst->create( - ctx.dst->Symbols().Register(kStructName), - ctx.dst->create(Source{}, std::move(members), - std::move(decos))); - - for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { - // The decorated variable with struct type - std::string name = GetVertexBufferName(i); - auto* var = ctx.dst->create( - Source{}, // source - ctx.dst->Symbols().Register(name), // symbol - ast::StorageClass::kStorage, // storage_class - struct_type, // type - false, // is_const - nullptr, // constructor - ast::DecorationList{ - ctx.dst->create(Source{}, i), - ctx.dst->create(Source{}, cfg.pulling_group), - }); - ctx.dst->AST().AddGlobalVariable(var); - } - ctx.dst->AST().AddConstructedType(struct_type); -} - -ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const { - // Assign by looking at the vertex descriptor to find attributes with matching - // location. - - ast::StatementList stmts; - - // Declare the |kPullingPosVarName| variable in the shader - auto* pos_declaration = ctx.dst->create( - Source{}, ctx.dst->create( - Source{}, // source - ctx.dst->Symbols().Register(kPullingPosVarName), // symbol - ast::StorageClass::kFunction, // storage_class - GetU32Type(), // type - false, // is_const - nullptr, // constructor - ast::DecorationList{})); // decorations - - // |kPullingPosVarName| refers to the byte location of the current read. We - // declare a variable in the shader to avoid having to reuse Expression - // objects. - stmts.emplace_back(pos_declaration); - - for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { - const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i]; - - for (const VertexAttributeDescriptor& attribute_desc : - buffer_layout.attributes) { - auto it = location_to_var.find(attribute_desc.shader_location); - if (it == location_to_var.end()) { - continue; - } - auto* v = it->second; - - auto name = buffer_layout.step_mode == InputStepMode::kVertex - ? vertex_index_name - : instance_index_name; - // Identifier to index by - auto* index_identifier = ctx.dst->create( - Source{}, ctx.dst->Symbols().Register(name)); - - // An expression for the start of the read in the buffer in bytes - auto* pos_value = ctx.dst->create( - Source{}, ast::BinaryOp::kAdd, - ctx.dst->create( - Source{}, ast::BinaryOp::kMultiply, index_identifier, - GenUint(static_cast(buffer_layout.array_stride))), - GenUint(static_cast(attribute_desc.offset))); - - // Update position of the read - auto* set_pos_expr = ctx.dst->create( - Source{}, CreatePullingPositionIdent(), pos_value); - stmts.emplace_back(set_pos_expr); - - stmts.emplace_back(ctx.dst->create( - Source{}, - ctx.dst->create(Source{}, v->symbol()), - AccessByFormat(i, attribute_desc.format))); - } - } - - return ctx.dst->create(Source{}, stmts); -} - -ast::Expression* VertexPulling::State::GenUint(uint32_t value) const { - return ctx.dst->create( - Source{}, - ctx.dst->create(Source{}, GetU32Type(), value)); -} - -ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const { - return ctx.dst->create( - Source{}, ctx.dst->Symbols().Register(kPullingPosVarName)); -} - -ast::Expression* VertexPulling::State::AccessByFormat( - uint32_t buffer, - VertexFormat format) const { - // TODO(idanr): this doesn't account for the format of the attribute in the - // shader. ex: vec in shader, and attribute claims VertexFormat::Float4 - // right now, we would try to assign a vec4 to this attribute, but we - // really need to assign a vec4 by casting. - // We could split this function to first do memory accesses and unpacking into - // int/uint/float1-4/etc, then convert that variable to a var with the - // conversion defined in the WebGPU spec. - switch (format) { - case VertexFormat::kU32: - return AccessU32(buffer, CreatePullingPositionIdent()); - case VertexFormat::kI32: - return AccessI32(buffer, CreatePullingPositionIdent()); - case VertexFormat::kF32: - return AccessF32(buffer, CreatePullingPositionIdent()); - case VertexFormat::kVec2F32: - return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 2); - case VertexFormat::kVec3F32: - return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 3); - case VertexFormat::kVec4F32: - return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 4); - default: - return nullptr; - } -} - -ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer, - ast::Expression* pos) const { - // Here we divide by 4, since the buffer is uint32 not uint8. The input buffer - // has byte offsets for each attribute, and we will convert it to u32 indexes - // by dividing. Then, that element is going to be read, and if needed, - // unpacked into an appropriate variable. All reads should end up here as a - // base case. - auto vbuf_name = GetVertexBufferName(buffer); - return ctx.dst->create( - Source{}, - ctx.dst->create( - Source{}, - ctx.dst->create( - Source{}, ctx.dst->Symbols().Register(vbuf_name)), - ctx.dst->create( - Source{}, ctx.dst->Symbols().Register(kStructBufferName))), - ctx.dst->create(Source{}, ast::BinaryOp::kDivide, - pos, GenUint(4))); -} - -ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer, - ast::Expression* pos) const { - // as reinterprets bits - return ctx.dst->create(Source{}, GetI32Type(), - AccessU32(buffer, pos)); -} - -ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer, - ast::Expression* pos) const { - // as reinterprets bits - return ctx.dst->create(Source{}, GetF32Type(), - AccessU32(buffer, pos)); -} - -ast::Expression* VertexPulling::State::AccessPrimitive( - uint32_t buffer, - ast::Expression* pos, - VertexFormat format) const { - // This function uses a position expression to read, rather than using the - // position variable. This allows us to read from offset positions relative to - // |kPullingPosVarName|. We can't call AccessByFormat because it reads only - // from the position variable. - switch (format) { - case VertexFormat::kU32: - return AccessU32(buffer, pos); - case VertexFormat::kI32: - return AccessI32(buffer, pos); - case VertexFormat::kF32: - return AccessF32(buffer, pos); - default: - return nullptr; - } -} - -ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer, - uint32_t element_stride, - type::Type* base_type, - VertexFormat base_format, - uint32_t count) const { - ast::ExpressionList expr_list; - for (uint32_t i = 0; i < count; ++i) { - // Offset read position by element_stride for each component - auto* cur_pos = ctx.dst->create( - Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(), - GenUint(element_stride * i)); - expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format)); - } - - return ctx.dst->create( - Source{}, ctx.dst->create(base_type, count), - std::move(expr_list)); -} - -type::Type* VertexPulling::State::GetU32Type() const { - return ctx.dst->create(); -} - -type::Type* VertexPulling::State::GetI32Type() const { - return ctx.dst->create(); -} - -type::Type* VertexPulling::State::GetF32Type() const { - return ctx.dst->create(); -} - VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default; VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor( diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h index 8be1812bf5..576bc09c09 100644 --- a/src/transform/vertex_pulling.h +++ b/src/transform/vertex_pulling.h @@ -176,105 +176,6 @@ class VertexPulling : public Transform { private: Config cfg_; - - struct State { - State(CloneContext& ctx, const Config& c); - explicit State(const State&); - ~State(); - - /// Generate the vertex buffer binding name - /// @param index index to append to buffer name - std::string GetVertexBufferName(uint32_t index) const; - - /// Inserts vertex_index binding, or finds the existing one - void FindOrInsertVertexIndexIfUsed(); - - /// Inserts instance_index binding, or finds the existing one - void FindOrInsertInstanceIndexIfUsed(); - - /// Converts var with a location decoration to var - void ConvertVertexInputVariablesToPrivate(); - - /// Adds storage buffer decorated variables for the vertex buffers - void AddVertexStorageBuffers(); - - /// Creates and returns the assignment to the variables from the buffers - ast::BlockStatement* CreateVertexPullingPreamble() const; - - /// Generates an expression holding a constant uint - /// @param value uint value - ast::Expression* GenUint(uint32_t value) const; - - /// Generates an expression to read the shader value `kPullingPosVarName` - ast::Expression* CreatePullingPositionIdent() const; - - /// Generates an expression reading from a buffer a specific format. - /// This reads the value wherever `kPullingPosVarName` points to at the time - /// of the read. - /// @param buffer the index of the vertex buffer - /// @param format the format to read - ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) const; - - /// Generates an expression reading a uint32 from a vertex buffer - /// @param buffer the index of the vertex buffer - /// @param pos an expression for the position of the access, in bytes - ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) const; - - /// Generates an expression reading an int32 from a vertex buffer - /// @param buffer the index of the vertex buffer - /// @param pos an expression for the position of the access, in bytes - ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) const; - - /// Generates an expression reading a float from a vertex buffer - /// @param buffer the index of the vertex buffer - /// @param pos an expression for the position of the access, in bytes - ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) const; - - /// Generates an expression reading a basic type (u32, i32, f32) from a - /// vertex buffer - /// @param buffer the index of the vertex buffer - /// @param pos an expression for the position of the access, in bytes - /// @param format the underlying vertex format - ast::Expression* AccessPrimitive(uint32_t buffer, - ast::Expression* pos, - VertexFormat format) const; - - /// Generates an expression reading a vec2/3/4 from a vertex buffer. - /// This reads the value wherever `kPullingPosVarName` points to at the time - /// of the read. - /// @param buffer the index of the vertex buffer - /// @param element_stride stride between elements, in bytes - /// @param base_type underlying AST type - /// @param base_format underlying vertex format - /// @param count how many elements the vector has - ast::Expression* AccessVec(uint32_t buffer, - uint32_t element_stride, - type::Type* base_type, - VertexFormat base_format, - uint32_t count) const; - - // Used to grab corresponding types from the type manager - type::Type* GetU32Type() const; - type::Type* GetI32Type() const; - type::Type* GetF32Type() const; - - CloneContext& ctx; - Config const cfg; - - /// LocationReplacement describes an ast::Variable replacement for a - /// location input. - struct LocationReplacement { - /// The variable to replace in the source Program - ast::Variable* from; - /// The replacement to use in the target ProgramBuilder - ast::Variable* to; - }; - - std::unordered_map location_to_var; - std::vector location_replacements; - std::string vertex_index_name; - std::string instance_index_name; - }; }; } // namespace transform diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index e1222e6f1f..ce01f87ef8 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -113,13 +113,13 @@ fn main() {} auto* expect = R"( [[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; - [[block]] struct TintVertexData { _tint_vertex_data : [[stride(4)]] array; }; +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + var var_a : f32; [[stage(vertex)]] @@ -155,13 +155,13 @@ fn main() {} auto* expect = R"( [[builtin(instance_index)]] var _tint_pulling_instance_index : u32; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; - [[block]] struct TintVertexData { _tint_vertex_data : [[stride(4)]] array; }; +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + var var_a : f32; [[stage(vertex)]] @@ -197,13 +197,13 @@ fn main() {} auto* expect = R"( [[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; -[[binding(0), group(5)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; - [[block]] struct TintVertexData { _tint_vertex_data : [[stride(4)]] array; }; +[[binding(0), group(5)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + var var_a : f32; [[stage(vertex)]] @@ -242,15 +242,15 @@ fn main() {} )"; auto* expect = R"( -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; - -[[binding(1), group(4)]] var _tint_pulling_vertex_buffer_1 : TintVertexData; - [[block]] struct TintVertexData { _tint_vertex_data : [[stride(4)]] array; }; +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + +[[binding(1), group(4)]] var _tint_pulling_vertex_buffer_1 : TintVertexData; + var var_a : f32; var var_b : f32; @@ -305,13 +305,13 @@ fn main() {} auto* expect = R"( [[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; - [[block]] struct TintVertexData { _tint_vertex_data : [[stride(4)]] array; }; +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + var var_a : f32; var var_b : vec4; @@ -355,17 +355,17 @@ fn main() {} auto* expect = R"( [[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; +[[block]] +struct TintVertexData { + _tint_vertex_data : [[stride(4)]] array; +}; + [[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; [[binding(1), group(4)]] var _tint_pulling_vertex_buffer_1 : TintVertexData; [[binding(2), group(4)]] var _tint_pulling_vertex_buffer_2 : TintVertexData; -[[block]] -struct TintVertexData { - _tint_vertex_data : [[stride(4)]] array; -}; - var var_a : vec2; var var_b : vec3;