// Copyright 2020 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/transform/vertex_pulling.h" #include #include #include "src/tint/ast/assignment_statement.h" #include "src/tint/ast/bitcast_expression.h" #include "src/tint/ast/variable_decl_statement.h" #include "src/tint/program_builder.h" #include "src/tint/sem/variable.h" #include "src/tint/utils/map.h" #include "src/tint/utils/math.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling); TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config); using namespace tint::number_suffixes; // NOLINT namespace tint::transform { namespace { /// The base type of a component. /// The format type is either this type or a vector of this type. enum class BaseType { kInvalid, kU32, kI32, kF32, }; /// Writes the BaseType to the std::ostream. /// @param out the std::ostream to write to /// @param format the BaseType to write /// @returns out so calls can be chained std::ostream& operator<<(std::ostream& out, BaseType format) { switch (format) { case BaseType::kInvalid: return out << "invalid"; case BaseType::kU32: return out << "u32"; case BaseType::kI32: return out << "i32"; case BaseType::kF32: return out << "f32"; } return out << ""; } /// Writes the VertexFormat to the std::ostream. /// @param out the std::ostream to write to /// @param format the VertexFormat to write /// @returns out so calls can be chained std::ostream& operator<<(std::ostream& out, VertexFormat format) { switch (format) { case VertexFormat::kUint8x2: return out << "uint8x2"; case VertexFormat::kUint8x4: return out << "uint8x4"; case VertexFormat::kSint8x2: return out << "sint8x2"; case VertexFormat::kSint8x4: return out << "sint8x4"; case VertexFormat::kUnorm8x2: return out << "unorm8x2"; case VertexFormat::kUnorm8x4: return out << "unorm8x4"; case VertexFormat::kSnorm8x2: return out << "snorm8x2"; case VertexFormat::kSnorm8x4: return out << "snorm8x4"; case VertexFormat::kUint16x2: return out << "uint16x2"; case VertexFormat::kUint16x4: return out << "uint16x4"; case VertexFormat::kSint16x2: return out << "sint16x2"; case VertexFormat::kSint16x4: return out << "sint16x4"; case VertexFormat::kUnorm16x2: return out << "unorm16x2"; case VertexFormat::kUnorm16x4: return out << "unorm16x4"; case VertexFormat::kSnorm16x2: return out << "snorm16x2"; case VertexFormat::kSnorm16x4: return out << "snorm16x4"; case VertexFormat::kFloat16x2: return out << "float16x2"; case VertexFormat::kFloat16x4: return out << "float16x4"; case VertexFormat::kFloat32: return out << "float32"; case VertexFormat::kFloat32x2: return out << "float32x2"; case VertexFormat::kFloat32x3: return out << "float32x3"; case VertexFormat::kFloat32x4: return out << "float32x4"; case VertexFormat::kUint32: return out << "uint32"; case VertexFormat::kUint32x2: return out << "uint32x2"; case VertexFormat::kUint32x3: return out << "uint32x3"; case VertexFormat::kUint32x4: return out << "uint32x4"; case VertexFormat::kSint32: return out << "sint32"; case VertexFormat::kSint32x2: return out << "sint32x2"; case VertexFormat::kSint32x3: return out << "sint32x3"; case VertexFormat::kSint32x4: return out << "sint32x4"; } return out << ""; } /// A vertex attribute data format. struct DataType { BaseType base_type; uint32_t width; // 1 for scalar, 2+ for a vector }; DataType DataTypeOf(const sem::Type* ty) { if (ty->Is()) { return {BaseType::kI32, 1}; } if (ty->Is()) { return {BaseType::kU32, 1}; } if (ty->Is()) { return {BaseType::kF32, 1}; } if (auto* vec = ty->As()) { return {DataTypeOf(vec->type()).base_type, vec->Width()}; } return {BaseType::kInvalid, 0}; } DataType DataTypeOf(VertexFormat format) { switch (format) { case VertexFormat::kUint32: return {BaseType::kU32, 1}; case VertexFormat::kUint8x2: case VertexFormat::kUint16x2: case VertexFormat::kUint32x2: return {BaseType::kU32, 2}; case VertexFormat::kUint32x3: return {BaseType::kU32, 3}; case VertexFormat::kUint8x4: case VertexFormat::kUint16x4: case VertexFormat::kUint32x4: return {BaseType::kU32, 4}; case VertexFormat::kSint32: return {BaseType::kI32, 1}; case VertexFormat::kSint8x2: case VertexFormat::kSint16x2: case VertexFormat::kSint32x2: return {BaseType::kI32, 2}; case VertexFormat::kSint32x3: return {BaseType::kI32, 3}; case VertexFormat::kSint8x4: case VertexFormat::kSint16x4: case VertexFormat::kSint32x4: return {BaseType::kI32, 4}; case VertexFormat::kFloat32: return {BaseType::kF32, 1}; case VertexFormat::kUnorm8x2: case VertexFormat::kSnorm8x2: case VertexFormat::kUnorm16x2: case VertexFormat::kSnorm16x2: case VertexFormat::kFloat16x2: case VertexFormat::kFloat32x2: return {BaseType::kF32, 2}; case VertexFormat::kFloat32x3: return {BaseType::kF32, 3}; case VertexFormat::kUnorm8x4: case VertexFormat::kSnorm8x4: case VertexFormat::kUnorm16x4: case VertexFormat::kSnorm16x4: case VertexFormat::kFloat16x4: case VertexFormat::kFloat32x4: return {BaseType::kF32, 4}; } return {BaseType::kInvalid, 0}; } } // namespace /// PIMPL state for the transform struct VertexPulling::State { /// Constructor /// @param program the source program /// @param c the VertexPulling config State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {} /// Runs the transform /// @returns the new program or SkipTransform if the transform is not required ApplyResult Run() { // Find entry point const ast::Function* func = nullptr; for (auto* fn : src->AST().Functions()) { if (fn->PipelineStage() == ast::PipelineStage::kVertex) { if (func != nullptr) { b.Diagnostics().add_error( diag::System::Transform, "VertexPulling found more than one vertex entry point"); return Program(std::move(b)); } func = fn; } } if (func == nullptr) { b.Diagnostics().add_error(diag::System::Transform, "Vertex stage entry point not found"); return Program(std::move(b)); } AddVertexStorageBuffers(); Process(func); ctx.Clone(); return Program(std::move(b)); } private: /// LocationReplacement describes an ast::Variable replacement for a location input. struct LocationReplacement { /// The variable to replace in the source Program ast::Variable* from; /// The replacement to use in the target ProgramBuilder ast::Variable* to; }; /// LocationInfo describes an input location struct LocationInfo { /// A builder that builds the expression that resolves to the (transformed) input location std::function expr; /// The store type of the location variable const sem::Type* type; }; /// The source program const Program* const src; /// The transform config VertexPulling::Config const cfg; /// The target program builder ProgramBuilder b; /// The clone context CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; std::unordered_map location_info; std::function vertex_index_expr = nullptr; std::function instance_index_expr = nullptr; Symbol pulling_position_name; Symbol struct_buffer_name; std::unordered_map vertex_buffer_names; utils::Vector new_function_parameters; /// Generate the vertex buffer binding name /// @param index index to append to buffer name Symbol GetVertexBufferName(uint32_t index) { return utils::GetOrCreate(vertex_buffer_names, index, [&] { static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_"; return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index)); }); } /// Lazily generates the structure buffer symbol Symbol GetStructBufferName() { if (!struct_buffer_name.IsValid()) { static const char kStructBufferName[] = "tint_vertex_data"; struct_buffer_name = b.Symbols().New(kStructBufferName); } return struct_buffer_name; } /// Adds storage buffer decorated variables for the vertex buffers void AddVertexStorageBuffers() { // Creating the struct type static const char kStructName[] = "TintVertexData"; auto* struct_type = b.Structure(b.Symbols().New(kStructName), utils::Vector{ b.Member(GetStructBufferName(), b.ty.array()), }); for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { // The decorated variable with struct type b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage, ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group))); } } /// Creates and returns the assignment to the variables from the buffers const ast::BlockStatement* CreateVertexPullingPreamble() { // Assign by looking at the vertex descriptor to find attributes with // matching location. utils::Vector stmts; for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) { const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx]; if ((buffer_layout.array_stride & 3) != 0) { b.Diagnostics().add_error( diag::System::Transform, "WebGPU requires that vertex stride must be a multiple of 4 bytes, " "but VertexPulling array stride for buffer " + std::to_string(buffer_idx) + " was " + std::to_string(buffer_layout.array_stride) + " bytes"); return nullptr; } auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex ? vertex_index_expr() : instance_index_expr(); // buffer_array_base is the base array offset for all the vertex // attributes. These are units of uint (4 bytes). auto buffer_array_base = b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx)); auto* attribute_offset = index_expr; if (buffer_layout.array_stride != 4) { attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u)); } // let pulling_offset_n = stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset))); for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) { auto it = location_info.find(attribute_desc.shader_location); if (it == location_info.end()) { continue; } auto& var = it->second; // Data type of the target WGSL variable auto var_dt = DataTypeOf(var.type); // Data type of the vertex stream attribute auto fmt_dt = DataTypeOf(attribute_desc.format); // Base types must match between the vertex stream and the WGSL variable if (var_dt.base_type != fmt_dt.base_type) { std::stringstream err; err << "VertexAttributeDescriptor for location " << std::to_string(attribute_desc.shader_location) << " has format " << attribute_desc.format << " but shader expects " << var.type->FriendlyName(src->Symbols()); b.Diagnostics().add_error(diag::System::Transform, err.str()); return nullptr; } // Load the attribute value auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx, attribute_desc.format); // The attribute value may not be of the desired vector width. If it is // not, we'll need to either reduce the width with a swizzle, or append // 0's and / or a 1. auto* value = fetch; if (var_dt.width < fmt_dt.width) { // WGSL variable vector width is smaller than the loaded vector width switch (var_dt.width) { case 1: value = b.MemberAccessor(fetch, "x"); break; case 2: value = b.MemberAccessor(fetch, "xy"); break; case 3: value = b.MemberAccessor(fetch, "xyz"); break; default: TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width; return nullptr; } } else if (var_dt.width > fmt_dt.width) { // WGSL variable vector width is wider than the loaded vector width const ast::Type* ty = nullptr; utils::Vector values{fetch}; switch (var_dt.base_type) { case BaseType::kI32: ty = b.ty.i32(); for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { values.Push(b.Expr((i == 3) ? 1_i : 0_i)); } break; case BaseType::kU32: ty = b.ty.u32(); for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { values.Push(b.Expr((i == 3) ? 1_u : 0_u)); } break; case BaseType::kF32: ty = b.ty.f32(); for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { values.Push(b.Expr((i == 3) ? 1_f : 0_f)); } break; default: TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type; return nullptr; } value = b.Construct(b.ty.vec(ty, var_dt.width), values); } // Assign the value to the WGSL variable stmts.Push(b.Assign(var.expr(), value)); } } if (stmts.IsEmpty()) { return nullptr; } return b.Block(std::move(stmts)); } /// Generates an expression reading from a buffer a specific format. /// @param array_base the symbol of the variable holding the base array offset /// of the vertex array (each index is 4-bytes). /// @param offset the byte offset of the data from `buffer_base` /// @param buffer the index of the vertex buffer /// @param format the format to read const ast::Expression* Fetch(Symbol array_base, uint32_t offset, uint32_t buffer, VertexFormat format) { // Returns a u32 loaded from buffer_base + offset. auto load_u32 = [&] { return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32); }; // Returns a i32 loaded from buffer_base + offset. auto load_i32 = [&] { return b.Bitcast(load_u32()); }; // Returns a u32 loaded from buffer_base + offset + 4. auto load_next_u32 = [&] { return LoadPrimitive(array_base, offset + 4, buffer, VertexFormat::kUint32); }; // Returns a i32 loaded from buffer_base + offset + 4. auto load_next_i32 = [&] { return b.Bitcast(load_next_u32()); }; // Returns a u16 loaded from offset, packed in the high 16 bits of a u32. // The low 16 bits are 0. // `min_alignment` must be a power of two. // `offset` must be `min_alignment` bytes aligned. auto load_u16_h = [&] { auto low_u32_offset = offset & ~3u; auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); switch (offset & 3) { case 0: return b.Shl(low_u32, 16_u); case 1: return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u); case 2: return b.And(low_u32, 0xffff0000_u); default: { // 3: auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, VertexFormat::kUint32); auto* shr = b.Shr(low_u32, 8_u); auto* shl = b.Shl(high_u32, 24_u); return b.And(b.Or(shl, shr), 0xffff0000_u); } } }; // Returns a u16 loaded from offset, packed in the low 16 bits of a u32. // The high 16 bits are 0. auto load_u16_l = [&] { auto low_u32_offset = offset & ~3u; auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); switch (offset & 3) { case 0: return b.And(low_u32, 0xffff_u); case 1: return b.And(b.Shr(low_u32, 8_u), 0xffff_u); case 2: return b.Shr(low_u32, 16_u); default: { // 3: auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, VertexFormat::kUint32); auto* shr = b.Shr(low_u32, 24_u); auto* shl = b.Shl(high_u32, 8_u); return b.And(b.Or(shl, shr), 0xffff_u); } } }; // Returns a i16 loaded from offset, packed in the high 16 bits of a u32. // The low 16 bits are 0. auto load_i16_h = [&] { return b.Bitcast(load_u16_h()); }; // Assumptions are made that alignment must be at least as large as the size // of a single component. switch (format) { // Basic primitives case VertexFormat::kUint32: case VertexFormat::kSint32: case VertexFormat::kFloat32: return LoadPrimitive(array_base, offset, buffer, format); // Vectors of basic primitives case VertexFormat::kUint32x2: return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2); case VertexFormat::kUint32x3: return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3); case VertexFormat::kUint32x4: return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4); case VertexFormat::kSint32x2: return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2); case VertexFormat::kSint32x3: return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3); case VertexFormat::kSint32x4: return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4); case VertexFormat::kFloat32x2: return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, 2); case VertexFormat::kFloat32x3: return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, 3); case VertexFormat::kFloat32x4: return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, 4); case VertexFormat::kUint8x2: { // yyxx0000, yyxx0000 auto* u16s = b.vec2(load_u16_h()); // xx000000, yyxx0000 auto* shl = b.Shl(u16s, b.vec2(8_u, 0_u)); // 000000xx, 000000yy return b.Shr(shl, b.vec2(24_u)); } case VertexFormat::kUint8x4: { // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx auto* u32s = b.vec4(load_u32()); // xx000000, yyxx0000, zzyyxx00, wwzzyyxx auto* shl = b.Shl(u32s, b.vec4(24_u, 16_u, 8_u, 0_u)); // 000000xx, 000000yy, 000000zz, 000000ww return b.Shr(shl, b.vec4(24_u)); } case VertexFormat::kUint16x2: { // yyyyxxxx, yyyyxxxx auto* u32s = b.vec2(load_u32()); // xxxx0000, yyyyxxxx auto* shl = b.Shl(u32s, b.vec2(16_u, 0_u)); // 0000xxxx, 0000yyyy return b.Shr(shl, b.vec2(16_u)); } case VertexFormat::kUint16x4: { // yyyyxxxx, wwwwzzzz auto* u32s = b.vec2(load_u32(), load_next_u32()); // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz auto* xxyy = b.MemberAccessor(u32s, "xxyy"); // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz auto* shl = b.Shl(xxyy, b.vec4(16_u, 0_u, 16_u, 0_u)); // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww return b.Shr(shl, b.vec4(16_u)); } case VertexFormat::kSint8x2: { // yyxx0000, yyxx0000 auto* i16s = b.vec2(load_i16_h()); // xx000000, yyxx0000 auto* shl = b.Shl(i16s, b.vec2(8_u, 0_u)); // ssssssxx, ssssssyy return b.Shr(shl, b.vec2(24_u)); } case VertexFormat::kSint8x4: { // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx auto* i32s = b.vec4(load_i32()); // xx000000, yyxx0000, zzyyxx00, wwzzyyxx auto* shl = b.Shl(i32s, b.vec4(24_u, 16_u, 8_u, 0_u)); // ssssssxx, ssssssyy, sssssszz, ssssssww return b.Shr(shl, b.vec4(24_u)); } case VertexFormat::kSint16x2: { // yyyyxxxx, yyyyxxxx auto* i32s = b.vec2(load_i32()); // xxxx0000, yyyyxxxx auto* shl = b.Shl(i32s, b.vec2(16_u, 0_u)); // ssssxxxx, ssssyyyy return b.Shr(shl, b.vec2(16_u)); } case VertexFormat::kSint16x4: { // yyyyxxxx, wwwwzzzz auto* i32s = b.vec2(load_i32(), load_next_i32()); // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz auto* xxyy = b.MemberAccessor(i32s, "xxyy"); // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz auto* shl = b.Shl(xxyy, b.vec4(16_u, 0_u, 16_u, 0_u)); // ssssxxxx, ssssyyyy, sssszzzz, sssswwww return b.Shr(shl, b.vec4(16_u)); } case VertexFormat::kUnorm8x2: return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy"); case VertexFormat::kSnorm8x2: return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy"); case VertexFormat::kUnorm8x4: return b.Call("unpack4x8unorm", load_u32()); case VertexFormat::kSnorm8x4: return b.Call("unpack4x8snorm", load_u32()); case VertexFormat::kUnorm16x2: return b.Call("unpack2x16unorm", load_u32()); case VertexFormat::kSnorm16x2: return b.Call("unpack2x16snorm", load_u32()); case VertexFormat::kFloat16x2: return b.Call("unpack2x16float", load_u32()); case VertexFormat::kUnorm16x4: return b.vec4(b.Call("unpack2x16unorm", load_u32()), b.Call("unpack2x16unorm", load_next_u32())); case VertexFormat::kSnorm16x4: return b.vec4(b.Call("unpack2x16snorm", load_u32()), b.Call("unpack2x16snorm", load_next_u32())); case VertexFormat::kFloat16x4: return b.vec4(b.Call("unpack2x16float", load_u32()), b.Call("unpack2x16float", load_next_u32())); } TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast(format); return nullptr; } /// Generates an expression reading an aligned basic type (u32, i32, f32) from /// a vertex buffer. /// @param array_base the symbol of the variable holding the base array offset /// of the vertex array (each index is 4-bytes). /// @param offset the byte offset of the data from `buffer_base` /// @param buffer the index of the vertex buffer /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or /// VertexFormat::kFloat32 const ast::Expression* LoadPrimitive(Symbol array_base, uint32_t offset, uint32_t buffer, VertexFormat format) { const ast::Expression* u = nullptr; if ((offset & 3) == 0) { // Aligned load. const ast ::Expression* index = nullptr; if (offset > 0) { index = b.Add(array_base, u32(offset / 4)); } else { index = b.Expr(array_base); } u = b.IndexAccessor( b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index); } else { // Unaligned load uint32_t offset_aligned = offset & ~3u; auto* low = LoadPrimitive(array_base, offset_aligned, buffer, VertexFormat::kUint32); auto* high = LoadPrimitive(array_base, offset_aligned + 4u, buffer, VertexFormat::kUint32); uint32_t shift = 8u * (offset & 3u); auto* low_shr = b.Shr(low, u32(shift)); auto* high_shl = b.Shl(high, u32(32u - shift)); u = b.Or(low_shr, high_shl); } switch (format) { case VertexFormat::kUint32: return u; case VertexFormat::kSint32: return b.Bitcast(b.ty.i32(), u); case VertexFormat::kFloat32: return b.Bitcast(b.ty.f32(), u); default: break; } TINT_UNREACHABLE(Transform, b.Diagnostics()) << "invalid format for LoadPrimitive" << static_cast(format); return nullptr; } /// Generates an expression reading a vec2/3/4 from a vertex buffer. /// @param array_base the symbol of the variable holding the base array offset /// of the vertex array (each index is 4-bytes). /// @param offset the byte offset of the data from `buffer_base` /// @param buffer the index of the vertex buffer /// @param element_stride stride between elements, in bytes /// @param base_type underlying AST type /// @param base_format underlying vertex format /// @param count how many elements the vector has const ast::Expression* LoadVec(Symbol array_base, uint32_t offset, uint32_t buffer, uint32_t element_stride, const ast::Type* base_type, VertexFormat base_format, uint32_t count) { utils::Vector expr_list; for (uint32_t i = 0; i < count; ++i) { // Offset read position by element_stride for each component uint32_t primitive_offset = offset + element_stride * i; expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format)); } return b.Construct(b.create(base_type, count), std::move(expr_list)); } /// Process a non-struct entry point parameter. /// Generate function-scope variables for location parameters, and record /// vertex_index and instance_index builtins if present. /// @param func the entry point function /// @param param the parameter to process void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) { if (ast::HasAttribute(param->attributes)) { // Create a function-scope variable to replace the parameter. auto func_var_sym = ctx.Clone(param->symbol); auto* func_var_type = ctx.Clone(param->type); auto* func_var = b.Var(func_var_sym, func_var_type); ctx.InsertFront(func->body->statements, b.Decl(func_var)); // Capture mapping from location to the new variable. LocationInfo info; info.expr = [this, func_var]() { return b.Expr(func_var); }; auto* sem = src->Sem().Get(param); info.type = sem->Type(); if (!sem->Location().has_value()) { TINT_ICE(Transform, b.Diagnostics()) << "Location missing value"; return; } location_info[sem->Location().value()] = info; } else if (auto* builtin = ast::GetAttribute(param->attributes)) { // Check for existing vertex_index and instance_index builtins. if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; } else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) { instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; } new_function_parameters.Push(ctx.Clone(param)); } else { TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; } } /// Process a struct entry point parameter. /// If the struct has members with location attributes, push the parameter to /// a function-scope variable and create a new struct parameter without those /// attributes. Record expressions for members that are vertex_index and /// instance_index builtins. /// @param func the entry point function /// @param param the parameter to process /// @param struct_ty the structure type void ProcessStructParameter(const ast::Function* func, const ast::Parameter* param, const ast::Struct* struct_ty) { auto param_sym = ctx.Clone(param->symbol); // Process the struct members. bool has_locations = false; utils::Vector members_to_clone; for (auto* member : struct_ty->members) { auto member_sym = ctx.Clone(member->symbol); std::function member_expr = [this, param_sym, member_sym]() { return b.MemberAccessor(param_sym, member_sym); }; if (ast::HasAttribute(member->attributes)) { // Capture mapping from location to struct member. LocationInfo info; info.expr = member_expr; auto* sem = src->Sem().Get(member); info.type = sem->Type(); TINT_ASSERT(Transform, sem->Location().has_value()); location_info[sem->Location().value()] = info; has_locations = true; } else if (auto* builtin = ast::GetAttribute(member->attributes)) { // Check for existing vertex_index and instance_index builtins. if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { vertex_index_expr = member_expr; } else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) { instance_index_expr = member_expr; } members_to_clone.Push(member); } else { TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; } } if (!has_locations) { // Nothing to do. new_function_parameters.Push(ctx.Clone(param)); return; } // Create a function-scope variable to replace the parameter. auto* func_var = b.Var(param_sym, ctx.Clone(param->type)); ctx.InsertFront(func->body->statements, b.Decl(func_var)); if (!members_to_clone.IsEmpty()) { // Create a new struct without the location attributes. utils::Vector new_members; for (auto* member : members_to_clone) { auto member_sym = ctx.Clone(member->symbol); auto* member_type = ctx.Clone(member->type); auto member_attrs = ctx.Clone(member->attributes); new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs))); } auto* new_struct = b.Structure(b.Sym(), new_members); // Create a new function parameter with this struct. auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct)); new_function_parameters.Push(new_param); // Copy values from the new parameter to the function-scope variable. for (auto* member : members_to_clone) { auto member_name = ctx.Clone(member->symbol); ctx.InsertFront(func->body->statements, b.Assign(b.MemberAccessor(func_var, member_name), b.MemberAccessor(new_param, member_name))); } } } /// Process an entry point function. /// @param func the entry point function void Process(const ast::Function* func) { if (func->body->Empty()) { return; } // Process entry point parameters. for (auto* param : func->params) { auto* sem = src->Sem().Get(param); if (auto* str = sem->Type()->As()) { ProcessStructParameter(func, param, str->Declaration()); } else { ProcessNonStructParameter(func, param); } } // Insert new parameters for vertex_index and instance_index if needed. if (!vertex_index_expr) { for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { if (layout.step_mode == VertexStepMode::kVertex) { auto name = b.Symbols().New("tint_pulling_vertex_index"); new_function_parameters.Push( b.Param(name, b.ty.u32(), utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)})); vertex_index_expr = [this, name]() { return b.Expr(name); }; break; } } } if (!instance_index_expr) { for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { if (layout.step_mode == VertexStepMode::kInstance) { auto name = b.Symbols().New("tint_pulling_instance_index"); new_function_parameters.Push( b.Param(name, b.ty.u32(), utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)})); instance_index_expr = [this, name]() { return b.Expr(name); }; break; } } } // Generate vertex pulling preamble. if (auto* block = CreateVertexPullingPreamble()) { ctx.InsertFront(func->body->statements, block); } // Rewrite the function header with the new parameters. auto func_sym = ctx.Clone(func->symbol); auto* ret_type = ctx.Clone(func->return_type); auto* body = ctx.Clone(func->body); auto attrs = ctx.Clone(func->attributes); auto ret_attrs = ctx.Clone(func->return_type_attributes); auto* new_func = b.create(func->source, func_sym, new_function_parameters, ret_type, body, std::move(attrs), std::move(ret_attrs)); ctx.Replace(func, new_func); } }; VertexPulling::VertexPulling() = default; VertexPulling::~VertexPulling() = default; Transform::ApplyResult VertexPulling::Apply(const Program* src, const DataMap& inputs, DataMap&) const { auto cfg = cfg_; if (auto* cfg_data = inputs.Get()) { cfg = *cfg_data; } return State{src, cfg}.Run(); } VertexPulling::Config::Config() = default; VertexPulling::Config::Config(const Config&) = default; VertexPulling::Config::~Config() = default; VertexPulling::Config& VertexPulling::Config::operator=(const Config&) = default; VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default; VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor( uint32_t in_array_stride, VertexStepMode in_step_mode, std::vector in_attributes) : array_stride(in_array_stride), step_mode(in_step_mode), attributes(std::move(in_attributes)) {} VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor( const VertexBufferLayoutDescriptor& other) = default; VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=( const VertexBufferLayoutDescriptor& other) = default; VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default; } // namespace tint::transform