mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-05-14 11:21:40 +00:00
Transforms are supposed to be immutable, operating on the DataMaps provided for input and output, so make the methods const. Add a ShouldRun() method which the Manager can use to skip over transforms that do not need to be run. Change-Id: I320ac964577e94ac988748d8aca85bd43ee8d3b5 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77120 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
963 lines
36 KiB
C++
963 lines
36 KiB
C++
// Copyright 2020 The Tint Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "src/transform/vertex_pulling.h"
|
|
|
|
#include <algorithm>
|
|
#include <utility>
|
|
|
|
#include "src/ast/assignment_statement.h"
|
|
#include "src/ast/bitcast_expression.h"
|
|
#include "src/ast/variable_decl_statement.h"
|
|
#include "src/program_builder.h"
|
|
#include "src/sem/variable.h"
|
|
#include "src/utils/map.h"
|
|
#include "src/utils/math.h"
|
|
|
|
TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling);
|
|
TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);
|
|
|
|
namespace tint {
|
|
namespace transform {
|
|
|
|
namespace {
|
|
|
|
/// The base type of a component.
|
|
/// The format type is either this type or a vector of this type.
|
|
enum class BaseType {
|
|
kInvalid,
|
|
kU32,
|
|
kI32,
|
|
kF32,
|
|
};
|
|
|
|
/// Writes the BaseType to the std::ostream.
|
|
/// @param out the std::ostream to write to
|
|
/// @param format the BaseType to write
|
|
/// @returns out so calls can be chained
|
|
std::ostream& operator<<(std::ostream& out, BaseType format) {
|
|
switch (format) {
|
|
case BaseType::kInvalid:
|
|
return out << "invalid";
|
|
case BaseType::kU32:
|
|
return out << "u32";
|
|
case BaseType::kI32:
|
|
return out << "i32";
|
|
case BaseType::kF32:
|
|
return out << "f32";
|
|
}
|
|
return out << "<unknown>";
|
|
}
|
|
|
|
/// Writes the VertexFormat to the std::ostream.
|
|
/// @param out the std::ostream to write to
|
|
/// @param format the VertexFormat to write
|
|
/// @returns out so calls can be chained
|
|
std::ostream& operator<<(std::ostream& out, VertexFormat format) {
|
|
switch (format) {
|
|
case VertexFormat::kUint8x2:
|
|
return out << "uint8x2";
|
|
case VertexFormat::kUint8x4:
|
|
return out << "uint8x4";
|
|
case VertexFormat::kSint8x2:
|
|
return out << "sint8x2";
|
|
case VertexFormat::kSint8x4:
|
|
return out << "sint8x4";
|
|
case VertexFormat::kUnorm8x2:
|
|
return out << "unorm8x2";
|
|
case VertexFormat::kUnorm8x4:
|
|
return out << "unorm8x4";
|
|
case VertexFormat::kSnorm8x2:
|
|
return out << "snorm8x2";
|
|
case VertexFormat::kSnorm8x4:
|
|
return out << "snorm8x4";
|
|
case VertexFormat::kUint16x2:
|
|
return out << "uint16x2";
|
|
case VertexFormat::kUint16x4:
|
|
return out << "uint16x4";
|
|
case VertexFormat::kSint16x2:
|
|
return out << "sint16x2";
|
|
case VertexFormat::kSint16x4:
|
|
return out << "sint16x4";
|
|
case VertexFormat::kUnorm16x2:
|
|
return out << "unorm16x2";
|
|
case VertexFormat::kUnorm16x4:
|
|
return out << "unorm16x4";
|
|
case VertexFormat::kSnorm16x2:
|
|
return out << "snorm16x2";
|
|
case VertexFormat::kSnorm16x4:
|
|
return out << "snorm16x4";
|
|
case VertexFormat::kFloat16x2:
|
|
return out << "float16x2";
|
|
case VertexFormat::kFloat16x4:
|
|
return out << "float16x4";
|
|
case VertexFormat::kFloat32:
|
|
return out << "float32";
|
|
case VertexFormat::kFloat32x2:
|
|
return out << "float32x2";
|
|
case VertexFormat::kFloat32x3:
|
|
return out << "float32x3";
|
|
case VertexFormat::kFloat32x4:
|
|
return out << "float32x4";
|
|
case VertexFormat::kUint32:
|
|
return out << "uint32";
|
|
case VertexFormat::kUint32x2:
|
|
return out << "uint32x2";
|
|
case VertexFormat::kUint32x3:
|
|
return out << "uint32x3";
|
|
case VertexFormat::kUint32x4:
|
|
return out << "uint32x4";
|
|
case VertexFormat::kSint32:
|
|
return out << "sint32";
|
|
case VertexFormat::kSint32x2:
|
|
return out << "sint32x2";
|
|
case VertexFormat::kSint32x3:
|
|
return out << "sint32x3";
|
|
case VertexFormat::kSint32x4:
|
|
return out << "sint32x4";
|
|
}
|
|
return out << "<unknown>";
|
|
}
|
|
|
|
/// A vertex attribute data format.
|
|
struct DataType {
|
|
BaseType base_type;
|
|
uint32_t width; // 1 for scalar, 2+ for a vector
|
|
};
|
|
|
|
DataType DataTypeOf(const sem::Type* ty) {
|
|
if (ty->Is<sem::I32>()) {
|
|
return {BaseType::kI32, 1};
|
|
}
|
|
if (ty->Is<sem::U32>()) {
|
|
return {BaseType::kU32, 1};
|
|
}
|
|
if (ty->Is<sem::F32>()) {
|
|
return {BaseType::kF32, 1};
|
|
}
|
|
if (auto* vec = ty->As<sem::Vector>()) {
|
|
return {DataTypeOf(vec->type()).base_type, vec->Width()};
|
|
}
|
|
return {BaseType::kInvalid, 0};
|
|
}
|
|
|
|
DataType DataTypeOf(VertexFormat format) {
|
|
switch (format) {
|
|
case VertexFormat::kUint32:
|
|
return {BaseType::kU32, 1};
|
|
case VertexFormat::kUint8x2:
|
|
case VertexFormat::kUint16x2:
|
|
case VertexFormat::kUint32x2:
|
|
return {BaseType::kU32, 2};
|
|
case VertexFormat::kUint32x3:
|
|
return {BaseType::kU32, 3};
|
|
case VertexFormat::kUint8x4:
|
|
case VertexFormat::kUint16x4:
|
|
case VertexFormat::kUint32x4:
|
|
return {BaseType::kU32, 4};
|
|
case VertexFormat::kSint32:
|
|
return {BaseType::kI32, 1};
|
|
case VertexFormat::kSint8x2:
|
|
case VertexFormat::kSint16x2:
|
|
case VertexFormat::kSint32x2:
|
|
return {BaseType::kI32, 2};
|
|
case VertexFormat::kSint32x3:
|
|
return {BaseType::kI32, 3};
|
|
case VertexFormat::kSint8x4:
|
|
case VertexFormat::kSint16x4:
|
|
case VertexFormat::kSint32x4:
|
|
return {BaseType::kI32, 4};
|
|
case VertexFormat::kFloat32:
|
|
return {BaseType::kF32, 1};
|
|
case VertexFormat::kUnorm8x2:
|
|
case VertexFormat::kSnorm8x2:
|
|
case VertexFormat::kUnorm16x2:
|
|
case VertexFormat::kSnorm16x2:
|
|
case VertexFormat::kFloat16x2:
|
|
case VertexFormat::kFloat32x2:
|
|
return {BaseType::kF32, 2};
|
|
case VertexFormat::kFloat32x3:
|
|
return {BaseType::kF32, 3};
|
|
case VertexFormat::kUnorm8x4:
|
|
case VertexFormat::kSnorm8x4:
|
|
case VertexFormat::kUnorm16x4:
|
|
case VertexFormat::kSnorm16x4:
|
|
case VertexFormat::kFloat16x4:
|
|
case VertexFormat::kFloat32x4:
|
|
return {BaseType::kF32, 4};
|
|
}
|
|
return {BaseType::kInvalid, 0};
|
|
}
|
|
|
|
struct State {
|
|
State(CloneContext& context, const VertexPulling::Config& c)
|
|
: ctx(context), cfg(c) {}
|
|
State(const State&) = default;
|
|
~State() = default;
|
|
|
|
/// LocationReplacement describes an ast::Variable replacement for a
|
|
/// location input.
|
|
struct LocationReplacement {
|
|
/// The variable to replace in the source Program
|
|
ast::Variable* from;
|
|
/// The replacement to use in the target ProgramBuilder
|
|
ast::Variable* to;
|
|
};
|
|
|
|
struct LocationInfo {
|
|
std::function<const ast::Expression*()> expr;
|
|
const sem::Type* type;
|
|
};
|
|
|
|
CloneContext& ctx;
|
|
VertexPulling::Config const cfg;
|
|
std::unordered_map<uint32_t, LocationInfo> location_info;
|
|
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
|
|
std::function<const ast::Expression*()> instance_index_expr = nullptr;
|
|
Symbol pulling_position_name;
|
|
Symbol struct_buffer_name;
|
|
std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
|
|
ast::VariableList new_function_parameters;
|
|
|
|
/// Generate the vertex buffer binding name
|
|
/// @param index index to append to buffer name
|
|
Symbol GetVertexBufferName(uint32_t index) {
|
|
return utils::GetOrCreate(vertex_buffer_names, index, [&] {
|
|
static const char kVertexBufferNamePrefix[] =
|
|
"tint_pulling_vertex_buffer_";
|
|
return ctx.dst->Symbols().New(kVertexBufferNamePrefix +
|
|
std::to_string(index));
|
|
});
|
|
}
|
|
|
|
/// Lazily generates the structure buffer symbol
|
|
Symbol GetStructBufferName() {
|
|
if (!struct_buffer_name.IsValid()) {
|
|
static const char kStructBufferName[] = "tint_vertex_data";
|
|
struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
|
|
}
|
|
return struct_buffer_name;
|
|
}
|
|
|
|
/// Adds storage buffer decorated variables for the vertex buffers
|
|
void AddVertexStorageBuffers() {
|
|
// Creating the struct type
|
|
static const char kStructName[] = "TintVertexData";
|
|
auto* struct_type = ctx.dst->Structure(
|
|
ctx.dst->Symbols().New(kStructName),
|
|
{
|
|
ctx.dst->Member(GetStructBufferName(),
|
|
ctx.dst->ty.array<ProgramBuilder::u32>(4)),
|
|
});
|
|
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
|
|
// The decorated variable with struct type
|
|
ctx.dst->Global(
|
|
GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
|
|
ast::StorageClass::kStorage, ast::Access::kRead,
|
|
ast::DecorationList{
|
|
ctx.dst->create<ast::BindingDecoration>(i),
|
|
ctx.dst->create<ast::GroupDecoration>(cfg.pulling_group),
|
|
});
|
|
}
|
|
}
|
|
|
|
/// Creates and returns the assignment to the variables from the buffers
|
|
ast::BlockStatement* CreateVertexPullingPreamble() {
|
|
// Assign by looking at the vertex descriptor to find attributes with
|
|
// matching location.
|
|
|
|
ast::StatementList stmts;
|
|
|
|
for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size();
|
|
++buffer_idx) {
|
|
const VertexBufferLayoutDescriptor& buffer_layout =
|
|
cfg.vertex_state[buffer_idx];
|
|
|
|
if ((buffer_layout.array_stride & 3) != 0) {
|
|
ctx.dst->Diagnostics().add_error(
|
|
diag::System::Transform,
|
|
"WebGPU requires that vertex stride must be a multiple of 4 bytes, "
|
|
"but VertexPulling array stride for buffer " +
|
|
std::to_string(buffer_idx) + " was " +
|
|
std::to_string(buffer_layout.array_stride) + " bytes");
|
|
return nullptr;
|
|
}
|
|
|
|
auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex
|
|
? vertex_index_expr()
|
|
: instance_index_expr();
|
|
|
|
// buffer_array_base is the base array offset for all the vertex
|
|
// attributes. These are units of uint (4 bytes).
|
|
auto buffer_array_base = ctx.dst->Symbols().New(
|
|
"buffer_array_base_" + std::to_string(buffer_idx));
|
|
|
|
auto* attribute_offset = index_expr;
|
|
if (buffer_layout.array_stride != 4) {
|
|
attribute_offset =
|
|
ctx.dst->Mul(index_expr, buffer_layout.array_stride / 4u);
|
|
}
|
|
|
|
// let pulling_offset_n = <attribute_offset>
|
|
stmts.emplace_back(ctx.dst->Decl(
|
|
ctx.dst->Const(buffer_array_base, nullptr, attribute_offset)));
|
|
|
|
for (const VertexAttributeDescriptor& attribute_desc :
|
|
buffer_layout.attributes) {
|
|
auto it = location_info.find(attribute_desc.shader_location);
|
|
if (it == location_info.end()) {
|
|
continue;
|
|
}
|
|
auto& var = it->second;
|
|
|
|
// Data type of the target WGSL variable
|
|
auto var_dt = DataTypeOf(var.type);
|
|
// Data type of the vertex stream attribute
|
|
auto fmt_dt = DataTypeOf(attribute_desc.format);
|
|
|
|
// Base types must match between the vertex stream and the WGSL variable
|
|
if (var_dt.base_type != fmt_dt.base_type) {
|
|
std::stringstream err;
|
|
err << "VertexAttributeDescriptor for location "
|
|
<< std::to_string(attribute_desc.shader_location)
|
|
<< " has format " << attribute_desc.format
|
|
<< " but shader expects "
|
|
<< var.type->FriendlyName(ctx.src->Symbols());
|
|
ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
|
|
return nullptr;
|
|
}
|
|
|
|
// Load the attribute value
|
|
auto* fetch = Fetch(buffer_array_base, attribute_desc.offset,
|
|
buffer_idx, attribute_desc.format);
|
|
|
|
// The attribute value may not be of the desired vector width. If it is
|
|
// not, we'll need to either reduce the width with a swizzle, or append
|
|
// 0's and / or a 1.
|
|
auto* value = fetch;
|
|
if (var_dt.width < fmt_dt.width) {
|
|
// WGSL variable vector width is smaller than the loaded vector width
|
|
switch (var_dt.width) {
|
|
case 1:
|
|
value = ctx.dst->MemberAccessor(fetch, "x");
|
|
break;
|
|
case 2:
|
|
value = ctx.dst->MemberAccessor(fetch, "xy");
|
|
break;
|
|
case 3:
|
|
value = ctx.dst->MemberAccessor(fetch, "xyz");
|
|
break;
|
|
default:
|
|
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
|
<< var_dt.width;
|
|
return nullptr;
|
|
}
|
|
} else if (var_dt.width > fmt_dt.width) {
|
|
// WGSL variable vector width is wider than the loaded vector width
|
|
const ast::Type* ty = nullptr;
|
|
ast::ExpressionList values{fetch};
|
|
switch (var_dt.base_type) {
|
|
case BaseType::kI32:
|
|
ty = ctx.dst->ty.i32();
|
|
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
|
|
values.emplace_back(ctx.dst->Expr((i == 3) ? 1 : 0));
|
|
}
|
|
break;
|
|
case BaseType::kU32:
|
|
ty = ctx.dst->ty.u32();
|
|
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
|
|
values.emplace_back(ctx.dst->Expr((i == 3) ? 1u : 0u));
|
|
}
|
|
break;
|
|
case BaseType::kF32:
|
|
ty = ctx.dst->ty.f32();
|
|
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
|
|
values.emplace_back(ctx.dst->Expr((i == 3) ? 1.f : 0.f));
|
|
}
|
|
break;
|
|
default:
|
|
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
|
<< var_dt.base_type;
|
|
return nullptr;
|
|
}
|
|
value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
|
|
}
|
|
|
|
// Assign the value to the WGSL variable
|
|
stmts.emplace_back(ctx.dst->Assign(var.expr(), value));
|
|
}
|
|
}
|
|
|
|
if (stmts.empty()) {
|
|
return nullptr;
|
|
}
|
|
|
|
return ctx.dst->create<ast::BlockStatement>(stmts);
|
|
}
|
|
|
|
/// Generates an expression reading from a buffer a specific format.
|
|
/// @param array_base the symbol of the variable holding the base array offset
|
|
/// of the vertex array (each index is 4-bytes).
|
|
/// @param offset the byte offset of the data from `buffer_base`
|
|
/// @param buffer the index of the vertex buffer
|
|
/// @param format the format to read
|
|
const ast::Expression* Fetch(Symbol array_base,
|
|
uint32_t offset,
|
|
uint32_t buffer,
|
|
VertexFormat format) {
|
|
using u32 = ProgramBuilder::u32;
|
|
using i32 = ProgramBuilder::i32;
|
|
using f32 = ProgramBuilder::f32;
|
|
|
|
// Returns a u32 loaded from buffer_base + offset.
|
|
auto load_u32 = [&] {
|
|
return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32);
|
|
};
|
|
|
|
// Returns a i32 loaded from buffer_base + offset.
|
|
auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
|
|
|
|
// Returns a u32 loaded from buffer_base + offset + 4.
|
|
auto load_next_u32 = [&] {
|
|
return LoadPrimitive(array_base, offset + 4, buffer,
|
|
VertexFormat::kUint32);
|
|
};
|
|
|
|
// Returns a i32 loaded from buffer_base + offset + 4.
|
|
auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
|
|
|
|
// Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
|
|
// The low 16 bits are 0.
|
|
// `min_alignment` must be a power of two.
|
|
// `offset` must be `min_alignment` bytes aligned.
|
|
auto load_u16_h = [&] {
|
|
auto low_u32_offset = offset & ~3u;
|
|
auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
|
|
VertexFormat::kUint32);
|
|
switch (offset & 3) {
|
|
case 0:
|
|
return ctx.dst->Shl(low_u32, 16u);
|
|
case 1:
|
|
return ctx.dst->And(ctx.dst->Shl(low_u32, 8u), 0xffff0000u);
|
|
case 2:
|
|
return ctx.dst->And(low_u32, 0xffff0000u);
|
|
default: { // 3:
|
|
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
|
|
VertexFormat::kUint32);
|
|
auto* shr = ctx.dst->Shr(low_u32, 8u);
|
|
auto* shl = ctx.dst->Shl(high_u32, 24u);
|
|
return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000u);
|
|
}
|
|
}
|
|
};
|
|
|
|
// Returns a u16 loaded from offset, packed in the low 16 bits of a u32.
|
|
// The high 16 bits are 0.
|
|
auto load_u16_l = [&] {
|
|
auto low_u32_offset = offset & ~3u;
|
|
auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
|
|
VertexFormat::kUint32);
|
|
switch (offset & 3) {
|
|
case 0:
|
|
return ctx.dst->And(low_u32, 0xffffu);
|
|
case 1:
|
|
return ctx.dst->And(ctx.dst->Shr(low_u32, 8u), 0xffffu);
|
|
case 2:
|
|
return ctx.dst->Shr(low_u32, 16u);
|
|
default: { // 3:
|
|
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
|
|
VertexFormat::kUint32);
|
|
auto* shr = ctx.dst->Shr(low_u32, 24u);
|
|
auto* shl = ctx.dst->Shl(high_u32, 8u);
|
|
return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffffu);
|
|
}
|
|
}
|
|
};
|
|
|
|
// Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
|
|
// The low 16 bits are 0.
|
|
auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
|
|
|
|
// Assumptions are made that alignment must be at least as large as the size
|
|
// of a single component.
|
|
switch (format) {
|
|
// Basic primitives
|
|
case VertexFormat::kUint32:
|
|
case VertexFormat::kSint32:
|
|
case VertexFormat::kFloat32:
|
|
return LoadPrimitive(array_base, offset, buffer, format);
|
|
|
|
// Vectors of basic primitives
|
|
case VertexFormat::kUint32x2:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
|
|
VertexFormat::kUint32, 2);
|
|
case VertexFormat::kUint32x3:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
|
|
VertexFormat::kUint32, 3);
|
|
case VertexFormat::kUint32x4:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
|
|
VertexFormat::kUint32, 4);
|
|
case VertexFormat::kSint32x2:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
|
|
VertexFormat::kSint32, 2);
|
|
case VertexFormat::kSint32x3:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
|
|
VertexFormat::kSint32, 3);
|
|
case VertexFormat::kSint32x4:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
|
|
VertexFormat::kSint32, 4);
|
|
case VertexFormat::kFloat32x2:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
|
|
VertexFormat::kFloat32, 2);
|
|
case VertexFormat::kFloat32x3:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
|
|
VertexFormat::kFloat32, 3);
|
|
case VertexFormat::kFloat32x4:
|
|
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
|
|
VertexFormat::kFloat32, 4);
|
|
|
|
case VertexFormat::kUint8x2: {
|
|
// yyxx0000, yyxx0000
|
|
auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
|
|
// xx000000, yyxx0000
|
|
auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8u, 0u));
|
|
// 000000xx, 000000yy
|
|
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
|
|
}
|
|
case VertexFormat::kUint8x4: {
|
|
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
|
|
auto* u32s = ctx.dst->vec4<u32>(load_u32());
|
|
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
|
|
auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
|
|
// 000000xx, 000000yy, 000000zz, 000000ww
|
|
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
|
|
}
|
|
case VertexFormat::kUint16x2: {
|
|
// yyyyxxxx, yyyyxxxx
|
|
auto* u32s = ctx.dst->vec2<u32>(load_u32());
|
|
// xxxx0000, yyyyxxxx
|
|
auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16u, 0u));
|
|
// 0000xxxx, 0000yyyy
|
|
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
|
|
}
|
|
case VertexFormat::kUint16x4: {
|
|
// yyyyxxxx, wwwwzzzz
|
|
auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
|
|
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
|
|
auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
|
|
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
|
|
auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
|
|
// 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
|
|
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
|
|
}
|
|
case VertexFormat::kSint8x2: {
|
|
// yyxx0000, yyxx0000
|
|
auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
|
|
// xx000000, yyxx0000
|
|
auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8u, 0u));
|
|
// ssssssxx, ssssssyy
|
|
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
|
|
}
|
|
case VertexFormat::kSint8x4: {
|
|
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
|
|
auto* i32s = ctx.dst->vec4<i32>(load_i32());
|
|
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
|
|
auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
|
|
// ssssssxx, ssssssyy, sssssszz, ssssssww
|
|
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
|
|
}
|
|
case VertexFormat::kSint16x2: {
|
|
// yyyyxxxx, yyyyxxxx
|
|
auto* i32s = ctx.dst->vec2<i32>(load_i32());
|
|
// xxxx0000, yyyyxxxx
|
|
auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16u, 0u));
|
|
// ssssxxxx, ssssyyyy
|
|
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
|
|
}
|
|
case VertexFormat::kSint16x4: {
|
|
// yyyyxxxx, wwwwzzzz
|
|
auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
|
|
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
|
|
auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
|
|
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
|
|
auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
|
|
// ssssxxxx, ssssyyyy, sssszzzz, sssswwww
|
|
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
|
|
}
|
|
case VertexFormat::kUnorm8x2:
|
|
return ctx.dst->MemberAccessor(
|
|
ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
|
|
case VertexFormat::kSnorm8x2:
|
|
return ctx.dst->MemberAccessor(
|
|
ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
|
|
case VertexFormat::kUnorm8x4:
|
|
return ctx.dst->Call("unpack4x8unorm", load_u32());
|
|
case VertexFormat::kSnorm8x4:
|
|
return ctx.dst->Call("unpack4x8snorm", load_u32());
|
|
case VertexFormat::kUnorm16x2:
|
|
return ctx.dst->Call("unpack2x16unorm", load_u32());
|
|
case VertexFormat::kSnorm16x2:
|
|
return ctx.dst->Call("unpack2x16snorm", load_u32());
|
|
case VertexFormat::kFloat16x2:
|
|
return ctx.dst->Call("unpack2x16float", load_u32());
|
|
case VertexFormat::kUnorm16x4:
|
|
return ctx.dst->vec4<f32>(
|
|
ctx.dst->Call("unpack2x16unorm", load_u32()),
|
|
ctx.dst->Call("unpack2x16unorm", load_next_u32()));
|
|
case VertexFormat::kSnorm16x4:
|
|
return ctx.dst->vec4<f32>(
|
|
ctx.dst->Call("unpack2x16snorm", load_u32()),
|
|
ctx.dst->Call("unpack2x16snorm", load_next_u32()));
|
|
case VertexFormat::kFloat16x4:
|
|
return ctx.dst->vec4<f32>(
|
|
ctx.dst->Call("unpack2x16float", load_u32()),
|
|
ctx.dst->Call("unpack2x16float", load_next_u32()));
|
|
}
|
|
|
|
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
|
<< "format " << static_cast<int>(format);
|
|
return nullptr;
|
|
}
|
|
|
|
/// Generates an expression reading an aligned basic type (u32, i32, f32) from
|
|
/// a vertex buffer.
|
|
/// @param array_base the symbol of the variable holding the base array offset
|
|
/// of the vertex array (each index is 4-bytes).
|
|
/// @param offset the byte offset of the data from `buffer_base`
|
|
/// @param buffer the index of the vertex buffer
|
|
/// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
|
|
/// VertexFormat::kFloat32
|
|
const ast::Expression* LoadPrimitive(Symbol array_base,
|
|
uint32_t offset,
|
|
uint32_t buffer,
|
|
VertexFormat format) {
|
|
const ast::Expression* u32 = nullptr;
|
|
if ((offset & 3) == 0) {
|
|
// Aligned load.
|
|
|
|
const ast ::Expression* index = nullptr;
|
|
if (offset > 0) {
|
|
index = ctx.dst->Add(array_base, offset / 4);
|
|
} else {
|
|
index = ctx.dst->Expr(array_base);
|
|
}
|
|
u32 = ctx.dst->IndexAccessor(
|
|
ctx.dst->MemberAccessor(GetVertexBufferName(buffer),
|
|
GetStructBufferName()),
|
|
index);
|
|
|
|
} else {
|
|
// Unaligned load
|
|
uint32_t offset_aligned = offset & ~3u;
|
|
auto* low = LoadPrimitive(array_base, offset_aligned, buffer,
|
|
VertexFormat::kUint32);
|
|
auto* high = LoadPrimitive(array_base, offset_aligned + 4u, buffer,
|
|
VertexFormat::kUint32);
|
|
|
|
uint32_t shift = 8u * (offset & 3u);
|
|
|
|
auto* low_shr = ctx.dst->Shr(low, shift);
|
|
auto* high_shl = ctx.dst->Shl(high, 32u - shift);
|
|
u32 = ctx.dst->Or(low_shr, high_shl);
|
|
}
|
|
|
|
switch (format) {
|
|
case VertexFormat::kUint32:
|
|
return u32;
|
|
case VertexFormat::kSint32:
|
|
return ctx.dst->Bitcast(ctx.dst->ty.i32(), u32);
|
|
case VertexFormat::kFloat32:
|
|
return ctx.dst->Bitcast(ctx.dst->ty.f32(), u32);
|
|
default:
|
|
break;
|
|
}
|
|
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
|
<< "invalid format for LoadPrimitive" << static_cast<int>(format);
|
|
return nullptr;
|
|
}
|
|
|
|
/// Generates an expression reading a vec2/3/4 from a vertex buffer.
|
|
/// @param array_base the symbol of the variable holding the base array offset
|
|
/// of the vertex array (each index is 4-bytes).
|
|
/// @param offset the byte offset of the data from `buffer_base`
|
|
/// @param buffer the index of the vertex buffer
|
|
/// @param element_stride stride between elements, in bytes
|
|
/// @param base_type underlying AST type
|
|
/// @param base_format underlying vertex format
|
|
/// @param count how many elements the vector has
|
|
const ast::Expression* LoadVec(Symbol array_base,
|
|
uint32_t offset,
|
|
uint32_t buffer,
|
|
uint32_t element_stride,
|
|
const ast::Type* base_type,
|
|
VertexFormat base_format,
|
|
uint32_t count) {
|
|
ast::ExpressionList expr_list;
|
|
for (uint32_t i = 0; i < count; ++i) {
|
|
// Offset read position by element_stride for each component
|
|
uint32_t primitive_offset = offset + element_stride * i;
|
|
expr_list.push_back(
|
|
LoadPrimitive(array_base, primitive_offset, buffer, base_format));
|
|
}
|
|
|
|
return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
|
|
std::move(expr_list));
|
|
}
|
|
|
|
/// Process a non-struct entry point parameter.
|
|
/// Generate function-scope variables for location parameters, and record
|
|
/// vertex_index and instance_index builtins if present.
|
|
/// @param func the entry point function
|
|
/// @param param the parameter to process
|
|
void ProcessNonStructParameter(const ast::Function* func,
|
|
const ast::Variable* param) {
|
|
if (auto* location =
|
|
ast::GetDecoration<ast::LocationDecoration>(param->decorations)) {
|
|
// Create a function-scope variable to replace the parameter.
|
|
auto func_var_sym = ctx.Clone(param->symbol);
|
|
auto* func_var_type = ctx.Clone(param->type);
|
|
auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
|
|
ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
|
|
// Capture mapping from location to the new variable.
|
|
LocationInfo info;
|
|
info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
|
|
info.type = ctx.src->Sem().Get(param)->Type();
|
|
location_info[location->value] = info;
|
|
} else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
|
|
param->decorations)) {
|
|
// Check for existing vertex_index and instance_index builtins.
|
|
if (builtin->builtin == ast::Builtin::kVertexIndex) {
|
|
vertex_index_expr = [this, param]() {
|
|
return ctx.dst->Expr(ctx.Clone(param->symbol));
|
|
};
|
|
} else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
|
|
instance_index_expr = [this, param]() {
|
|
return ctx.dst->Expr(ctx.Clone(param->symbol));
|
|
};
|
|
}
|
|
new_function_parameters.push_back(ctx.Clone(param));
|
|
} else {
|
|
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
|
<< "Invalid entry point parameter";
|
|
}
|
|
}
|
|
|
|
/// Process a struct entry point parameter.
|
|
/// If the struct has members with location attributes, push the parameter to
|
|
/// a function-scope variable and create a new struct parameter without those
|
|
/// attributes. Record expressions for members that are vertex_index and
|
|
/// instance_index builtins.
|
|
/// @param func the entry point function
|
|
/// @param param the parameter to process
|
|
/// @param struct_ty the structure type
|
|
void ProcessStructParameter(const ast::Function* func,
|
|
const ast::Variable* param,
|
|
const ast::Struct* struct_ty) {
|
|
auto param_sym = ctx.Clone(param->symbol);
|
|
|
|
// Process the struct members.
|
|
bool has_locations = false;
|
|
ast::StructMemberList members_to_clone;
|
|
for (auto* member : struct_ty->members) {
|
|
auto member_sym = ctx.Clone(member->symbol);
|
|
std::function<const ast::Expression*()> member_expr = [this, param_sym,
|
|
member_sym]() {
|
|
return ctx.dst->MemberAccessor(param_sym, member_sym);
|
|
};
|
|
|
|
if (auto* location = ast::GetDecoration<ast::LocationDecoration>(
|
|
member->decorations)) {
|
|
// Capture mapping from location to struct member.
|
|
LocationInfo info;
|
|
info.expr = member_expr;
|
|
info.type = ctx.src->Sem().Get(member)->Type();
|
|
location_info[location->value] = info;
|
|
has_locations = true;
|
|
} else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
|
|
member->decorations)) {
|
|
// Check for existing vertex_index and instance_index builtins.
|
|
if (builtin->builtin == ast::Builtin::kVertexIndex) {
|
|
vertex_index_expr = member_expr;
|
|
} else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
|
|
instance_index_expr = member_expr;
|
|
}
|
|
members_to_clone.push_back(member);
|
|
} else {
|
|
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
|
<< "Invalid entry point parameter";
|
|
}
|
|
}
|
|
|
|
if (!has_locations) {
|
|
// Nothing to do.
|
|
new_function_parameters.push_back(ctx.Clone(param));
|
|
return;
|
|
}
|
|
|
|
// Create a function-scope variable to replace the parameter.
|
|
auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
|
|
ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
|
|
|
|
if (!members_to_clone.empty()) {
|
|
// Create a new struct without the location attributes.
|
|
ast::StructMemberList new_members;
|
|
for (auto* member : members_to_clone) {
|
|
auto member_sym = ctx.Clone(member->symbol);
|
|
auto* member_type = ctx.Clone(member->type);
|
|
auto member_decos = ctx.Clone(member->decorations);
|
|
new_members.push_back(
|
|
ctx.dst->Member(member_sym, member_type, std::move(member_decos)));
|
|
}
|
|
auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
|
|
|
|
// Create a new function parameter with this struct.
|
|
auto* new_param =
|
|
ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
|
|
new_function_parameters.push_back(new_param);
|
|
|
|
// Copy values from the new parameter to the function-scope variable.
|
|
for (auto* member : members_to_clone) {
|
|
auto member_name = ctx.Clone(member->symbol);
|
|
ctx.InsertFront(
|
|
func->body->statements,
|
|
ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
|
|
ctx.dst->MemberAccessor(new_param, member_name)));
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Process an entry point function.
|
|
/// @param func the entry point function
|
|
void Process(const ast::Function* func) {
|
|
if (func->body->Empty()) {
|
|
return;
|
|
}
|
|
|
|
// Process entry point parameters.
|
|
for (auto* param : func->params) {
|
|
auto* sem = ctx.src->Sem().Get(param);
|
|
if (auto* str = sem->Type()->As<sem::Struct>()) {
|
|
ProcessStructParameter(func, param, str->Declaration());
|
|
} else {
|
|
ProcessNonStructParameter(func, param);
|
|
}
|
|
}
|
|
|
|
// Insert new parameters for vertex_index and instance_index if needed.
|
|
if (!vertex_index_expr) {
|
|
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
|
|
if (layout.step_mode == VertexStepMode::kVertex) {
|
|
auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
|
|
new_function_parameters.push_back(
|
|
ctx.dst->Param(name, ctx.dst->ty.u32(),
|
|
{ctx.dst->Builtin(ast::Builtin::kVertexIndex)}));
|
|
vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (!instance_index_expr) {
|
|
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
|
|
if (layout.step_mode == VertexStepMode::kInstance) {
|
|
auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
|
|
new_function_parameters.push_back(
|
|
ctx.dst->Param(name, ctx.dst->ty.u32(),
|
|
{ctx.dst->Builtin(ast::Builtin::kInstanceIndex)}));
|
|
instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generate vertex pulling preamble.
|
|
if (auto* block = CreateVertexPullingPreamble()) {
|
|
ctx.InsertFront(func->body->statements, block);
|
|
}
|
|
|
|
// Rewrite the function header with the new parameters.
|
|
auto func_sym = ctx.Clone(func->symbol);
|
|
auto* ret_type = ctx.Clone(func->return_type);
|
|
auto* body = ctx.Clone(func->body);
|
|
auto decos = ctx.Clone(func->decorations);
|
|
auto ret_decos = ctx.Clone(func->return_type_decorations);
|
|
auto* new_func = ctx.dst->create<ast::Function>(
|
|
func->source, func_sym, new_function_parameters, ret_type, body,
|
|
std::move(decos), std::move(ret_decos));
|
|
ctx.Replace(func, new_func);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
VertexPulling::VertexPulling() = default;
|
|
VertexPulling::~VertexPulling() = default;
|
|
|
|
void VertexPulling::Run(CloneContext& ctx,
|
|
const DataMap& inputs,
|
|
DataMap&) const {
|
|
auto cfg = cfg_;
|
|
if (auto* cfg_data = inputs.Get<Config>()) {
|
|
cfg = *cfg_data;
|
|
}
|
|
|
|
// Find entry point
|
|
auto* func = ctx.src->AST().Functions().Find(
|
|
ctx.src->Symbols().Get(cfg.entry_point_name),
|
|
ast::PipelineStage::kVertex);
|
|
if (func == nullptr) {
|
|
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
|
"Vertex stage entry point not found");
|
|
return;
|
|
}
|
|
|
|
// TODO(idanr): Need to check shader locations in descriptor cover all
|
|
// attributes
|
|
|
|
// TODO(idanr): Make sure we covered all error cases, to guarantee the
|
|
// following stages will pass
|
|
|
|
State state{ctx, cfg};
|
|
state.AddVertexStorageBuffers();
|
|
state.Process(func);
|
|
|
|
ctx.Clone();
|
|
}
|
|
|
|
VertexPulling::Config::Config() = default;
|
|
VertexPulling::Config::Config(const Config&) = default;
|
|
VertexPulling::Config::~Config() = default;
|
|
VertexPulling::Config& VertexPulling::Config::operator=(const Config&) =
|
|
default;
|
|
|
|
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
|
|
|
|
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
|
|
uint32_t in_array_stride,
|
|
VertexStepMode in_step_mode,
|
|
std::vector<VertexAttributeDescriptor> in_attributes)
|
|
: array_stride(in_array_stride),
|
|
step_mode(in_step_mode),
|
|
attributes(std::move(in_attributes)) {}
|
|
|
|
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
|
|
const VertexBufferLayoutDescriptor& other) = default;
|
|
|
|
VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=(
|
|
const VertexBufferLayoutDescriptor& other) = default;
|
|
|
|
VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
|
|
|
|
} // namespace transform
|
|
} // namespace tint
|