Tint: Refactor transform VertexPulling and its unit tests

This CL prepare transform VertexPulling and its unit tests for
implementing f16 pipeline io. This CL distinguishes vertex format type
and WGSL variable type of a vertex shader attribute (location input) in
VertexPuilling transform as both `f32` and `f16` WGSL types would be
mapepd to float vertex format. This CL splits VertexPulling unit tests
by base veretx format (SInt, UInt and Float), make it easier to add
`f16` tests.

Bugs: tint:1473, tint:1502
Change-Id: I649deb61e8eb8dac6ebd653bf77ef96475334a56
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112520
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Zhaoming Jiang 2022-12-02 18:11:57 +00:00 committed by Dawn LUCI CQ
parent 3c33cf15e3
commit f42d92a1d5
2 changed files with 1045 additions and 427 deletions

View File

@ -34,32 +34,23 @@ namespace tint::transform {
namespace { namespace {
/// The base type of a component. /// The base WGSL type of a component.
/// The format type is either this type or a vector of this type. /// The format type is either this type or a vector of this type.
enum class BaseType { enum class BaseWGSLType {
kInvalid, kInvalid,
kU32, kU32,
kI32, kI32,
kF32, kF32,
}; };
/// Writes the BaseType to the std::ostream. /// The data type of a vertex format.
/// @param out the std::ostream to write to /// The format type is either this type or a vector of this type.
/// @param format the BaseType to write enum class VertexDataType {
/// @returns out so calls can be chained kInvalid,
std::ostream& operator<<(std::ostream& out, BaseType format) { kUInt, // unsigned int
switch (format) { kSInt, // signed int
case BaseType::kInvalid: kFloat, // unsigned normalized, signed normalized, and float
return out << "invalid"; };
case BaseType::kU32:
return out << "u32";
case BaseType::kI32:
return out << "i32";
case BaseType::kF32:
return out << "f32";
}
return out << "<unknown>";
}
/// Writes the VertexFormat to the std::ostream. /// Writes the VertexFormat to the std::ostream.
/// @param out the std::ostream to write to /// @param out the std::ostream to write to
@ -131,74 +122,94 @@ std::ostream& operator<<(std::ostream& out, VertexFormat format) {
return out << "<unknown>"; return out << "<unknown>";
} }
/// A vertex attribute data format. /// Type information of a vertex input attribute.
struct DataType { struct AttributeWGSLType {
BaseType base_type; BaseWGSLType base_type;
uint32_t width; // 1 for scalar, 2+ for a vector uint32_t width; // 1 for scalar, 2+ for a vector
}; };
DataType DataTypeOf(const sem::Type* ty) { /// Type information of a vertex format.
if (ty->Is<sem::I32>()) { struct VertexFormatType {
return {BaseType::kI32, 1}; VertexDataType base_type;
uint32_t width; // 1 for scalar, 2+ for a vector
};
// Check if base types match between the WGSL variable and the vertex format
bool IsTypeCompatible(AttributeWGSLType wgslType, VertexFormatType vertexFormatType) {
switch (wgslType.base_type) {
case BaseWGSLType::kF32:
return (vertexFormatType.base_type == VertexDataType::kFloat);
case BaseWGSLType::kU32:
return (vertexFormatType.base_type == VertexDataType::kUInt);
case BaseWGSLType::kI32:
return (vertexFormatType.base_type == VertexDataType::kSInt);
default:
return false;
} }
if (ty->Is<sem::U32>()) {
return {BaseType::kU32, 1};
}
if (ty->Is<sem::F32>()) {
return {BaseType::kF32, 1};
}
if (auto* vec = ty->As<sem::Vector>()) {
return {DataTypeOf(vec->type()).base_type, vec->Width()};
}
return {BaseType::kInvalid, 0};
} }
DataType DataTypeOf(VertexFormat format) { AttributeWGSLType WGSLTypeOf(const sem::Type* ty) {
if (ty->Is<sem::I32>()) {
return {BaseWGSLType::kI32, 1};
}
if (ty->Is<sem::U32>()) {
return {BaseWGSLType::kU32, 1};
}
if (ty->Is<sem::F32>()) {
return {BaseWGSLType::kF32, 1};
}
if (auto* vec = ty->As<sem::Vector>()) {
return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
}
return {BaseWGSLType::kInvalid, 0};
}
VertexFormatType VertexFormatTypeOf(VertexFormat format) {
switch (format) { switch (format) {
case VertexFormat::kUint32: case VertexFormat::kUint32:
return {BaseType::kU32, 1}; return {VertexDataType::kUInt, 1};
case VertexFormat::kUint8x2: case VertexFormat::kUint8x2:
case VertexFormat::kUint16x2: case VertexFormat::kUint16x2:
case VertexFormat::kUint32x2: case VertexFormat::kUint32x2:
return {BaseType::kU32, 2}; return {VertexDataType::kUInt, 2};
case VertexFormat::kUint32x3: case VertexFormat::kUint32x3:
return {BaseType::kU32, 3}; return {VertexDataType::kUInt, 3};
case VertexFormat::kUint8x4: case VertexFormat::kUint8x4:
case VertexFormat::kUint16x4: case VertexFormat::kUint16x4:
case VertexFormat::kUint32x4: case VertexFormat::kUint32x4:
return {BaseType::kU32, 4}; return {VertexDataType::kUInt, 4};
case VertexFormat::kSint32: case VertexFormat::kSint32:
return {BaseType::kI32, 1}; return {VertexDataType::kSInt, 1};
case VertexFormat::kSint8x2: case VertexFormat::kSint8x2:
case VertexFormat::kSint16x2: case VertexFormat::kSint16x2:
case VertexFormat::kSint32x2: case VertexFormat::kSint32x2:
return {BaseType::kI32, 2}; return {VertexDataType::kSInt, 2};
case VertexFormat::kSint32x3: case VertexFormat::kSint32x3:
return {BaseType::kI32, 3}; return {VertexDataType::kSInt, 3};
case VertexFormat::kSint8x4: case VertexFormat::kSint8x4:
case VertexFormat::kSint16x4: case VertexFormat::kSint16x4:
case VertexFormat::kSint32x4: case VertexFormat::kSint32x4:
return {BaseType::kI32, 4}; return {VertexDataType::kSInt, 4};
case VertexFormat::kFloat32: case VertexFormat::kFloat32:
return {BaseType::kF32, 1}; return {VertexDataType::kFloat, 1};
case VertexFormat::kUnorm8x2: case VertexFormat::kUnorm8x2:
case VertexFormat::kSnorm8x2: case VertexFormat::kSnorm8x2:
case VertexFormat::kUnorm16x2: case VertexFormat::kUnorm16x2:
case VertexFormat::kSnorm16x2: case VertexFormat::kSnorm16x2:
case VertexFormat::kFloat16x2: case VertexFormat::kFloat16x2:
case VertexFormat::kFloat32x2: case VertexFormat::kFloat32x2:
return {BaseType::kF32, 2}; return {VertexDataType::kFloat, 2};
case VertexFormat::kFloat32x3: case VertexFormat::kFloat32x3:
return {BaseType::kF32, 3}; return {VertexDataType::kFloat, 3};
case VertexFormat::kUnorm8x4: case VertexFormat::kUnorm8x4:
case VertexFormat::kSnorm8x4: case VertexFormat::kSnorm8x4:
case VertexFormat::kUnorm16x4: case VertexFormat::kUnorm16x4:
case VertexFormat::kSnorm16x4: case VertexFormat::kSnorm16x4:
case VertexFormat::kFloat16x4: case VertexFormat::kFloat16x4:
case VertexFormat::kFloat32x4: case VertexFormat::kFloat32x4:
return {BaseType::kF32, 4}; return {VertexDataType::kFloat, 4};
} }
return {BaseType::kInvalid, 0}; return {VertexDataType::kInvalid, 0};
} }
} // namespace } // namespace
@ -350,12 +361,12 @@ struct VertexPulling::State {
auto& var = it->second; auto& var = it->second;
// Data type of the target WGSL variable // Data type of the target WGSL variable
auto var_dt = DataTypeOf(var.type); auto var_dt = WGSLTypeOf(var.type);
// Data type of the vertex stream attribute // Data type of the vertex stream attribute
auto fmt_dt = DataTypeOf(attribute_desc.format); auto fmt_dt = VertexFormatTypeOf(attribute_desc.format);
// Base types must match between the vertex stream and the WGSL variable // Base types must match between the vertex stream and the WGSL variable
if (var_dt.base_type != fmt_dt.base_type) { if (!IsTypeCompatible(var_dt, fmt_dt)) {
std::stringstream err; std::stringstream err;
err << "VertexAttributeDescriptor for location " err << "VertexAttributeDescriptor for location "
<< std::to_string(attribute_desc.shader_location) << " has format " << std::to_string(attribute_desc.shader_location) << " has format "
@ -365,13 +376,14 @@ struct VertexPulling::State {
return nullptr; return nullptr;
} }
// Load the attribute value // Load the attribute value according to vertex format and convert the element type
// of result to match target WGSL variable. The result of `Fetch` should be of WGSL
// types `f32`, `i32`, `u32`, and their vectors.
auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx, auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, buffer_idx,
attribute_desc.format); attribute_desc.format);
// The attribute value may not be of the desired vector width. If it is // The attribute value may not be of the desired vector width. If it is not, we'll
// not, we'll need to either reduce the width with a swizzle, or append // need to either reduce the width with a swizzle, or append 0's and / or a 1.
// 0's and / or a 1.
auto* value = fetch; auto* value = fetch;
if (var_dt.width < fmt_dt.width) { if (var_dt.width < fmt_dt.width) {
// WGSL variable vector width is smaller than the loaded vector width // WGSL variable vector width is smaller than the loaded vector width
@ -390,33 +402,26 @@ struct VertexPulling::State {
return nullptr; return nullptr;
} }
} else if (var_dt.width > fmt_dt.width) { } else if (var_dt.width > fmt_dt.width) {
// WGSL variable vector width is wider than the loaded vector width // WGSL variable vector width is wider than the loaded vector width, do padding.
const ast::Type* ty = nullptr;
// The components of result vector variable, initialized with type-converted
// loaded data vector.
utils::Vector<const ast::Expression*, 8> values{fetch}; utils::Vector<const ast::Expression*, 8> values{fetch};
switch (var_dt.base_type) {
case BaseType::kI32: // Add padding elements. The result must be of vector types of signed/unsigned
ty = b.ty.i32(); // integer or float, so use the abstract integer or abstract float value to do
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { // padding.
values.Push(b.Expr((i == 3) ? 1_i : 0_i)); for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
} if (var_dt.base_type == BaseWGSLType::kI32 ||
break; var_dt.base_type == BaseWGSLType::kU32) {
case BaseType::kU32: values.Push(b.Expr((i == 3) ? 1_a : 0_a));
ty = b.ty.u32(); } else {
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { values.Push(b.Expr((i == 3) ? 1.0_a : 0.0_a));
values.Push(b.Expr((i == 3) ? 1_u : 0_u)); }
}
break;
case BaseType::kF32:
ty = b.ty.f32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
values.Push(b.Expr((i == 3) ? 1_f : 0_f));
}
break;
default:
TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type;
return nullptr;
} }
value = b.Construct(b.ty.vec(ty, var_dt.width), values);
const ast::Type* target_ty = CreateASTTypeFor(ctx, var.type);
value = b.Construct(target_ty, values);
} }
// Assign the value to the WGSL variable // Assign the value to the WGSL variable
@ -431,12 +436,14 @@ struct VertexPulling::State {
return b.Block(std::move(stmts)); return b.Block(std::move(stmts));
} }
/// Generates an expression reading from a buffer a specific format. /// Generates an expression reading a specific vertex format from a buffer. Any vertex format of
/// signed normailized, unsigned normailized, or float will result in `f32` or `vecN<f32>` WGSL
/// type.
/// @param array_base the symbol of the variable holding the base array offset /// @param array_base the symbol of the variable holding the base array offset
/// of the vertex array (each index is 4-bytes). /// of the vertex array (each index is 4-bytes).
/// @param offset the byte offset of the data from `buffer_base` /// @param offset the byte offset of the data from `buffer_base`
/// @param buffer the index of the vertex buffer /// @param buffer the index of the vertex buffer
/// @param format the format to read /// @param format the vertex format to read
const ast::Expression* Fetch(Symbol array_base, const ast::Expression* Fetch(Symbol array_base,
uint32_t offset, uint32_t offset,
uint32_t buffer, uint32_t buffer,

File diff suppressed because it is too large Load Diff