Tint: Refactor transform VertexPulling and its unit tests

This CL prepare transform VertexPulling and its unit tests for
implementing f16 pipeline io. This CL distinguishes vertex format type
and WGSL variable type of a vertex shader attribute (location input) in
VertexPuilling transform as both `f32` and `f16` WGSL types would be
mapepd to float vertex format. This CL splits VertexPulling unit tests
by base veretx format (SInt, UInt and Float), make it easier to add
`f16` tests.

Bugs: tint:1473, tint:1502
Change-Id: I649deb61e8eb8dac6ebd653bf77ef96475334a56
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112520
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Zhaoming Jiang 2022-12-02 18:11:57 +00:00 committed by Dawn LUCI CQ
parent 3c33cf15e3
commit f42d92a1d5
2 changed files with 1045 additions and 427 deletions

View File

@ -34,32 +34,23 @@ namespace tint::transform {
namespace {
/// The base type of a component.
/// The base WGSL type of a component.
/// The format type is either this type or a vector of this type.
enum class BaseType {
enum class BaseWGSLType {
kInvalid,
kU32,
kI32,
kF32,
};
/// Writes the BaseType to the std::ostream.
/// @param out the std::ostream to write to
/// @param format the BaseType to write
/// @returns out so calls can be chained
std::ostream& operator<<(std::ostream& out, BaseType format) {
switch (format) {
case BaseType::kInvalid:
return out << "invalid";
case BaseType::kU32:
return out << "u32";
case BaseType::kI32:
return out << "i32";
case BaseType::kF32:
return out << "f32";
}
return out << "<unknown>";
}
/// The data type of a vertex format.
/// The format type is either this type or a vector of this type.
enum class VertexDataType {
kInvalid,
kUInt, // unsigned int
kSInt, // signed int
kFloat, // unsigned normalized, signed normalized, and float
};
/// Writes the VertexFormat to the std::ostream.
/// @param out the std::ostream to write to
@ -131,74 +122,94 @@ std::ostream& operator<<(std::ostream& out, VertexFormat format) {
return out << "<unknown>";
}
/// A vertex attribute data format.
struct DataType {
BaseType base_type;
/// Type information of a vertex input attribute.
struct AttributeWGSLType {
BaseWGSLType base_type;
uint32_t width; // 1 for scalar, 2+ for a vector
};
DataType DataTypeOf(const sem::Type* ty) {
if (ty->Is<sem::I32>()) {
return {BaseType::kI32, 1};
/// Type information of a vertex format.
struct VertexFormatType {
VertexDataType base_type;
uint32_t width; // 1 for scalar, 2+ for a vector
};
// Check if base types match between the WGSL variable and the vertex format
bool IsTypeCompatible(AttributeWGSLType wgslType, VertexFormatType vertexFormatType) {
switch (wgslType.base_type) {
case BaseWGSLType::kF32:
return (vertexFormatType.base_type == VertexDataType::kFloat);
case BaseWGSLType::kU32:
return (vertexFormatType.base_type == VertexDataType::kUInt);
case BaseWGSLType::kI32:
return (vertexFormatType.base_type == VertexDataType::kSInt);
default:
return false;
}
if (ty->Is<sem::U32>()) {
return {BaseType::kU32, 1};
}
if (ty->Is<sem::F32>()) {
return {BaseType::kF32, 1};
}
if (auto* vec = ty->As<sem::Vector>()) {
return {DataTypeOf(vec->type()).base_type, vec->Width()};
}
return {BaseType::kInvalid, 0};
}
DataType DataTypeOf(VertexFormat format) {
AttributeWGSLType WGSLTypeOf(const sem::Type* ty) {
if (ty->Is<sem::I32>()) {
return {BaseWGSLType::kI32, 1};
}
if (ty->Is<sem::U32>()) {
return {BaseWGSLType::kU32, 1};
}
if (ty->Is<sem::F32>()) {
return {BaseWGSLType::kF32, 1};
}
if (auto* vec = ty->As<sem::Vector>()) {
return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
}
return {BaseWGSLType::kInvalid, 0};
}
VertexFormatType VertexFormatTypeOf(VertexFormat format) {
switch (format) {
case VertexFormat::kUint32:
return {BaseType::kU32, 1};
return {VertexDataType::kUInt, 1};
case VertexFormat::kUint8x2:
case VertexFormat::kUint16x2:
case VertexFormat::kUint32x2:
return {BaseType::kU32, 2};
return {VertexDataType::kUInt, 2};
case VertexFormat::kUint32x3:
return {BaseType::kU32, 3};
return {VertexDataType::kUInt, 3};
case VertexFormat::kUint8x4:
case VertexFormat::kUint16x4:
case VertexFormat::kUint32x4:
return {BaseType::kU32, 4};
return {VertexDataType::kUInt, 4};
case VertexFormat::kSint32:
return {BaseType::kI32, 1};
return {VertexDataType::kSInt, 1};
case VertexFormat::kSint8x2:
case VertexFormat::kSint16x2:
case VertexFormat::kSint32x2:
return {BaseType::kI32, 2};
return {VertexDataType::kSInt, 2};
case VertexFormat::kSint32x3:
return {BaseType::kI32, 3};
return {VertexDataType::kSInt, 3};
case VertexFormat::kSint8x4:
case VertexFormat::kSint16x4:
case VertexFormat::kSint32x4:
return {BaseType::kI32, 4};
return {VertexDataType::kSInt, 4};
case VertexFormat::kFloat32:
return {BaseType::kF32, 1};
return {VertexDataType::kFloat, 1};
case VertexFormat::kUnorm8x2:
case VertexFormat::kSnorm8x2:
case VertexFormat::kUnorm16x2:
case VertexFormat::kSnorm16x2:
case VertexFormat::kFloat16x2:
case VertexFormat::kFloat32x2:
return {BaseType::kF32, 2};
return {VertexDataType::kFloat, 2};
case VertexFormat::kFloat32x3:
return {BaseType::kF32, 3};
return {VertexDataType::kFloat, 3};
case VertexFormat::kUnorm8x4:
case VertexFormat::kSnorm8x4:
case VertexFormat::kUnorm16x4:
case VertexFormat::kSnorm16x4:
case VertexFormat::kFloat16x4:
case VertexFormat::kFloat32x4:
return {BaseType::kF32, 4};
return {VertexDataType::kFloat, 4};
}
return {BaseType::kInvalid, 0};
return {VertexDataType::kInvalid, 0};
}
} // namespace
@ -350,12 +361,12 @@ struct VertexPulling::State {
auto& var = it->second;
// Data type of the target WGSL variable
auto var_dt = DataTypeOf(var.type);
auto var_dt = WGSLTypeOf(var.type);
// Data type of the vertex stream attribute
auto fmt_dt = DataTypeOf(attribute_desc.format);
auto fmt_dt = VertexFormatTypeOf(attribute_desc.format);
// Base types must match between the vertex stream and the WGSL variable
if (var_dt.base_type != fmt_dt.base_type) {
if (!IsTypeCompatible(var_dt, fmt_dt)) {
std::stringstream err;
err << "VertexAttributeDescriptor for location "
<< std::to_string(attribute_desc.shader_location) << " has format "
@ -365,13 +376,14 @@ struct VertexPulling::State {
return nullptr;
}
// Load the attribute value
// Load the attribute value according to vertex format and convert the element type
// of result to match target WGSL variable. The result of `Fetch` should be of WGSL
// types `f32`, `i32`, `u32`, and their vectors.
auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx,
attribute_desc.format);
// The attribute value may not be of the desired vector width. If it is
// not, we'll need to either reduce the width with a swizzle, or append
// 0's and / or a 1.
// The attribute value may not be of the desired vector width. If it is not, we'll
// need to either reduce the width with a swizzle, or append 0's and / or a 1.
auto* value = fetch;
if (var_dt.width < fmt_dt.width) {
// WGSL variable vector width is smaller than the loaded vector width
@ -390,33 +402,26 @@ struct VertexPulling::State {
return nullptr;
}
} else if (var_dt.width > fmt_dt.width) {
// WGSL variable vector width is wider than the loaded vector width
const ast::Type* ty = nullptr;
// WGSL variable vector width is wider than the loaded vector width, do padding.
// The components of result vector variable, initialized with type-converted
// loaded data vector.
utils::Vector<const ast::Expression*, 8> values{fetch};
switch (var_dt.base_type) {
case BaseType::kI32:
ty = b.ty.i32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
values.Push(b.Expr((i == 3) ? 1_i : 0_i));
}
break;
case BaseType::kU32:
ty = b.ty.u32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
values.Push(b.Expr((i == 3) ? 1_u : 0_u));
}
break;
case BaseType::kF32:
ty = b.ty.f32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
values.Push(b.Expr((i == 3) ? 1_f : 0_f));
}
break;
default:
TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type;
return nullptr;
// Add padding elements. The result must be of vector types of signed/unsigned
// integer or float, so use the abstract integer or abstract float value to do
// padding.
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
if (var_dt.base_type == BaseWGSLType::kI32 ||
var_dt.base_type == BaseWGSLType::kU32) {
values.Push(b.Expr((i == 3) ? 1_a : 0_a));
} else {
values.Push(b.Expr((i == 3) ? 1.0_a : 0.0_a));
}
}
value = b.Construct(b.ty.vec(ty, var_dt.width), values);
const ast::Type* target_ty = CreateASTTypeFor(ctx, var.type);
value = b.Construct(target_ty, values);
}
// Assign the value to the WGSL variable
@ -431,12 +436,14 @@ struct VertexPulling::State {
return b.Block(std::move(stmts));
}
/// Generates an expression reading from a buffer a specific format.
/// Generates an expression reading a specific vertex format from a buffer. Any vertex format of
/// signed normailized, unsigned normailized, or float will result in `f32` or `vecN<f32>` WGSL
/// type.
/// @param array_base the symbol of the variable holding the base array offset
/// of the vertex array (each index is 4-bytes).
/// @param offset the byte offset of the data from `buffer_base`
/// @param buffer the index of the vertex buffer
/// @param format the format to read
/// @param format the vertex format to read
const ast::Expression* Fetch(Symbol array_base,
uint32_t offset,
uint32_t buffer,

File diff suppressed because it is too large Load Diff