dawn-cmake/src/tint/transform/vertex_pulling.cc
Zhaoming Jiang f42d92a1d5 Tint: Refactor transform VertexPulling and its unit tests
This CL prepare transform VertexPulling and its unit tests for
implementing f16 pipeline io. This CL distinguishes vertex format type
and WGSL variable type of a vertex shader attribute (location input) in
VertexPuilling transform as both `f32` and `f16` WGSL types would be
mapepd to float vertex format. This CL splits VertexPulling unit tests
by base veretx format (SInt, UInt and Float), make it easier to add
`f16` tests.

Bugs: tint:1473, tint:1502
Change-Id: I649deb61e8eb8dac6ebd653bf77ef96475334a56
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112520
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
2022-12-02 18:11:57 +00:00

946 lines
40 KiB
C++

// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/vertex_pulling.h"
#include <algorithm>
#include <utility>
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/bitcast_expression.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/math.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling);
TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);
using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
namespace {
/// The base WGSL type of a component.
/// The format type is either this type or a vector of this type.
enum class BaseWGSLType {
kInvalid,
kU32,
kI32,
kF32,
};
/// The data type of a vertex format.
/// The format type is either this type or a vector of this type.
enum class VertexDataType {
kInvalid,
kUInt, // unsigned int
kSInt, // signed int
kFloat, // unsigned normalized, signed normalized, and float
};
/// Writes the VertexFormat to the std::ostream.
/// @param out the std::ostream to write to
/// @param format the VertexFormat to write
/// @returns out so calls can be chained
std::ostream& operator<<(std::ostream& out, VertexFormat format) {
switch (format) {
case VertexFormat::kUint8x2:
return out << "uint8x2";
case VertexFormat::kUint8x4:
return out << "uint8x4";
case VertexFormat::kSint8x2:
return out << "sint8x2";
case VertexFormat::kSint8x4:
return out << "sint8x4";
case VertexFormat::kUnorm8x2:
return out << "unorm8x2";
case VertexFormat::kUnorm8x4:
return out << "unorm8x4";
case VertexFormat::kSnorm8x2:
return out << "snorm8x2";
case VertexFormat::kSnorm8x4:
return out << "snorm8x4";
case VertexFormat::kUint16x2:
return out << "uint16x2";
case VertexFormat::kUint16x4:
return out << "uint16x4";
case VertexFormat::kSint16x2:
return out << "sint16x2";
case VertexFormat::kSint16x4:
return out << "sint16x4";
case VertexFormat::kUnorm16x2:
return out << "unorm16x2";
case VertexFormat::kUnorm16x4:
return out << "unorm16x4";
case VertexFormat::kSnorm16x2:
return out << "snorm16x2";
case VertexFormat::kSnorm16x4:
return out << "snorm16x4";
case VertexFormat::kFloat16x2:
return out << "float16x2";
case VertexFormat::kFloat16x4:
return out << "float16x4";
case VertexFormat::kFloat32:
return out << "float32";
case VertexFormat::kFloat32x2:
return out << "float32x2";
case VertexFormat::kFloat32x3:
return out << "float32x3";
case VertexFormat::kFloat32x4:
return out << "float32x4";
case VertexFormat::kUint32:
return out << "uint32";
case VertexFormat::kUint32x2:
return out << "uint32x2";
case VertexFormat::kUint32x3:
return out << "uint32x3";
case VertexFormat::kUint32x4:
return out << "uint32x4";
case VertexFormat::kSint32:
return out << "sint32";
case VertexFormat::kSint32x2:
return out << "sint32x2";
case VertexFormat::kSint32x3:
return out << "sint32x3";
case VertexFormat::kSint32x4:
return out << "sint32x4";
}
return out << "<unknown>";
}
/// Type information of a vertex input attribute.
struct AttributeWGSLType {
BaseWGSLType base_type;
uint32_t width; // 1 for scalar, 2+ for a vector
};
/// Type information of a vertex format.
struct VertexFormatType {
VertexDataType base_type;
uint32_t width; // 1 for scalar, 2+ for a vector
};
// Check if base types match between the WGSL variable and the vertex format
bool IsTypeCompatible(AttributeWGSLType wgslType, VertexFormatType vertexFormatType) {
switch (wgslType.base_type) {
case BaseWGSLType::kF32:
return (vertexFormatType.base_type == VertexDataType::kFloat);
case BaseWGSLType::kU32:
return (vertexFormatType.base_type == VertexDataType::kUInt);
case BaseWGSLType::kI32:
return (vertexFormatType.base_type == VertexDataType::kSInt);
default:
return false;
}
}
AttributeWGSLType WGSLTypeOf(const sem::Type* ty) {
if (ty->Is<sem::I32>()) {
return {BaseWGSLType::kI32, 1};
}
if (ty->Is<sem::U32>()) {
return {BaseWGSLType::kU32, 1};
}
if (ty->Is<sem::F32>()) {
return {BaseWGSLType::kF32, 1};
}
if (auto* vec = ty->As<sem::Vector>()) {
return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
}
return {BaseWGSLType::kInvalid, 0};
}
VertexFormatType VertexFormatTypeOf(VertexFormat format) {
switch (format) {
case VertexFormat::kUint32:
return {VertexDataType::kUInt, 1};
case VertexFormat::kUint8x2:
case VertexFormat::kUint16x2:
case VertexFormat::kUint32x2:
return {VertexDataType::kUInt, 2};
case VertexFormat::kUint32x3:
return {VertexDataType::kUInt, 3};
case VertexFormat::kUint8x4:
case VertexFormat::kUint16x4:
case VertexFormat::kUint32x4:
return {VertexDataType::kUInt, 4};
case VertexFormat::kSint32:
return {VertexDataType::kSInt, 1};
case VertexFormat::kSint8x2:
case VertexFormat::kSint16x2:
case VertexFormat::kSint32x2:
return {VertexDataType::kSInt, 2};
case VertexFormat::kSint32x3:
return {VertexDataType::kSInt, 3};
case VertexFormat::kSint8x4:
case VertexFormat::kSint16x4:
case VertexFormat::kSint32x4:
return {VertexDataType::kSInt, 4};
case VertexFormat::kFloat32:
return {VertexDataType::kFloat, 1};
case VertexFormat::kUnorm8x2:
case VertexFormat::kSnorm8x2:
case VertexFormat::kUnorm16x2:
case VertexFormat::kSnorm16x2:
case VertexFormat::kFloat16x2:
case VertexFormat::kFloat32x2:
return {VertexDataType::kFloat, 2};
case VertexFormat::kFloat32x3:
return {VertexDataType::kFloat, 3};
case VertexFormat::kUnorm8x4:
case VertexFormat::kSnorm8x4:
case VertexFormat::kUnorm16x4:
case VertexFormat::kSnorm16x4:
case VertexFormat::kFloat16x4:
case VertexFormat::kFloat32x4:
return {VertexDataType::kFloat, 4};
}
return {VertexDataType::kInvalid, 0};
}
} // namespace
/// PIMPL state for the transform
struct VertexPulling::State {
/// Constructor
/// @param program the source program
/// @param c the VertexPulling config
State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {}
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
// Find entry point
const ast::Function* func = nullptr;
for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
if (func != nullptr) {
b.Diagnostics().add_error(
diag::System::Transform,
"VertexPulling found more than one vertex entry point");
return Program(std::move(b));
}
func = fn;
}
}
if (func == nullptr) {
b.Diagnostics().add_error(diag::System::Transform,
"Vertex stage entry point not found");
return Program(std::move(b));
}
AddVertexStorageBuffers();
Process(func);
ctx.Clone();
return Program(std::move(b));
}
private:
/// LocationReplacement describes an ast::Variable replacement for a location input.
struct LocationReplacement {
/// The variable to replace in the source Program
ast::Variable* from;
/// The replacement to use in the target ProgramBuilder
ast::Variable* to;
};
/// LocationInfo describes an input location
struct LocationInfo {
/// A builder that builds the expression that resolves to the (transformed) input location
std::function<const ast::Expression*()> expr;
/// The store type of the location variable
const sem::Type* type;
};
/// The source program
const Program* const src;
/// The transform config
VertexPulling::Config const cfg;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
std::unordered_map<uint32_t, LocationInfo> location_info;
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
std::function<const ast::Expression*()> instance_index_expr = nullptr;
Symbol pulling_position_name;
Symbol struct_buffer_name;
std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
utils::Vector<const ast::Parameter*, 8> new_function_parameters;
/// Generate the vertex buffer binding name
/// @param index index to append to buffer name
Symbol GetVertexBufferName(uint32_t index) {
return utils::GetOrCreate(vertex_buffer_names, index, [&] {
static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
});
}
/// Lazily generates the structure buffer symbol
Symbol GetStructBufferName() {
if (!struct_buffer_name.IsValid()) {
static const char kStructBufferName[] = "tint_vertex_data";
struct_buffer_name = b.Symbols().New(kStructBufferName);
}
return struct_buffer_name;
}
/// Adds storage buffer decorated variables for the vertex buffers
void AddVertexStorageBuffers() {
// Creating the struct type
static const char kStructName[] = "TintVertexData";
auto* struct_type = b.Structure(b.Symbols().New(kStructName),
utils::Vector{
b.Member(GetStructBufferName(), b.ty.array<u32>()),
});
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type
b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage,
ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group)));
}
}
/// Creates and returns the assignment to the variables from the buffers
const ast::BlockStatement* CreateVertexPullingPreamble() {
// Assign by looking at the vertex descriptor to find attributes with
// matching location.
utils::Vector<const ast::Statement*, 8> stmts;
for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) {
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
if ((buffer_layout.array_stride & 3) != 0) {
b.Diagnostics().add_error(
diag::System::Transform,
"WebGPU requires that vertex stride must be a multiple of 4 bytes, "
"but VertexPulling array stride for buffer " +
std::to_string(buffer_idx) + " was " +
std::to_string(buffer_layout.array_stride) + " bytes");
return nullptr;
}
auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex
? vertex_index_expr()
: instance_index_expr();
// buffer_array_base is the base array offset for all the vertex
// attributes. These are units of uint (4 bytes).
auto buffer_array_base =
b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
auto* attribute_offset = index_expr;
if (buffer_layout.array_stride != 4) {
attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u));
}
// let pulling_offset_n = <attribute_offset>
stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset)));
for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) {
auto it = location_info.find(attribute_desc.shader_location);
if (it == location_info.end()) {
continue;
}
auto& var = it->second;
// Data type of the target WGSL variable
auto var_dt = WGSLTypeOf(var.type);
// Data type of the vertex stream attribute
auto fmt_dt = VertexFormatTypeOf(attribute_desc.format);
// Base types must match between the vertex stream and the WGSL variable
if (!IsTypeCompatible(var_dt, fmt_dt)) {
std::stringstream err;
err << "VertexAttributeDescriptor for location "
<< std::to_string(attribute_desc.shader_location) << " has format "
<< attribute_desc.format << " but shader expects "
<< var.type->FriendlyName(src->Symbols());
b.Diagnostics().add_error(diag::System::Transform, err.str());
return nullptr;
}
// Load the attribute value according to vertex format and convert the element type
// of result to match target WGSL variable. The result of `Fetch` should be of WGSL
// types `f32`, `i32`, `u32`, and their vectors.
auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx,
attribute_desc.format);
// The attribute value may not be of the desired vector width. If it is not, we'll
// need to either reduce the width with a swizzle, or append 0's and / or a 1.
auto* value = fetch;
if (var_dt.width < fmt_dt.width) {
// WGSL variable vector width is smaller than the loaded vector width
switch (var_dt.width) {
case 1:
value = b.MemberAccessor(fetch, "x");
break;
case 2:
value = b.MemberAccessor(fetch, "xy");
break;
case 3:
value = b.MemberAccessor(fetch, "xyz");
break;
default:
TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width;
return nullptr;
}
} else if (var_dt.width > fmt_dt.width) {
// WGSL variable vector width is wider than the loaded vector width, do padding.
// The components of result vector variable, initialized with type-converted
// loaded data vector.
utils::Vector<const ast::Expression*, 8> values{fetch};
// Add padding elements. The result must be of vector types of signed/unsigned
// integer or float, so use the abstract integer or abstract float value to do
// padding.
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
if (var_dt.base_type == BaseWGSLType::kI32 ||
var_dt.base_type == BaseWGSLType::kU32) {
values.Push(b.Expr((i == 3) ? 1_a : 0_a));
} else {
values.Push(b.Expr((i == 3) ? 1.0_a : 0.0_a));
}
}
const ast::Type* target_ty = CreateASTTypeFor(ctx, var.type);
value = b.Construct(target_ty, values);
}
// Assign the value to the WGSL variable
stmts.Push(b.Assign(var.expr(), value));
}
}
if (stmts.IsEmpty()) {
return nullptr;
}
return b.Block(std::move(stmts));
}
/// Generates an expression reading a specific vertex format from a buffer. Any vertex format of
/// signed normailized, unsigned normailized, or float will result in `f32` or `vecN<f32>` WGSL
/// type.
/// @param array_base the symbol of the variable holding the base array offset
/// of the vertex array (each index is 4-bytes).
/// @param offset the byte offset of the data from `buffer_base`
/// @param buffer the index of the vertex buffer
/// @param format the vertex format to read
const ast::Expression* Fetch(Symbol array_base,
uint32_t offset,
uint32_t buffer,
VertexFormat format) {
// Returns a u32 loaded from buffer_base + offset.
auto load_u32 = [&] {
return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32);
};
// Returns a i32 loaded from buffer_base + offset.
auto load_i32 = [&] { return b.Bitcast<i32>(load_u32()); };
// Returns a u32 loaded from buffer_base + offset + 4.
auto load_next_u32 = [&] {
return LoadPrimitive(array_base, offset + 4, buffer, VertexFormat::kUint32);
};
// Returns a i32 loaded from buffer_base + offset + 4.
auto load_next_i32 = [&] { return b.Bitcast<i32>(load_next_u32()); };
// Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
// The low 16 bits are 0.
// `min_alignment` must be a power of two.
// `offset` must be `min_alignment` bytes aligned.
auto load_u16_h = [&] {
auto low_u32_offset = offset & ~3u;
auto* low_u32 =
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
switch (offset & 3) {
case 0:
return b.Shl(low_u32, 16_u);
case 1:
return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u);
case 2:
return b.And(low_u32, 0xffff0000_u);
default: { // 3:
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
VertexFormat::kUint32);
auto* shr = b.Shr(low_u32, 8_u);
auto* shl = b.Shl(high_u32, 24_u);
return b.And(b.Or(shl, shr), 0xffff0000_u);
}
}
};
// Returns a u16 loaded from offset, packed in the low 16 bits of a u32.
// The high 16 bits are 0.
auto load_u16_l = [&] {
auto low_u32_offset = offset & ~3u;
auto* low_u32 =
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
switch (offset & 3) {
case 0:
return b.And(low_u32, 0xffff_u);
case 1:
return b.And(b.Shr(low_u32, 8_u), 0xffff_u);
case 2:
return b.Shr(low_u32, 16_u);
default: { // 3:
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
VertexFormat::kUint32);
auto* shr = b.Shr(low_u32, 24_u);
auto* shl = b.Shl(high_u32, 8_u);
return b.And(b.Or(shl, shr), 0xffff_u);
}
}
};
// Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
// The low 16 bits are 0.
auto load_i16_h = [&] { return b.Bitcast<i32>(load_u16_h()); };
// Assumptions are made that alignment must be at least as large as the size
// of a single component.
switch (format) {
// Basic primitives
case VertexFormat::kUint32:
case VertexFormat::kSint32:
case VertexFormat::kFloat32:
return LoadPrimitive(array_base, offset, buffer, format);
// Vectors of basic primitives
case VertexFormat::kUint32x2:
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2);
case VertexFormat::kUint32x3:
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3);
case VertexFormat::kUint32x4:
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4);
case VertexFormat::kSint32x2:
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2);
case VertexFormat::kSint32x3:
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3);
case VertexFormat::kSint32x4:
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4);
case VertexFormat::kFloat32x2:
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
2);
case VertexFormat::kFloat32x3:
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
3);
case VertexFormat::kFloat32x4:
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
4);
case VertexFormat::kUint8x2: {
// yyxx0000, yyxx0000
auto* u16s = b.vec2<u32>(load_u16_h());
// xx000000, yyxx0000
auto* shl = b.Shl(u16s, b.vec2<u32>(8_u, 0_u));
// 000000xx, 000000yy
return b.Shr(shl, b.vec2<u32>(24_u));
}
case VertexFormat::kUint8x4: {
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
auto* u32s = b.vec4<u32>(load_u32());
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
auto* shl = b.Shl(u32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
// 000000xx, 000000yy, 000000zz, 000000ww
return b.Shr(shl, b.vec4<u32>(24_u));
}
case VertexFormat::kUint16x2: {
// yyyyxxxx, yyyyxxxx
auto* u32s = b.vec2<u32>(load_u32());
// xxxx0000, yyyyxxxx
auto* shl = b.Shl(u32s, b.vec2<u32>(16_u, 0_u));
// 0000xxxx, 0000yyyy
return b.Shr(shl, b.vec2<u32>(16_u));
}
case VertexFormat::kUint16x4: {
// yyyyxxxx, wwwwzzzz
auto* u32s = b.vec2<u32>(load_u32(), load_next_u32());
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
auto* xxyy = b.MemberAccessor(u32s, "xxyy");
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
// 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
return b.Shr(shl, b.vec4<u32>(16_u));
}
case VertexFormat::kSint8x2: {
// yyxx0000, yyxx0000
auto* i16s = b.vec2<i32>(load_i16_h());
// xx000000, yyxx0000
auto* shl = b.Shl(i16s, b.vec2<u32>(8_u, 0_u));
// ssssssxx, ssssssyy
return b.Shr(shl, b.vec2<u32>(24_u));
}
case VertexFormat::kSint8x4: {
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
auto* i32s = b.vec4<i32>(load_i32());
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
auto* shl = b.Shl(i32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
// ssssssxx, ssssssyy, sssssszz, ssssssww
return b.Shr(shl, b.vec4<u32>(24_u));
}
case VertexFormat::kSint16x2: {
// yyyyxxxx, yyyyxxxx
auto* i32s = b.vec2<i32>(load_i32());
// xxxx0000, yyyyxxxx
auto* shl = b.Shl(i32s, b.vec2<u32>(16_u, 0_u));
// ssssxxxx, ssssyyyy
return b.Shr(shl, b.vec2<u32>(16_u));
}
case VertexFormat::kSint16x4: {
// yyyyxxxx, wwwwzzzz
auto* i32s = b.vec2<i32>(load_i32(), load_next_i32());
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
auto* xxyy = b.MemberAccessor(i32s, "xxyy");
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
// ssssxxxx, ssssyyyy, sssszzzz, sssswwww
return b.Shr(shl, b.vec4<u32>(16_u));
}
case VertexFormat::kUnorm8x2:
return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy");
case VertexFormat::kSnorm8x2:
return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy");
case VertexFormat::kUnorm8x4:
return b.Call("unpack4x8unorm", load_u32());
case VertexFormat::kSnorm8x4:
return b.Call("unpack4x8snorm", load_u32());
case VertexFormat::kUnorm16x2:
return b.Call("unpack2x16unorm", load_u32());
case VertexFormat::kSnorm16x2:
return b.Call("unpack2x16snorm", load_u32());
case VertexFormat::kFloat16x2:
return b.Call("unpack2x16float", load_u32());
case VertexFormat::kUnorm16x4:
return b.vec4<f32>(b.Call("unpack2x16unorm", load_u32()),
b.Call("unpack2x16unorm", load_next_u32()));
case VertexFormat::kSnorm16x4:
return b.vec4<f32>(b.Call("unpack2x16snorm", load_u32()),
b.Call("unpack2x16snorm", load_next_u32()));
case VertexFormat::kFloat16x4:
return b.vec4<f32>(b.Call("unpack2x16float", load_u32()),
b.Call("unpack2x16float", load_next_u32()));
}
TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast<int>(format);
return nullptr;
}
/// Generates an expression reading an aligned basic type (u32, i32, f32) from
/// a vertex buffer.
/// @param array_base the symbol of the variable holding the base array offset
/// of the vertex array (each index is 4-bytes).
/// @param offset the byte offset of the data from `buffer_base`
/// @param buffer the index of the vertex buffer
/// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
/// VertexFormat::kFloat32
const ast::Expression* LoadPrimitive(Symbol array_base,
uint32_t offset,
uint32_t buffer,
VertexFormat format) {
const ast::Expression* u = nullptr;
if ((offset & 3) == 0) {
// Aligned load.
const ast ::Expression* index = nullptr;
if (offset > 0) {
index = b.Add(array_base, u32(offset / 4));
} else {
index = b.Expr(array_base);
}
u = b.IndexAccessor(
b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
} else {
// Unaligned load
uint32_t offset_aligned = offset & ~3u;
auto* low = LoadPrimitive(array_base, offset_aligned, buffer, VertexFormat::kUint32);
auto* high =
LoadPrimitive(array_base, offset_aligned + 4u, buffer, VertexFormat::kUint32);
uint32_t shift = 8u * (offset & 3u);
auto* low_shr = b.Shr(low, u32(shift));
auto* high_shl = b.Shl(high, u32(32u - shift));
u = b.Or(low_shr, high_shl);
}
switch (format) {
case VertexFormat::kUint32:
return u;
case VertexFormat::kSint32:
return b.Bitcast(b.ty.i32(), u);
case VertexFormat::kFloat32:
return b.Bitcast(b.ty.f32(), u);
default:
break;
}
TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "invalid format for LoadPrimitive" << static_cast<int>(format);
return nullptr;
}
/// Generates an expression reading a vec2/3/4 from a vertex buffer.
/// @param array_base the symbol of the variable holding the base array offset
/// of the vertex array (each index is 4-bytes).
/// @param offset the byte offset of the data from `buffer_base`
/// @param buffer the index of the vertex buffer
/// @param element_stride stride between elements, in bytes
/// @param base_type underlying AST type
/// @param base_format underlying vertex format
/// @param count how many elements the vector has
const ast::Expression* LoadVec(Symbol array_base,
uint32_t offset,
uint32_t buffer,
uint32_t element_stride,
const ast::Type* base_type,
VertexFormat base_format,
uint32_t count) {
utils::Vector<const ast::Expression*, 8> expr_list;
for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component
uint32_t primitive_offset = offset + element_stride * i;
expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format));
}
return b.Construct(b.create<ast::Vector>(base_type, count), std::move(expr_list));
}
/// Process a non-struct entry point parameter.
/// Generate function-scope variables for location parameters, and record
/// vertex_index and instance_index builtins if present.
/// @param func the entry point function
/// @param param the parameter to process
void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) {
if (ast::HasAttribute<ast::LocationAttribute>(param->attributes)) {
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol);
auto* func_var_type = ctx.Clone(param->type);
auto* func_var = b.Var(func_var_sym, func_var_type);
ctx.InsertFront(func->body->statements, b.Decl(func_var));
// Capture mapping from location to the new variable.
LocationInfo info;
info.expr = [this, func_var]() { return b.Expr(func_var); };
auto* sem = src->Sem().Get<sem::Parameter>(param);
info.type = sem->Type();
if (!sem->Location().has_value()) {
TINT_ICE(Transform, b.Diagnostics()) << "Location missing value";
return;
}
location_info[sem->Location().value()] = info;
} else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
// Check for existing vertex_index and instance_index builtins.
if (builtin->builtin == ast::BuiltinValue::kVertexIndex) {
vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
} else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) {
instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
}
new_function_parameters.Push(ctx.Clone(param));
} else {
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
}
}
/// Process a struct entry point parameter.
/// If the struct has members with location attributes, push the parameter to
/// a function-scope variable and create a new struct parameter without those
/// attributes. Record expressions for members that are vertex_index and
/// instance_index builtins.
/// @param func the entry point function
/// @param param the parameter to process
/// @param struct_ty the structure type
void ProcessStructParameter(const ast::Function* func,
const ast::Parameter* param,
const ast::Struct* struct_ty) {
auto param_sym = ctx.Clone(param->symbol);
// Process the struct members.
bool has_locations = false;
utils::Vector<const ast::StructMember*, 8> members_to_clone;
for (auto* member : struct_ty->members) {
auto member_sym = ctx.Clone(member->symbol);
std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() {
return b.MemberAccessor(param_sym, member_sym);
};
if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) {
// Capture mapping from location to struct member.
LocationInfo info;
info.expr = member_expr;
auto* sem = src->Sem().Get(member);
info.type = sem->Type();
TINT_ASSERT(Transform, sem->Location().has_value());
location_info[sem->Location().value()] = info;
has_locations = true;
} else if (auto* builtin =
ast::GetAttribute<ast::BuiltinAttribute>(member->attributes)) {
// Check for existing vertex_index and instance_index builtins.
if (builtin->builtin == ast::BuiltinValue::kVertexIndex) {
vertex_index_expr = member_expr;
} else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) {
instance_index_expr = member_expr;
}
members_to_clone.Push(member);
} else {
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
}
}
if (!has_locations) {
// Nothing to do.
new_function_parameters.Push(ctx.Clone(param));
return;
}
// Create a function-scope variable to replace the parameter.
auto* func_var = b.Var(param_sym, ctx.Clone(param->type));
ctx.InsertFront(func->body->statements, b.Decl(func_var));
if (!members_to_clone.IsEmpty()) {
// Create a new struct without the location attributes.
utils::Vector<const ast::StructMember*, 8> new_members;
for (auto* member : members_to_clone) {
auto member_sym = ctx.Clone(member->symbol);
auto* member_type = ctx.Clone(member->type);
auto member_attrs = ctx.Clone(member->attributes);
new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs)));
}
auto* new_struct = b.Structure(b.Sym(), new_members);
// Create a new function parameter with this struct.
auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct));
new_function_parameters.Push(new_param);
// Copy values from the new parameter to the function-scope variable.
for (auto* member : members_to_clone) {
auto member_name = ctx.Clone(member->symbol);
ctx.InsertFront(func->body->statements,
b.Assign(b.MemberAccessor(func_var, member_name),
b.MemberAccessor(new_param, member_name)));
}
}
}
/// Process an entry point function.
/// @param func the entry point function
void Process(const ast::Function* func) {
if (func->body->Empty()) {
return;
}
// Process entry point parameters.
for (auto* param : func->params) {
auto* sem = src->Sem().Get(param);
if (auto* str = sem->Type()->As<sem::Struct>()) {
ProcessStructParameter(func, param, str->Declaration());
} else {
ProcessNonStructParameter(func, param);
}
}
// Insert new parameters for vertex_index and instance_index if needed.
if (!vertex_index_expr) {
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
if (layout.step_mode == VertexStepMode::kVertex) {
auto name = b.Symbols().New("tint_pulling_vertex_index");
new_function_parameters.Push(
b.Param(name, b.ty.u32(),
utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)}));
vertex_index_expr = [this, name]() { return b.Expr(name); };
break;
}
}
}
if (!instance_index_expr) {
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
if (layout.step_mode == VertexStepMode::kInstance) {
auto name = b.Symbols().New("tint_pulling_instance_index");
new_function_parameters.Push(
b.Param(name, b.ty.u32(),
utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)}));
instance_index_expr = [this, name]() { return b.Expr(name); };
break;
}
}
}
// Generate vertex pulling preamble.
if (auto* block = CreateVertexPullingPreamble()) {
ctx.InsertFront(func->body->statements, block);
}
// Rewrite the function header with the new parameters.
auto func_sym = ctx.Clone(func->symbol);
auto* ret_type = ctx.Clone(func->return_type);
auto* body = ctx.Clone(func->body);
auto attrs = ctx.Clone(func->attributes);
auto ret_attrs = ctx.Clone(func->return_type_attributes);
auto* new_func =
b.create<ast::Function>(func->source, func_sym, new_function_parameters, ret_type, body,
std::move(attrs), std::move(ret_attrs));
ctx.Replace(func, new_func);
}
};
VertexPulling::VertexPulling() = default;
VertexPulling::~VertexPulling() = default;
Transform::ApplyResult VertexPulling::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
auto cfg = cfg_;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
return State{src, cfg}.Run();
}
VertexPulling::Config::Config() = default;
VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default;
VertexPulling::Config& VertexPulling::Config::operator=(const Config&) = default;
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
uint32_t in_array_stride,
VertexStepMode in_step_mode,
std::vector<VertexAttributeDescriptor> in_attributes)
: array_stride(in_array_stride),
step_mode(in_step_mode),
attributes(std::move(in_attributes)) {}
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
const VertexBufferLayoutDescriptor& other) = default;
VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=(
const VertexBufferLayoutDescriptor& other) = default;
VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
} // namespace tint::transform