diff --git a/src/dawn/tests/end2end/ComputeLayoutMemoryBufferTests.cpp b/src/dawn/tests/end2end/ComputeLayoutMemoryBufferTests.cpp index 21e11dd0b3..68085740a9 100644 --- a/src/dawn/tests/end2end/ComputeLayoutMemoryBufferTests.cpp +++ b/src/dawn/tests/end2end/ComputeLayoutMemoryBufferTests.cpp @@ -34,77 +34,6 @@ std::string ReplaceAll(std::string str, const std::string& substr, const std::st return str; } -// DataMatcherCallback is the callback function by DataMatcher. -// It is called for each contiguous sequence of bytes that should be checked -// for equality. -// offset and size are in units of bytes. -using DataMatcherCallback = std::function; - -// DataMatcher is a function pointer to a data matching function. -// size is the total number of bytes being considered for matching. -// The callback may be called once or multiple times, and may only consider -// part of the interval [0, size) -using DataMatcher = void (*)(uint32_t size, DataMatcherCallback); - -// FullDataMatcher is a DataMatcher that calls callback with the interval -// [0, size) -void FullDataMatcher(uint32_t size, DataMatcherCallback callback) { - callback(0, size); -} - -// StridedDataMatcher is a DataMatcher that calls callback with the strided -// intervals of length BYTES_TO_MATCH, skipping BYTES_TO_SKIP. -// For example: StridedDataMatcher<2, 4>(18, callback) will call callback -// with the intervals: [0, 2), [6, 8), [12, 14) -template -void StridedDataMatcher(uint32_t size, DataMatcherCallback callback) { - uint32_t offset = 0; - while (offset < size) { - callback(offset, BYTES_TO_MATCH); - offset += BYTES_TO_MATCH + BYTES_TO_SKIP; - } -} - -// Align returns the WGSL decoration for an explicit structure field alignment -std::string AlignDeco(uint32_t value) { - return "@align(" + std::to_string(value) + ") "; -} - -} // namespace - -// Field holds test parameters for ComputeLayoutMemoryBufferTests.Fields -struct Field { - const char* type; // Type of the field - uint32_t align; // Alignment of the type in bytes - uint32_t size; // Natural size of the type in bytes - - uint32_t padded_size = 0; // Decorated (extended) size of the type in bytes - DataMatcher matcher = &FullDataMatcher; // The matching method - bool storage_buffer_only = false; // This should only be used for storage buffer tests - - // Sets the padded_size to value. - // Returns this Field so calls can be chained. - Field& PaddedSize(uint32_t value) { - padded_size = value; - return *this; - } - - // Sets the matcher to a StridedDataMatcher. - // Returns this Field so calls can be chained. - template - Field& Strided() { - matcher = &StridedDataMatcher; - return *this; - } - - // Marks that this should only be used for storage buffer tests. - // Returns this Field so calls can be chained. - Field& StorageBufferOnly() { - storage_buffer_only = true; - return *this; - } -}; - // StorageClass is an enumerator of storage classes used by ComputeLayoutMemoryBufferTests.Fields enum class StorageClass { Uniform, @@ -123,12 +52,395 @@ std::ostream& operator<<(std::ostream& o, StorageClass storageClass) { return o; } +// Host-sharable scalar types +enum class ScalarType { + f32, + i32, + u32, + f16, +}; + +std::string ScalarTypeName(ScalarType scalarType) { + switch (scalarType) { + case ScalarType::f32: + return "f32"; + case ScalarType::i32: + return "i32"; + case ScalarType::u32: + return "u32"; + case ScalarType::f16: + return "f16"; + } + UNREACHABLE(); + return ""; +} + +size_t ScalarTypeSize(ScalarType scalarType) { + switch (scalarType) { + case ScalarType::f32: + case ScalarType::i32: + case ScalarType::u32: + return 4; + case ScalarType::f16: + return 2; + } + UNREACHABLE(); + return 0; +} + +// MemoryDataBuilder records and performs operations of following types on a memory buffer `buf`: +// 1. "Align": Align to a alignment `alignment`, which will ensure +// `buf.size() % alignment == 0` by adding padding bytes into the buffer +// if necessary; +// 2. "Data": Add `size` bytes of data bytes into buffer; +// 3. "Padding": Add `size` bytes of padding bytes into buffer; +// 4. "FillingFixed": Fill all `size` given (fixed) bytes into the memory buffer. +// Note that data bytes and padding bytes are generated seperatedly and designed to +// be distinguishable, i.e. data bytes have MSB set to 0 while padding bytes 1. +class MemoryDataBuilder { + public: + // Record a "Align" operation + MemoryDataBuilder& AlignTo(uint32_t alignment) { + mOperations.push_back({OperationType::Align, alignment, {}}); + return *this; + } + + // Record a "Data" operation + MemoryDataBuilder& AddData(size_t size) { + mOperations.push_back({OperationType::Data, size, {}}); + return *this; + } + + // Record a "Padding" operation + MemoryDataBuilder& AddPadding(size_t size) { + mOperations.push_back({OperationType::Padding, size, {}}); + return *this; + } + + // Record a "FillingFixed" operation + MemoryDataBuilder& AddFixedBytes(std::vector& bytes) { + mOperations.push_back({OperationType::FillingFixed, bytes.size(), bytes}); + return *this; + } + + // A helper function to record a "FillingFixed" operation with all four bytes of a given U32 + MemoryDataBuilder& AddFixedU32(uint32_t u32) { + std::vector bytes; + bytes.emplace_back((u32 >> 0) & 0xff); + bytes.emplace_back((u32 >> 8) & 0xff); + bytes.emplace_back((u32 >> 16) & 0xff); + bytes.emplace_back((u32 >> 24) & 0xff); + return AddFixedBytes(bytes); + } + + // Record all operations that `builder` recorded + MemoryDataBuilder& AddSubBuilder(MemoryDataBuilder builder) { + mOperations.insert(mOperations.end(), builder.mOperations.begin(), + builder.mOperations.end()); + return *this; + } + + // Apply all recorded operations, one by one, on a given memory buffer. + // dataXorKey and paddingXorKey controls the generated data and padding bytes seperatedly, make + // it possible to, for example, generate two buffers that have different data bytes but + // identical padding bytes, thus can be used as initializer and expectation bytes of the copy + // destination buffer, expecting data bytes are changed while padding bytes are left unchanged. + void ApplyOperationsToBuffer(std::vector& buffer, + uint8_t dataXorKey, + uint8_t paddingXorKey) { + uint8_t dataByte = 0x0u; + uint8_t paddingByte = 0x2u; + // Get a data byte with MSB set to 0. + auto NextDataByte = [&]() { + dataByte += 0x11u; + return static_cast((dataByte ^ dataXorKey) & 0x7fu); + }; + // Get a padding byte with MSB set to 1, distinguished from data bytes. + auto NextPaddingByte = [&]() { + paddingByte += 0x13u; + return static_cast((paddingByte ^ paddingXorKey) | 0x80u); + }; + for (auto& operation : mOperations) { + switch (operation.mType) { + case OperationType::FillingFixed: { + ASSERT(operation.mOperand == operation.mFixedFillingData.size()); + buffer.insert(buffer.end(), operation.mFixedFillingData.begin(), + operation.mFixedFillingData.end()); + break; + } + case OperationType::Align: { + size_t targetSize = Align(buffer.size(), operation.mOperand); + size_t paddingSize = targetSize - buffer.size(); + for (size_t i = 0; i < paddingSize; i++) { + buffer.push_back(NextPaddingByte()); + } + break; + } + case OperationType::Data: { + for (size_t i = 0; i < operation.mOperand; i++) { + buffer.push_back(NextDataByte()); + } + break; + } + case OperationType::Padding: { + for (size_t i = 0; i < operation.mOperand; i++) { + buffer.push_back(NextPaddingByte()); + } + break; + } + } + } + } + + // Create a empty memory buffer and apply all recorded operations one by one on it. + std::vector CreateBufferAndApplyOperations(uint8_t dataXorKey = 0u, + uint8_t paddingXorKey = 0u) { + std::vector buffer; + ApplyOperationsToBuffer(buffer, dataXorKey, paddingXorKey); + return buffer; + } + + protected: + enum class OperationType { + Align, + Data, + Padding, + FillingFixed, + }; + struct Operation { + OperationType mType; + // mOperand is `alignment` for Align operation, and `size` for Data, Padding, and + // FillingFixed. + size_t mOperand; + // The data that will be filled into buffer if the segment type is FillingFixed. Otherwise + // for Padding and Data segment, the filling bytes are byte-wise generated based on xor + // keys. + std::vector mFixedFillingData; + }; + + std::vector mOperations; +}; + +// DataMatcherCallback is the callback function by DataMatcher. +// It is called for each contiguous sequence of bytes that should be checked +// for equality. +// offset and size are in units of bytes. +using DataMatcherCallback = std::function; + +// Field describe a type that has contiguous data bytes, e.g. `i32`, `vec2`, `mat4x4` or +// `array`, or have a fixed data stride, e.g. `mat3x3` or `array, 4>`. +// `@size` and `@align` attributes, when used as a struct member, can also described by this struct. +class Field { + public: + // Constructor with WGSL type name, natural alignment and natural size. Set mStrideDataBytes to + // natural size and mStridePaddingBytes to 0 by default to indicate continious data part. + Field(std::string wgslType, size_t align, size_t size) + : mWGSLType(wgslType), + mAlign(align), + mSize(size), + mStrideDataBytes(size), + mStridePaddingBytes(0) {} + + const std::string& GetWGSLType() const { return mWGSLType; } + size_t GetAlign() const { return mAlign; } + // The natural size of this field type, i.e. the size without @size attribute + size_t GetUnpaddedSize() const { return mSize; } + // The padded size determined by @size attribute if existed, otherwise the natural size + size_t GetPaddedSize() const { return mHasSizeAttribute ? mPaddedSize : mSize; } + + // Applies a @size attribute, sets the mPaddedSize to value. + // Returns this Field so calls can be chained. + Field& SizeAttribute(size_t value) { + ASSERT(value >= mSize); + mHasSizeAttribute = true; + mPaddedSize = value; + return *this; + } + + bool HasSizeAttribute() const { return mHasSizeAttribute; } + + // Applies a @align attribute, sets the align to value. + // Returns this Field so calls can be chained. + Field& AlignAttribute(size_t value) { + ASSERT(value >= mAlign); + ASSERT(IsPowerOfTwo(value)); + mAlign = value; + mHasAlignAttribute = true; + return *this; + } + + bool HasAlignAttribute() const { return mHasAlignAttribute; } + + // Mark that the data part of this field is strided, and record given mStrideDataBytes and + // mStridePaddingBytes. Returns this Field so calls can be chained. + Field& Strided(size_t bytesData, size_t bytesPadding) { + // Check that stride pattern cover the whole data part, i.e. the data part contains N x + // whole data bytes and N or (N-1) x whole padding bytes. + ASSERT((mSize % (bytesData + bytesPadding) == 0) || + ((mSize + bytesPadding) % (bytesData + bytesPadding) == 0)); + mStrideDataBytes = bytesData; + mStridePaddingBytes = bytesPadding; + return *this; + } + + // Marks that this should only be used for storage buffer tests. + // Returns this Field so calls can be chained. + Field& StorageBufferOnly() { + mStorageBufferOnly = true; + return *this; + } + + bool IsStorageBufferOnly() const { return mStorageBufferOnly; } + + // Call the DataMatcherCallback `callback` for continious or strided data bytes, based on the + // strided information of this field. The callback may be called once or multiple times. Note + // that padding bytes introduced by @size attribute are not tested. + void CheckData(DataMatcherCallback callback) const { + // Calls `callback` with the strided intervals of length mStrideDataBytes, skipping + // mStridePaddingBytes. For example, for a field of mSize = 18, mStrideDataBytes = 2, + // and mStridePaddingBytes = 4, calls `callback` with the intervals: [0, 2), [6, 8), + // [12, 14). If the data is continious, i.e. mStrideDataBytes = 18 and + // mStridePaddingBytes = 0, `callback` would be called only once with the whole interval + // [0, 18). + size_t offset = 0; + while (offset < mSize) { + callback(offset, mStrideDataBytes); + offset += mStrideDataBytes + mStridePaddingBytes; + } + } + + // Get a MemoryDataBuilder that do alignment, place data bytes and padding bytes, according to + // field's alignment, size, padding, and stride information. This MemoryDataBuilder can be used + // by other MemoryDataBuilder as needed. + MemoryDataBuilder GetDataBuilder() const { + MemoryDataBuilder builder; + builder.AlignTo(mAlign); + // Check that stride pattern cover the whole data part, i.e. the data part contains N x + // whole data bytes and N or (N-1) x whole padding bytes. Note that this also handle + // continious data, i.e. mStrideDataBytes == mSize and mStridePaddingBytes == 0, correctly. + ASSERT((mSize % (mStrideDataBytes + mStridePaddingBytes) == 0) || + ((mSize + mStridePaddingBytes) % (mStrideDataBytes + mStridePaddingBytes) == 0)); + size_t offset = 0; + while (offset < mSize) { + builder.AddData(mStrideDataBytes); + offset += mStrideDataBytes; + if (offset < mSize) { + builder.AddPadding(mStridePaddingBytes); + offset += mStridePaddingBytes; + } + } + if (mHasSizeAttribute) { + builder.AddPadding(mPaddedSize - mSize); + } + return builder; + } + + // Helper function to build a Field describing a scalar type. + static Field Scalar(ScalarType type) { + return Field(ScalarTypeName(type), ScalarTypeSize(type), ScalarTypeSize(type)); + } + + // Helper function to build a Field describing a vector type. + static Field Vector(uint32_t n, ScalarType type) { + ASSERT(2 <= n && n <= 4); + size_t elementSize = ScalarTypeSize(type); + size_t vectorSize = n * elementSize; + size_t vectorAlignment = (n == 3 ? 4 : n) * elementSize; + return Field{"vec" + std::to_string(n) + "<" + ScalarTypeName(type) + ">", vectorAlignment, + vectorSize}; + } + + // Helper function to build a Field describing a matrix type. + static Field Matrix(uint32_t col, uint32_t row, ScalarType type) { + ASSERT(2 <= col && col <= 4); + ASSERT(2 <= row && row <= 4); + ASSERT(type == ScalarType::f32 || type == ScalarType::f16); + size_t elementSize = ScalarTypeSize(type); + size_t colVectorSize = row * elementSize; + size_t colVectorAlignment = (row == 3 ? 4 : row) * elementSize; + Field field = Field{"mat" + std::to_string(col) + "x" + std::to_string(row) + "<" + + ScalarTypeName(type) + ">", + colVectorAlignment, col * colVectorAlignment}; + if (colVectorSize != colVectorAlignment) { + field.Strided(colVectorSize, colVectorAlignment - colVectorSize); + } + return field; + } + + private: + const std::string mWGSLType; // Friendly WGSL name of the type of the field + size_t mAlign; // Alignment of the type in bytes, can be change by @align attribute + const size_t mSize; // Natural size of the type in bytes + + bool mHasAlignAttribute = false; + bool mHasSizeAttribute = false; + // Decorated size of the type in bytes indicated by @size attribute, if existed + size_t mPaddedSize = 0; + // Whether this type doesn't meet the layout constraints for uniform buffer and thus should only + // be used for storage buffer tests + bool mStorageBufferOnly = false; + + // Describe the striding pattern of data part (i.e. the "natural size" part). Note that + // continious types are described as mStrideDataBytes == mSize and mStridePaddingBytes == 0. + size_t mStrideDataBytes; + size_t mStridePaddingBytes; +}; + std::ostream& operator<<(std::ostream& o, Field field) { - o << "@align(" << field.align << ") @size(" - << (field.padded_size > 0 ? field.padded_size : field.size) << ") " << field.type; + o << "@align(" << field.GetAlign() << ") @size(" << field.GetPaddedSize() << ") " + << field.GetWGSLType(); return o; } +// Create a compute pipeline with all buffer in bufferList binded in order starting from slot 0, and +// run the given shader. +void RunComputeShaderWithBuffers(const wgpu::Device& device, + const wgpu::Queue& queue, + const std::string& shader, + std::initializer_list bufferList) { + // Set up shader and pipeline + auto module = utils::CreateShaderModule(device, shader.c_str()); + + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = module; + csDesc.compute.entryPoint = "main"; + + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); + + // Set up bind group and issue dispatch + std::vector entries; + uint32_t bufferSlot = 0; + for (const wgpu::Buffer& buffer : bufferList) { + wgpu::BindGroupEntry entry; + entry.binding = bufferSlot++; + entry.buffer = buffer; + entry.offset = 0; + entry.size = wgpu::kWholeSize; + entries.push_back(entry); + } + + wgpu::BindGroupDescriptor descriptor; + descriptor.layout = pipeline.GetBindGroupLayout(0); + descriptor.entryCount = static_cast(entries.size()); + descriptor.entries = entries.data(); + + wgpu::BindGroup bindGroup = device.CreateBindGroup(&descriptor); + + wgpu::CommandBuffer commands; + { + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline); + pass.SetBindGroup(0, bindGroup); + pass.DispatchWorkgroups(1); + pass.End(); + + commands = encoder.Finish(); + } + + queue.Submit(1, &commands); +} + DAWN_TEST_PARAM_STRUCT(ComputeLayoutMemoryBufferTestParams, StorageClass, Field); class ComputeLayoutMemoryBufferTests @@ -136,7 +448,13 @@ class ComputeLayoutMemoryBufferTests void SetUp() override { DawnTestBase::SetUp(); } }; -TEST_P(ComputeLayoutMemoryBufferTests, Fields) { +// Align returns the WGSL decoration for an explicit structure field alignment +std::string AlignDeco(uint32_t value) { + return "@align(" + std::to_string(value) + ") "; +} + +// Test different types used as a struct member +TEST_P(ComputeLayoutMemoryBufferTests, StructMember) { // Sentinel value markers codes used to check that the start and end of // structures are correctly aligned. Each of these codes are distinct and // are not likely to be confused with data. @@ -145,15 +463,6 @@ TEST_P(ComputeLayoutMemoryBufferTests, Fields) { constexpr uint32_t kInputHeaderCode = 0x91827364u; constexpr uint32_t kInputFooterCode = 0x19283764u; - // Byte codes used for field padding. The MSB is set for each of these. - // The field data has the MSB 0. - constexpr uint8_t kDataAlignPaddingCode = 0xfeu; - constexpr uint8_t kFieldAlignPaddingCode = 0xfdu; - constexpr uint8_t kFieldSizePaddingCode = 0xdcu; - constexpr uint8_t kDataSizePaddingCode = 0xdbu; - constexpr uint8_t kInputFooterAlignPaddingCode = 0xdau; - constexpr uint8_t kInputTailPaddingCode = 0xd9u; - // Status codes returned by the shader. constexpr uint32_t kStatusBadInputHeader = 100u; constexpr uint32_t kStatusBadInputFooter = 101u; @@ -210,7 +519,7 @@ fn main() { // Structure size: roundUp(AlignOf(S), OffsetOf(S, L) + SizeOf(S, L)) // https://www.w3.org/TR/WGSL/#storage-class-constraints // RequiredAlignOf(S, uniform): roundUp(16, max(AlignOf(T0), ..., AlignOf(TN))) - uint32_t dataAlign = isUniform ? std::max(16u, field.align) : field.align; + uint32_t dataAlign = isUniform ? std::max(size_t(16u), field.GetAlign()) : field.GetAlign(); // https://www.w3.org/TR/WGSL/#structure-layout-rules // Note: When underlying the target is a Vulkan device, we assume the device does not support @@ -219,11 +528,10 @@ fn main() { uint32_t footerAlign = isUniform ? 16 : 4; shader = ReplaceAll(shader, "{data_align}", isUniform ? AlignDeco(dataAlign) : ""); - shader = ReplaceAll(shader, "{field_align}", std::to_string(field.align)); + shader = ReplaceAll(shader, "{field_align}", std::to_string(field.GetAlign())); shader = ReplaceAll(shader, "{footer_align}", isUniform ? AlignDeco(footerAlign) : ""); - shader = ReplaceAll(shader, "{field_size}", - std::to_string(field.padded_size > 0 ? field.padded_size : field.size)); - shader = ReplaceAll(shader, "{field_type}", field.type); + shader = ReplaceAll(shader, "{field_size}", std::to_string(field.GetPaddedSize())); + shader = ReplaceAll(shader, "{field_type}", field.GetWGSLType()); shader = ReplaceAll(shader, "{input_header_code}", std::to_string(kInputHeaderCode)); shader = ReplaceAll(shader, "{input_footer_code}", std::to_string(kInputFooterCode)); shader = ReplaceAll(shader, "{data_header_code}", std::to_string(kDataHeaderCode)); @@ -237,55 +545,40 @@ fn main() { isUniform ? "uniform" // : "storage, read_write"); - // Set up shader and pipeline - auto module = utils::CreateShaderModule(device, shader.c_str()); - - wgpu::ComputePipelineDescriptor csDesc; - csDesc.compute.module = module; - csDesc.compute.entryPoint = "main"; - - wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); - // Build the input and expected data. - std::vector inputData; // The whole SSBO data - std::vector expectedData; // The expected data to be copied by the shader + MemoryDataBuilder inputDataBuilder; // The whole SSBO data { - auto PushU32 = [&inputData](uint32_t u32) { - inputData.emplace_back((u32 >> 0) & 0xff); - inputData.emplace_back((u32 >> 8) & 0xff); - inputData.emplace_back((u32 >> 16) & 0xff); - inputData.emplace_back((u32 >> 24) & 0xff); - }; - auto AlignTo = [&inputData](uint32_t alignment, uint8_t code) { - uint32_t target = Align(inputData.size(), alignment); - uint32_t bytes = target - inputData.size(); - for (uint32_t i = 0; i < bytes; i++) { - inputData.emplace_back(code); - } - }; - PushU32(kInputHeaderCode); // Input.header - AlignTo(dataAlign, kDataAlignPaddingCode); // Input.data + inputDataBuilder.AddFixedU32(kInputHeaderCode); // Input.header + inputDataBuilder.AlignTo(dataAlign); // Input.data { - PushU32(kDataHeaderCode); // Input.data.header - AlignTo(field.align, kFieldAlignPaddingCode); // Input.data.field - for (uint32_t i = 0; i < field.size; i++) { - // The data has the MSB cleared to distinguish it from the - // padding codes. - uint8_t code = i & 0x7f; - inputData.emplace_back(code); // Input.data.field - expectedData.emplace_back(code); - } - for (uint32_t i = field.size; i < field.padded_size; i++) { - inputData.emplace_back(kFieldSizePaddingCode); // Input.data.field padding - } - PushU32(kDataFooterCode); // Input.data.footer - AlignTo(field.align, kDataSizePaddingCode); // Input.data padding + inputDataBuilder.AddFixedU32(kDataHeaderCode); // Input.data.header + inputDataBuilder.AddSubBuilder(field.GetDataBuilder()); // Input.data.field + inputDataBuilder.AddFixedU32(kDataFooterCode); // Input.data.footer + inputDataBuilder.AlignTo(field.GetAlign()); // Input.data padding } - AlignTo(footerAlign, kInputFooterAlignPaddingCode); // Input.footer @align - PushU32(kInputFooterCode); // Input.footer - AlignTo(256, kInputTailPaddingCode); // Input padding + inputDataBuilder.AlignTo(footerAlign); // Input.footer @align + inputDataBuilder.AddFixedU32(kInputFooterCode); // Input.footer + inputDataBuilder.AlignTo(256); // Input padding } + MemoryDataBuilder expectedDataBuilder; // The expected data to be copied by the shader + expectedDataBuilder.AddSubBuilder(field.GetDataBuilder()); + + // Expectation and input buffer have identical data bytes but different padding bytes. + // Initializes the dst buffer with data bytes different from input and expectation, and padding + // bytes identical to expectation but different from input. + constexpr uint8_t dataKeyForInputAndExpectation = 0x00u; + constexpr uint8_t dataKeyForDstInit = 0xffu; + constexpr uint8_t paddingKeyForInput = 0x3fu; + constexpr uint8_t paddingKeyForDstInitAndExpectation = 0x77u; + + std::vector inputData = inputDataBuilder.CreateBufferAndApplyOperations( + dataKeyForInputAndExpectation, paddingKeyForInput); + std::vector expectedData = expectedDataBuilder.CreateBufferAndApplyOperations( + dataKeyForInputAndExpectation, paddingKeyForDstInitAndExpectation); + std::vector initData = expectedDataBuilder.CreateBufferAndApplyOperations( + dataKeyForDstInit, paddingKeyForDstInitAndExpectation); + // Set up input storage buffer wgpu::Buffer inputBuf = utils::CreateBufferFromData( device, inputData.data(), inputData.size(), @@ -293,11 +586,9 @@ fn main() { (isUniform ? wgpu::BufferUsage::Uniform : wgpu::BufferUsage::Storage)); // Set up output storage buffer - wgpu::BufferDescriptor outputDesc; - outputDesc.size = field.size; - outputDesc.usage = - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst; - wgpu::Buffer outputBuf = device.CreateBuffer(&outputDesc); + wgpu::Buffer outputBuf = utils::CreateBufferFromData( + device, initData.data(), initData.size(), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); // Set up status storage buffer wgpu::BufferDescriptor statusDesc; @@ -306,40 +597,84 @@ fn main() { wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst; wgpu::Buffer statusBuf = device.CreateBuffer(&statusDesc); - // Set up bind group and issue dispatch - wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), - { - {0, inputBuf}, - {1, outputBuf}, - {2, statusBuf}, - }); - - wgpu::CommandBuffer commands; - { - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipeline); - pass.SetBindGroup(0, bindGroup); - pass.DispatchWorkgroups(1); - pass.End(); - - commands = encoder.Finish(); - } - - queue.Submit(1, &commands); + RunComputeShaderWithBuffers(device, queue, shader, {inputBuf, outputBuf, statusBuf}); // Check the status EXPECT_BUFFER_U32_EQ(kStatusOk, statusBuf, 0) << "status code error" << std::endl << "Shader: " << shader; // Check the data - field.matcher(field.size, [&](uint32_t offset, uint32_t size) { + field.CheckData([&](uint32_t offset, uint32_t size) { EXPECT_BUFFER_U8_RANGE_EQ(expectedData.data() + offset, outputBuf, offset, size) << "offset: " << offset; }); } -namespace { +// Test different types that used directly as buffer type +TEST_P(ComputeLayoutMemoryBufferTests, NonStructMember) { + auto params = GetParam(); + Field& field = params.mField; + // @size and @align attribute only apply to struct members, skip them + if (field.HasSizeAttribute() || field.HasAlignAttribute()) { + return; + } + + const bool isUniform = GetParam().mStorageClass == StorageClass::Uniform; + + std::string shader = R"( +@group(0) @binding(0) var<{input_qualifiers}> input : {field_type}; +@group(0) @binding(1) var output : {field_type}; + +@compute @workgroup_size(1,1,1) +fn main() { + output = input; +})"; + + shader = ReplaceAll(shader, "{field_type}", field.GetWGSLType()); + shader = ReplaceAll(shader, "{input_qualifiers}", + isUniform ? "uniform" // + : "storage, read_write"); + + // Build the input and expected data. + MemoryDataBuilder dataBuilder; + dataBuilder.AddSubBuilder(field.GetDataBuilder()); + + // Expectation and input buffer have identical data bytes but different padding bytes. + // Initializes the dst buffer with data bytes different from input and expectation, and padding + // bytes identical to expectation but different from input. + constexpr uint8_t dataKeyForInputAndExpectation = 0x00u; + constexpr uint8_t dataKeyForDstInit = 0xffu; + constexpr uint8_t paddingKeyForInput = 0x3fu; + constexpr uint8_t paddingKeyForDstInitAndExpectation = 0x77u; + + std::vector inputData = dataBuilder.CreateBufferAndApplyOperations( + dataKeyForInputAndExpectation, paddingKeyForInput); + std::vector expectedData = dataBuilder.CreateBufferAndApplyOperations( + dataKeyForInputAndExpectation, paddingKeyForDstInitAndExpectation); + std::vector initData = dataBuilder.CreateBufferAndApplyOperations( + dataKeyForDstInit, paddingKeyForDstInitAndExpectation); + + // Set up input storage buffer + wgpu::Buffer inputBuf = utils::CreateBufferFromData( + device, inputData.data(), inputData.size(), + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | + (isUniform ? wgpu::BufferUsage::Uniform : wgpu::BufferUsage::Storage)); + EXPECT_BUFFER_U8_RANGE_EQ(inputData.data(), inputBuf, 0, inputData.size()); + + // Set up output storage buffer + wgpu::Buffer outputBuf = utils::CreateBufferFromData( + device, initData.data(), initData.size(), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + EXPECT_BUFFER_U8_RANGE_EQ(initData.data(), outputBuf, 0, initData.size()); + + RunComputeShaderWithBuffers(device, queue, shader, {inputBuf, outputBuf}); + + // Check the data + field.CheckData([&](uint32_t offset, uint32_t size) { + EXPECT_BUFFER_U8_RANGE_EQ(expectedData.data() + offset, outputBuf, offset, size) + << "offset: " << offset; + }); +} auto GenerateParams() { auto params = MakeParamGenerator( @@ -354,127 +689,169 @@ auto GenerateParams() { { // See https://www.w3.org/TR/WGSL/#alignment-and-size // Scalar types with no custom alignment or size - Field{"i32", /* align */ 4, /* size */ 4}, - Field{"u32", /* align */ 4, /* size */ 4}, - Field{"f32", /* align */ 4, /* size */ 4}, + Field::Scalar(ScalarType::f32), + Field::Scalar(ScalarType::i32), + Field::Scalar(ScalarType::u32), // Scalar types with custom alignment - Field{"i32", /* align */ 16, /* size */ 4}, - Field{"u32", /* align */ 16, /* size */ 4}, - Field{"f32", /* align */ 16, /* size */ 4}, + Field::Scalar(ScalarType::f32).AlignAttribute(16), + Field::Scalar(ScalarType::i32).AlignAttribute(16), + Field::Scalar(ScalarType::u32).AlignAttribute(16), // Scalar types with custom size - Field{"i32", /* align */ 4, /* size */ 4}.PaddedSize(24), - Field{"u32", /* align */ 4, /* size */ 4}.PaddedSize(24), - Field{"f32", /* align */ 4, /* size */ 4}.PaddedSize(24), + Field::Scalar(ScalarType::f32).SizeAttribute(24), + Field::Scalar(ScalarType::i32).SizeAttribute(24), + Field::Scalar(ScalarType::u32).SizeAttribute(24), // Vector types with no custom alignment or size - Field{"vec2", /* align */ 8, /* size */ 8}, - Field{"vec2", /* align */ 8, /* size */ 8}, - Field{"vec2", /* align */ 8, /* size */ 8}, - Field{"vec3", /* align */ 16, /* size */ 12}, - Field{"vec3", /* align */ 16, /* size */ 12}, - Field{"vec3", /* align */ 16, /* size */ 12}, - Field{"vec4", /* align */ 16, /* size */ 16}, - Field{"vec4", /* align */ 16, /* size */ 16}, - Field{"vec4", /* align */ 16, /* size */ 16}, + Field::Vector(2, ScalarType::f32), + Field::Vector(3, ScalarType::f32), + Field::Vector(4, ScalarType::f32), + Field::Vector(2, ScalarType::i32), + Field::Vector(3, ScalarType::i32), + Field::Vector(4, ScalarType::i32), + Field::Vector(2, ScalarType::u32), + Field::Vector(3, ScalarType::u32), + Field::Vector(4, ScalarType::u32), // Vector types with custom alignment - Field{"vec2", /* align */ 32, /* size */ 8}, - Field{"vec2", /* align */ 32, /* size */ 8}, - Field{"vec2", /* align */ 32, /* size */ 8}, - Field{"vec3", /* align */ 32, /* size */ 12}, - Field{"vec3", /* align */ 32, /* size */ 12}, - Field{"vec3", /* align */ 32, /* size */ 12}, - Field{"vec4", /* align */ 32, /* size */ 16}, - Field{"vec4", /* align */ 32, /* size */ 16}, - Field{"vec4", /* align */ 32, /* size */ 16}, + Field::Vector(2, ScalarType::f32).AlignAttribute(32), + Field::Vector(3, ScalarType::f32).AlignAttribute(32), + Field::Vector(4, ScalarType::f32).AlignAttribute(32), + Field::Vector(2, ScalarType::i32).AlignAttribute(32), + Field::Vector(3, ScalarType::i32).AlignAttribute(32), + Field::Vector(4, ScalarType::i32).AlignAttribute(32), + Field::Vector(2, ScalarType::u32).AlignAttribute(32), + Field::Vector(3, ScalarType::u32).AlignAttribute(32), + Field::Vector(4, ScalarType::u32).AlignAttribute(32), // Vector types with custom size - Field{"vec2", /* align */ 8, /* size */ 8}.PaddedSize(24), - Field{"vec2", /* align */ 8, /* size */ 8}.PaddedSize(24), - Field{"vec2", /* align */ 8, /* size */ 8}.PaddedSize(24), - Field{"vec3", /* align */ 16, /* size */ 12}.PaddedSize(24), - Field{"vec3", /* align */ 16, /* size */ 12}.PaddedSize(24), - Field{"vec3", /* align */ 16, /* size */ 12}.PaddedSize(24), - Field{"vec4", /* align */ 16, /* size */ 16}.PaddedSize(24), - Field{"vec4", /* align */ 16, /* size */ 16}.PaddedSize(24), - Field{"vec4", /* align */ 16, /* size */ 16}.PaddedSize(24), + Field::Vector(2, ScalarType::f32).SizeAttribute(24), + Field::Vector(3, ScalarType::f32).SizeAttribute(24), + Field::Vector(4, ScalarType::f32).SizeAttribute(24), + Field::Vector(2, ScalarType::i32).SizeAttribute(24), + Field::Vector(3, ScalarType::i32).SizeAttribute(24), + Field::Vector(4, ScalarType::i32).SizeAttribute(24), + Field::Vector(2, ScalarType::u32).SizeAttribute(24), + Field::Vector(3, ScalarType::u32).SizeAttribute(24), + Field::Vector(4, ScalarType::u32).SizeAttribute(24), // Matrix types with no custom alignment or size - Field{"mat2x2", /* align */ 8, /* size */ 16}, - Field{"mat3x2", /* align */ 8, /* size */ 24}, - Field{"mat4x2", /* align */ 8, /* size */ 32}, - Field{"mat2x3", /* align */ 16, /* size */ 32}.Strided<12, 4>(), - Field{"mat3x3", /* align */ 16, /* size */ 48}.Strided<12, 4>(), - Field{"mat4x3", /* align */ 16, /* size */ 64}.Strided<12, 4>(), - Field{"mat2x4", /* align */ 16, /* size */ 32}, - Field{"mat3x4", /* align */ 16, /* size */ 48}, - Field{"mat4x4", /* align */ 16, /* size */ 64}, + Field::Matrix(2, 2, ScalarType::f32), + Field::Matrix(3, 2, ScalarType::f32), + Field::Matrix(4, 2, ScalarType::f32), + Field::Matrix(2, 3, ScalarType::f32), + Field::Matrix(3, 3, ScalarType::f32), + Field::Matrix(4, 3, ScalarType::f32), + Field::Matrix(2, 4, ScalarType::f32), + Field::Matrix(3, 4, ScalarType::f32), + Field::Matrix(4, 4, ScalarType::f32), // Matrix types with custom alignment - Field{"mat2x2", /* align */ 32, /* size */ 16}, - Field{"mat3x2", /* align */ 32, /* size */ 24}, - Field{"mat4x2", /* align */ 32, /* size */ 32}, - Field{"mat2x3", /* align */ 32, /* size */ 32}.Strided<12, 4>(), - Field{"mat3x3", /* align */ 32, /* size */ 48}.Strided<12, 4>(), - Field{"mat4x3", /* align */ 32, /* size */ 64}.Strided<12, 4>(), - Field{"mat2x4", /* align */ 32, /* size */ 32}, - Field{"mat3x4", /* align */ 32, /* size */ 48}, - Field{"mat4x4", /* align */ 32, /* size */ 64}, + Field::Matrix(2, 2, ScalarType::f32).AlignAttribute(32), + Field::Matrix(3, 2, ScalarType::f32).AlignAttribute(32), + Field::Matrix(4, 2, ScalarType::f32).AlignAttribute(32), + Field::Matrix(2, 3, ScalarType::f32).AlignAttribute(32), + Field::Matrix(3, 3, ScalarType::f32).AlignAttribute(32), + Field::Matrix(4, 3, ScalarType::f32).AlignAttribute(32), + Field::Matrix(2, 4, ScalarType::f32).AlignAttribute(32), + Field::Matrix(3, 4, ScalarType::f32).AlignAttribute(32), + Field::Matrix(4, 4, ScalarType::f32).AlignAttribute(32), // Matrix types with custom size - Field{"mat2x2", /* align */ 8, /* size */ 16}.PaddedSize(128), - Field{"mat3x2", /* align */ 8, /* size */ 24}.PaddedSize(128), - Field{"mat4x2", /* align */ 8, /* size */ 32}.PaddedSize(128), - Field{"mat2x3", /* align */ 16, /* size */ 32}.PaddedSize(128).Strided<12, 4>(), - Field{"mat3x3", /* align */ 16, /* size */ 48}.PaddedSize(128).Strided<12, 4>(), - Field{"mat4x3", /* align */ 16, /* size */ 64}.PaddedSize(128).Strided<12, 4>(), - Field{"mat2x4", /* align */ 16, /* size */ 32}.PaddedSize(128), - Field{"mat3x4", /* align */ 16, /* size */ 48}.PaddedSize(128), - Field{"mat4x4", /* align */ 16, /* size */ 64}.PaddedSize(128), + Field::Matrix(2, 2, ScalarType::f32).SizeAttribute(128), + Field::Matrix(3, 2, ScalarType::f32).SizeAttribute(128), + Field::Matrix(4, 2, ScalarType::f32).SizeAttribute(128), + Field::Matrix(2, 3, ScalarType::f32).SizeAttribute(128), + Field::Matrix(3, 3, ScalarType::f32).SizeAttribute(128), + Field::Matrix(4, 3, ScalarType::f32).SizeAttribute(128), + Field::Matrix(2, 4, ScalarType::f32).SizeAttribute(128), + Field::Matrix(3, 4, ScalarType::f32).SizeAttribute(128), + Field::Matrix(4, 4, ScalarType::f32).SizeAttribute(128), // Array types with no custom alignment or size. // Note: The use of StorageBufferOnly() is due to UBOs requiring 16 byte alignment // of array elements. See https://www.w3.org/TR/WGSL/#storage-class-constraints - Field{"array", /* align */ 4, /* size */ 4}.StorageBufferOnly(), - Field{"array", /* align */ 4, /* size */ 8}.StorageBufferOnly(), - Field{"array", /* align */ 4, /* size */ 12}.StorageBufferOnly(), - Field{"array", /* align */ 4, /* size */ 16}.StorageBufferOnly(), - Field{"array, 1>", /* align */ 16, /* size */ 16}, - Field{"array, 2>", /* align */ 16, /* size */ 32}, - Field{"array, 3>", /* align */ 16, /* size */ 48}, - Field{"array, 4>", /* align */ 16, /* size */ 64}, - Field{"array, 4>", /* align */ 16, /* size */ 64}.Strided<12, 4>(), + Field("array", /* align */ 4, /* size */ 4).StorageBufferOnly(), + Field("array", /* align */ 4, /* size */ 8).StorageBufferOnly(), + Field("array", /* align */ 4, /* size */ 12).StorageBufferOnly(), + Field("array", /* align */ 4, /* size */ 16).StorageBufferOnly(), + Field("array, 1>", /* align */ 8, /* size */ 8).StorageBufferOnly(), + Field("array, 2>", /* align */ 8, /* size */ 16).StorageBufferOnly(), + Field("array, 3>", /* align */ 8, /* size */ 24).StorageBufferOnly(), + Field("array, 4>", /* align */ 8, /* size */ 32).StorageBufferOnly(), + Field("array, 1>", /* align */ 16, /* size */ 16).Strided(12, 4), + Field("array, 2>", /* align */ 16, /* size */ 32).Strided(12, 4), + Field("array, 3>", /* align */ 16, /* size */ 48).Strided(12, 4), + Field("array, 4>", /* align */ 16, /* size */ 64).Strided(12, 4), + Field("array, 1>", /* align */ 16, /* size */ 16), + Field("array, 2>", /* align */ 16, /* size */ 32), + Field("array, 3>", /* align */ 16, /* size */ 48), + Field("array, 4>", /* align */ 16, /* size */ 64), // Array types with custom alignment - Field{"array", /* align */ 32, /* size */ 4}.StorageBufferOnly(), - Field{"array", /* align */ 32, /* size */ 8}.StorageBufferOnly(), - Field{"array", /* align */ 32, /* size */ 12}.StorageBufferOnly(), - Field{"array", /* align */ 32, /* size */ 16}.StorageBufferOnly(), - Field{"array, 1>", /* align */ 32, /* size */ 16}, - Field{"array, 2>", /* align */ 32, /* size */ 32}, - Field{"array, 3>", /* align */ 32, /* size */ 48}, - Field{"array, 4>", /* align */ 32, /* size */ 64}, - Field{"array, 4>", /* align */ 32, /* size */ 64}.Strided<12, 4>(), + Field("array", /* align */ 4, /* size */ 4) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array", /* align */ 4, /* size */ 8) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array", /* align */ 4, /* size */ 12) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array", /* align */ 4, /* size */ 16) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array, 1>", /* align */ 8, /* size */ 8) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array, 2>", /* align */ 8, /* size */ 16) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array, 3>", /* align */ 8, /* size */ 24) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array, 4>", /* align */ 8, /* size */ 32) + .AlignAttribute(32) + .StorageBufferOnly(), + Field("array, 1>", /* align */ 16, /* size */ 16) + .AlignAttribute(32) + .Strided(12, 4), + Field("array, 2>", /* align */ 16, /* size */ 32) + .AlignAttribute(32) + .Strided(12, 4), + Field("array, 3>", /* align */ 16, /* size */ 48) + .AlignAttribute(32) + .Strided(12, 4), + Field("array, 4>", /* align */ 16, /* size */ 64) + .AlignAttribute(32) + .Strided(12, 4), + Field("array, 1>", /* align */ 16, /* size */ 16).AlignAttribute(32), + Field("array, 2>", /* align */ 16, /* size */ 32).AlignAttribute(32), + Field("array, 3>", /* align */ 16, /* size */ 48).AlignAttribute(32), + Field("array, 4>", /* align */ 16, /* size */ 64).AlignAttribute(32), // Array types with custom size - Field{"array", /* align */ 4, /* size */ 4}.PaddedSize(128).StorageBufferOnly(), - Field{"array", /* align */ 4, /* size */ 8}.PaddedSize(128).StorageBufferOnly(), - Field{"array", /* align */ 4, /* size */ 12} - .PaddedSize(128) + Field("array", /* align */ 4, /* size */ 4) + .SizeAttribute(128) .StorageBufferOnly(), - Field{"array", /* align */ 4, /* size */ 16} - .PaddedSize(128) + Field("array", /* align */ 4, /* size */ 8) + .SizeAttribute(128) .StorageBufferOnly(), - Field{"array, 4>", /* align */ 16, /* size */ 64} - .PaddedSize(128) - .Strided<12, 4>(), + Field("array", /* align */ 4, /* size */ 12) + .SizeAttribute(128) + .StorageBufferOnly(), + Field("array", /* align */ 4, /* size */ 16) + .SizeAttribute(128) + .StorageBufferOnly(), + Field("array, 4>", /* align */ 16, /* size */ 64) + .SizeAttribute(128) + .Strided(12, 4), }); std::vector filtered; for (auto param : params) { - if (param.mStorageClass != StorageClass::Storage && param.mField.storage_buffer_only) { + if (param.mStorageClass != StorageClass::Storage && param.mField.IsStorageBufferOnly()) { continue; } filtered.emplace_back(param);