From e2c775f4bb7cebaae9ddf04f54b85063d2e749ea Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 9 Apr 2021 13:20:28 +0000 Subject: [PATCH] writer/hlsl: Fix storage buffers Use the new CalculateArrayLength and DecomposeStorageAccess transforms to simplify storage buffer patterns before running the HLSL writer. The HLSL writer now handles the InternalDecorations for the internal load, store, and buffer-length intrinsics. GeneratorImpl::EmitStorageBufferAccessor() has now been entirely removed, as all this primitive load / store decomposition performed by DecomposeStorageAccess. TODOs around runtime arrays have been removed, as this is now handled by CalculateArrayLength. Bug: tint:185 Bug: tint:683 Change-Id: Ie25a527e7a22da52778c4477cfc22501de558a41 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46878 Kokoro: Kokoro Reviewed-by: James Price Commit-Queue: Ben Clayton --- src/transform/hlsl.cc | 19 +- src/transform/manager.h | 8 + src/writer/hlsl/generator_impl.cc | 511 +++---- src/writer/hlsl/generator_impl.h | 34 +- .../hlsl/generator_impl_function_test.cc | 94 +- .../generator_impl_member_accessor_test.cc | 1181 +++++++++-------- .../hlsl/generator_impl_sanitizer_test.cc | 50 + src/writer/hlsl/generator_impl_type_test.cc | 69 +- 8 files changed, 973 insertions(+), 993 deletions(-) diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc index 9db68e0745..0fdddc9931 100644 --- a/src/transform/hlsl.cc +++ b/src/transform/hlsl.cc @@ -22,6 +22,9 @@ #include "src/semantic/expression.h" #include "src/semantic/statement.h" #include "src/semantic/variable.h" +#include "src/transform/calculate_array_length.h" +#include "src/transform/decompose_storage_access.h" +#include "src/transform/manager.h" namespace tint { namespace transform { @@ -29,13 +32,21 @@ namespace transform { Hlsl::Hlsl() = default; Hlsl::~Hlsl() = default; -Transform::Output Hlsl::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); +Transform::Output Hlsl::Run(const Program* in, const DataMap& data) { + Manager manager; + manager.Add(); + manager.Add(); + auto out = manager.Run(in, data); + if (!out.program.IsValid()) { + return out; + } + + ProgramBuilder builder; + CloneContext ctx(&builder, &out.program); PromoteInitializersToConstVar(ctx); AddEmptyEntryPoint(ctx); ctx.Clone(); - return Output{Program(std::move(out))}; + return Output{Program(std::move(builder))}; } void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const { diff --git a/src/transform/manager.h b/src/transform/manager.h index 8e860052de..345afb7d17 100644 --- a/src/transform/manager.h +++ b/src/transform/manager.h @@ -40,6 +40,14 @@ class Manager : public Transform { transforms_.push_back(std::move(transform)); } + /// Add pass to the manager of type `T`, constructed with the provided + /// arguments. + /// @param args the arguments to forward to the `T` constructor + template + void Add(ARGS&&... args) { + transforms_.emplace_back(std::make_unique(std::forward(args)...)); + } + /// Runs the transforms on `program`, returning the transformation result. /// @param program the source program to transform /// @param data optional extra transform-specific input data diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index ed508a75ea..0fb6a1242b 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -28,6 +28,8 @@ #include "src/semantic/member_accessor_expression.h" #include "src/semantic/struct.h" #include "src/semantic/variable.h" +#include "src/transform/calculate_array_length.h" +#include "src/transform/decompose_storage_access.h" #include "src/type/access_control_type.h" #include "src/type/depth_texture_type.h" #include "src/type/multisampled_texture_type.h" @@ -220,7 +222,7 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out, return true; } out << "typedef "; - if (!EmitType(out, alias->type(), "")) { + if (!EmitType(out, alias->type(), ast::StorageClass::kNone, "")) { return false; } out << " " << namer_.NameFor(builder_.Symbols().NameFor(alias->symbol())) @@ -240,11 +242,6 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out, bool GeneratorImpl::EmitArrayAccessor(std::ostream& pre, std::ostream& out, ast::ArrayAccessorExpression* expr) { - // Handle writing into a storage buffer array - if (is_storage_buffer_access(expr)) { - return EmitStorageBufferAccessor(pre, out, expr, nullptr); - } - if (!EmitExpression(pre, out, expr->array())) { return false; } @@ -268,7 +265,7 @@ bool GeneratorImpl::EmitBitcast(std::ostream& pre, } out << "as"; - if (!EmitType(out, expr->type(), "")) { + if (!EmitType(out, expr->type(), ast::StorageClass::kNone, "")) { return false; } out << "("; @@ -285,30 +282,6 @@ bool GeneratorImpl::EmitAssign(std::ostream& out, std::ostringstream pre; - // If the LHS is an accessor into a storage buffer then we have to - // emit a Store operation instead of an ='s. - if (auto* mem = stmt->lhs()->As()) { - if (is_storage_buffer_access(mem)) { - std::ostringstream accessor_out; - if (!EmitStorageBufferAccessor(pre, accessor_out, mem, stmt->rhs())) { - return false; - } - out << pre.str(); - out << accessor_out.str() << ";" << std::endl; - return true; - } - } else if (auto* ary = stmt->lhs()->As()) { - if (is_storage_buffer_access(ary)) { - std::ostringstream accessor_out; - if (!EmitStorageBufferAccessor(pre, accessor_out, ary, stmt->rhs())) { - return false; - } - out << pre.str(); - out << accessor_out.str() << ";" << std::endl; - return true; - } - } - std::ostringstream lhs_out; if (!EmitExpression(pre, lhs_out, stmt->lhs())) { return false; @@ -516,12 +489,130 @@ bool GeneratorImpl::EmitCall(std::ostream& pre, return 0; } + const auto& params = expr->params(); auto* call = builder_.Sem().Get(expr); + auto* target = call->Target(); + + if (auto* func = target->As()) { + if (ast::HasDecoration< + transform::CalculateArrayLength::BufferSizeIntrinsic>( + func->Declaration()->decorations())) { + // Special function generated by the CalculateArrayLength transform for + // calling X.GetDimensions(Y) + if (!EmitExpression(pre, out, params[0])) { + return false; + } + out << ".GetDimensions("; + if (!EmitExpression(pre, out, params[1])) { + return false; + } + out << ")"; + return true; + } + + if (auto* intrinsic = + ast::GetDecoration( + func->Declaration()->decorations())) { + auto load = [&](const char* cast, int n) { + if (cast) { + out << cast << "("; + } + if (!EmitExpression(pre, out, params[0])) { // buffer + return false; + } + out << ".Load"; + if (n > 1) { + out << n; + } + ScopedParen sp(out); + if (!EmitExpression(pre, out, params[1])) { // offset + return false; + } + if (cast) { + out << ")"; + } + return true; + }; + auto store = [&](int n) { + if (!EmitExpression(pre, out, params[0])) { // buffer + return false; + } + out << ".Store"; + if (n > 1) { + out << n; + } + ScopedParen sp1(out); + if (!EmitExpression(pre, out, params[1])) { // offset + return false; + } + out << ", asuint"; + ScopedParen sp2(out); + if (!EmitExpression(pre, out, params[2])) { // value + return false; + } + return true; + }; + + switch (intrinsic->type) { + case transform::DecomposeStorageAccess::Intrinsic::kLoadU32: + return load(nullptr, 1); + case transform::DecomposeStorageAccess::Intrinsic::kLoadF32: + return load("asfloat", 1); + case transform::DecomposeStorageAccess::Intrinsic::kLoadI32: + return load("asint", 1); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec2U32: + return load(nullptr, 2); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec2F32: + return load("asfloat", 2); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec2I32: + return load("asint", 2); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec3U32: + return load(nullptr, 3); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec3F32: + return load("asfloat", 3); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec3I32: + return load("asint", 3); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec4U32: + return load(nullptr, 4); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec4F32: + return load("asfloat", 4); + case transform::DecomposeStorageAccess::Intrinsic::kLoadVec4I32: + return load("asint", 4); + case transform::DecomposeStorageAccess::Intrinsic::kStoreU32: + return store(1); + case transform::DecomposeStorageAccess::Intrinsic::kStoreF32: + return store(1); + case transform::DecomposeStorageAccess::Intrinsic::kStoreI32: + return store(1); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec2U32: + return store(2); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec2F32: + return store(2); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec2I32: + return store(2); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec3U32: + return store(3); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec3F32: + return store(3); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec3I32: + return store(3); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec4U32: + return store(4); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec4F32: + return store(4); + case transform::DecomposeStorageAccess::Intrinsic::kStoreVec4I32: + return store(4); + } + + TINT_UNIMPLEMENTED(diagnostics_) << static_cast(intrinsic->type); + return false; + } + } + if (auto* intrinsic = call->Target()->As()) { if (intrinsic->IsTexture()) { return EmitTextureCall(pre, out, expr, intrinsic); } - const auto& params = expr->params(); if (intrinsic->Type() == semantic::IntrinsicType::kSelect) { diagnostics_.add_error("select not supported in HLSL backend yet"); return false; @@ -597,7 +688,6 @@ bool GeneratorImpl::EmitCall(std::ostream& pre, } } - const auto& params = expr->params(); for (auto* param : params) { if (!first) { out << ", "; @@ -1241,7 +1331,7 @@ bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre, if (brackets) { out << "{"; } else { - if (!EmitType(out, expr->type(), "")) { + if (!EmitType(out, expr->type(), ast::StorageClass::kNone, "")) { return false; } out << "("; @@ -1499,7 +1589,7 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, Symbol ep_sym) { auto name = func->symbol().to_str(); - if (!EmitType(out, func->return_type(), "")) { + if (!EmitType(out, func->return_type(), ast::StorageClass::kNone, "")) { return false; } @@ -1551,9 +1641,16 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, } first = false; - auto* type = builder_.Sem().Get(v)->Type(); + auto* sem = builder_.Sem().Get(v); + auto* type = sem->Type(); - if (!EmitType(out, type, builder_.Symbols().NameFor(v->symbol()))) { + // Note: WGSL only allows for StorageClass::kNone on parameters, however the + // sanitizer transforms generates load / store functions for storage + // buffers. These functions have a storage buffer parameter with + // StorageClass::kStorage. This is required to correctly translate the + // parameter to [RW]ByteAddressBuffer. + if (!EmitType(out, type, sem->StorageClass(), + builder_.Symbols().NameFor(v->symbol()))) { return false; } // Array name is output as part of the type @@ -1638,7 +1735,7 @@ bool GeneratorImpl::EmitEntryPointData( increment_indent(); make_indent(out); - if (!EmitType(out, type, "")) { + if (!EmitType(out, type, var->StorageClass(), "")) { return false; } out << " " << builder_.Symbols().NameFor(decl->symbol()) << ";" @@ -1663,18 +1760,19 @@ bool GeneratorImpl::EmitEntryPointData( continue; // Global already emitted } - auto* ac = var->Type()->As(); - if (ac == nullptr) { + auto* access = var->Type()->As(); + if (access == nullptr) { diagnostics_.add_error("access control type required for storage buffer"); return false; } - if (!ac->IsReadOnly()) { - out << "RW"; + if (!EmitType(out, var->Type(), ast::StorageClass::kStorage, "")) { + return false; } - out << "ByteAddressBuffer " << builder_.Symbols().NameFor(decl->symbol()) - << RegisterAndSpace(ac->IsReadOnly() ? 't' : 'u', binding_point) << ";" - << std::endl; + + out << " " << builder_.Symbols().NameFor(decl->symbol()) + << RegisterAndSpace(access->IsReadOnly() ? 't' : 'u', binding_point) + << ";" << std::endl; emitted_storagebuffer = true; } if (emitted_storagebuffer) { @@ -1696,10 +1794,12 @@ bool GeneratorImpl::EmitEntryPointData( for (auto& data : in_variables) { auto* var = data.first; auto* deco = data.second; - auto* type = builder_.Sem().Get(var)->Type(); + auto* sem = builder_.Sem().Get(var); + auto* type = sem->Type(); make_indent(out); - if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, sem->StorageClass(), + builder_.Symbols().NameFor(var->symbol()))) { return false; } @@ -1745,10 +1845,12 @@ bool GeneratorImpl::EmitEntryPointData( for (auto& data : outvariables) { auto* var = data.first; auto* deco = data.second; - auto* type = builder_.Sem().Get(var)->Type(); + auto* sem = builder_.Sem().Get(var); + auto* type = sem->Type(); make_indent(out); - if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, sem->StorageClass(), + builder_.Symbols().NameFor(var->symbol()))) { return false; } @@ -1800,7 +1902,7 @@ bool GeneratorImpl::EmitEntryPointData( continue; // Not interested in this type } - if (!EmitType(out, var->Type(), "")) { + if (!EmitType(out, var->Type(), var->StorageClass(), "")) { return false; } out << " " << namer_.NameFor(builder_.Symbols().NameFor(decl->symbol())); @@ -1914,7 +2016,8 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, // Emit entry point parameters. for (auto* var : func->params()) { - auto* type = builder_.Sem().Get(var)->Type(); + auto* sem = builder_.Sem().Get(var); + auto* type = sem->Type(); if (!type->Is()) { TINT_ICE(diagnostics_) << "Unsupported non-struct entry point parameter"; } @@ -1924,7 +2027,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, } first = false; - if (!EmitType(out, type, "")) { + if (!EmitType(out, type, sem->StorageClass(), "")) { return false; } @@ -1992,7 +2095,7 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) { } else if (type->Is()) { out << "0u"; } else if (auto* vec = type->As()) { - if (!EmitType(out, type, "")) { + if (!EmitType(out, type, ast::StorageClass::kNone, "")) { return false; } ScopedParen sp(out); @@ -2005,7 +2108,7 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) { } } } else if (auto* mat = type->As()) { - if (!EmitType(out, type, "")) { + if (!EmitType(out, type, ast::StorageClass::kNone, "")) { return false; } ScopedParen sp(out); @@ -2134,263 +2237,9 @@ bool GeneratorImpl::EmitLoop(std::ostream& out, ast::LoopStatement* stmt) { return true; } -std::string GeneratorImpl::generate_storage_buffer_index_expression( - std::ostream& pre, - ast::Expression* expr) { - std::ostringstream out; - bool first = true; - for (;;) { - if (expr->Is()) { - break; - } - - if (!first) { - out << " + "; - } - first = false; - if (auto* mem = expr->As()) { - auto* res_type = TypeOf(mem->structure())->UnwrapAll(); - if (auto* str = res_type->As()) { - auto* str_type = str->impl(); - auto* str_member = str_type->get_member(mem->member()->symbol()); - - auto* sem_mem = builder_.Sem().Get(str_member); - if (!sem_mem) { - TINT_ICE(diagnostics_) << "struct member missing semantic info"; - return ""; - } - - out << sem_mem->Offset(); - - } else if (res_type->Is()) { - auto swizzle = builder_.Sem().Get(mem)->Swizzle(); - - // TODO(dsinclair): Swizzle stuff - // - // This must be a single element swizzle if we've got a vector at this - // point. - if (swizzle.size() != 1) { - TINT_ICE(diagnostics_) - << "Encountered multi-element swizzle when should have only one " - "level"; - return ""; - } - - // TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32) - // so this is assuming 4. This will need to be fixed when we get f16 or - // f64 types. - out << "(4 * " << swizzle[0] << ")"; - } else { - TINT_ICE(diagnostics_) << "Invalid result type for member accessor: " - << res_type->type_name(); - return ""; - } - - expr = mem->structure(); - } else if (auto* ary = expr->As()) { - auto* ary_type = TypeOf(ary->array())->UnwrapAll(); - - out << "("; - if (auto* arr = ary_type->As()) { - auto* sem_arr = builder_.Sem().Get(arr); - if (!sem_arr) { - TINT_ICE(diagnostics_) << "array type missing semantic info"; - return ""; - } - out << sem_arr->Stride(); - } else if (ary_type->Is()) { - // TODO(dsinclair): This is a hack. Our vectors can only be f32, i32 - // or u32 which are all 4 bytes. When we get f16 or other types we'll - // have to ask the type for the byte size. - out << "4"; - } else if (auto* mat = ary_type->As()) { - if (mat->columns() == 2) { - out << "8"; - } else { - out << "16"; - } - } else { - diagnostics_.add_error("Invalid array type in storage buffer access"); - return ""; - } - out << " * "; - if (!EmitExpression(pre, out, ary->idx_expr())) { - return ""; - } - out << ")"; - - expr = ary->array(); - } else { - diagnostics_.add_error("error emitting storage buffer access"); - return ""; - } - } - - return out.str(); -} - -// TODO(dsinclair): This currently only handles loading of 4, 8, 12 or 16 byte -// members. If we need to support larger we'll need to do the loading into -// chunks. -// -// TODO(dsinclair): Need to support loading through a pointer. The pointer is -// just a memory address in the storage buffer, so need to do the correct -// calculation. -bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre, - std::ostream& out, - ast::Expression* expr, - ast::Expression* rhs) { - auto* result_type = TypeOf(expr)->UnwrapAll(); - bool is_store = rhs != nullptr; - - std::string access_method = is_store ? "Store" : "Load"; - if (auto* vec = result_type->As()) { - access_method += std::to_string(vec->size()); - } else if (auto* mat = result_type->As()) { - access_method += std::to_string(mat->rows()); - } - - // If we aren't storing then we need to put in the outer cast. - if (!is_store) { - if (result_type->is_float_scalar_or_vector() || - result_type->Is()) { - out << "asfloat("; - } else if (result_type->is_signed_scalar_or_vector()) { - out << "asint("; - } else if (result_type->is_unsigned_scalar_or_vector()) { - out << "asuint("; - } else { - TINT_UNIMPLEMENTED(diagnostics_) - << result_type->FriendlyName(builder_.Symbols()); - return false; - } - } - - auto buffer_name = get_buffer_name(expr); - if (buffer_name.empty()) { - diagnostics_.add_error("error emitting storage buffer access"); - return false; - } - - auto idx = generate_storage_buffer_index_expression(pre, expr); - if (idx.empty()) { - return false; - } - - if (auto* mat = result_type->As()) { - // TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed - // if we get matrixes of f16 or f64. - uint32_t stride = mat->rows() == 2 ? 8 : 16; - - if (is_store) { - if (!EmitType(out, mat, "")) { - return false; - } - - auto name = generate_name(kTempNamePrefix); - out << " " << name << " = "; - if (!EmitExpression(pre, out, rhs)) { - return false; - } - out << ";" << std::endl; - - for (uint32_t i = 0; i < mat->columns(); i++) { - if (i > 0) { - out << ";" << std::endl; - } - - make_indent(out); - out << buffer_name << "." << access_method << "(" << idx << " + " - << (i * stride) << ", asuint(" << name << "[" << i << "]))"; - } - - return true; - } - - out << "uint" << mat->rows() << "x" << mat->columns(); - ScopedParen p(out); - for (uint32_t i = 0; i < mat->columns(); i++) { - if (i != 0) { - out << ", "; - } - - out << buffer_name << "." << access_method << "(" << idx << " + " - << (i * stride) << ")"; - } - } else { - out << buffer_name << "." << access_method; - ScopedParen p(out); - out << idx; - if (is_store) { - out << ", asuint"; - ScopedParen p2(out); - if (!EmitExpression(pre, out, rhs)) { - return false; - } - } - } - - if (!is_store) { - out << ")"; - } - return true; -} - -bool GeneratorImpl::is_storage_buffer_access( - ast::ArrayAccessorExpression* expr) { - // We only care about array so we can get to the next part of the expression. - // If it isn't an array or a member accessor we can stop looking as it won't - // be a storage buffer. - auto* ary = expr->array(); - if (auto* member = ary->As()) { - return is_storage_buffer_access(member); - } else if (auto* array = ary->As()) { - return is_storage_buffer_access(array); - } - return false; -} - -bool GeneratorImpl::is_storage_buffer_access( - ast::MemberAccessorExpression* expr) { - auto* structure = expr->structure(); - auto* data_type = TypeOf(structure)->UnwrapAll(); - // TODO(dsinclair): Swizzle - // - // If the data is a multi-element swizzle then we will not load the swizzle - // portion through the Load command. - if (data_type->Is() && - builder_.Symbols().NameFor(expr->member()->symbol()).size() > 1) { - return false; - } - - // Check if this is a storage buffer variable - if (auto* ident = expr->structure()->As()) { - const semantic::Variable* var = nullptr; - if (!global_variables_.get(ident->symbol(), &var)) { - return false; - } - return var->StorageClass() == ast::StorageClass::kStorage; - } else if (auto* member = structure->As()) { - return is_storage_buffer_access(member); - } else if (auto* array = structure->As()) { - return is_storage_buffer_access(array); - } - - // Technically I don't think this is possible, but if we don't have a struct - // or array accessor then we can't have a storage buffer I believe. - return false; -} - bool GeneratorImpl::EmitMemberAccessor(std::ostream& pre, std::ostream& out, ast::MemberAccessorExpression* expr) { - // Look for storage buffer accesses as we have to convert them into Load - // expressions. Stores will be identified in the assignment emission and a - // member accessor store of a storage buffer will not get here. - if (is_storage_buffer_access(expr)) { - return EmitStorageBufferAccessor(pre, out, expr, nullptr); - } - if (!EmitExpression(pre, out, expr->structure())) { return false; } @@ -2515,12 +2364,26 @@ bool GeneratorImpl::EmitSwitch(std::ostream& out, ast::SwitchStatement* stmt) { bool GeneratorImpl::EmitType(std::ostream& out, type::Type* type, + ast::StorageClass storage_class, const std::string& name) { auto* access = type->As(); if (access) { type = access->type(); } + if (storage_class == ast::StorageClass::kStorage) { + if (access == nullptr) { + diagnostics_.add_error("access control type required for storage buffer"); + return false; + } + + if (!access->IsReadOnly()) { + out << "RW"; + } + out << "ByteAddressBuffer"; + return true; + } + if (auto* alias = type->As()) { out << namer_.NameFor(builder_.Symbols().NameFor(alias->symbol())); } else if (auto* ary = type->As()) { @@ -2528,16 +2391,15 @@ bool GeneratorImpl::EmitType(std::ostream& out, std::vector sizes; while (auto* arr = base_type->As()) { if (arr->IsRuntimeArray()) { - // TODO(dsinclair): Support runtime arrays - // https://bugs.chromium.org/p/tint/issues/detail?id=185 - diagnostics_.add_error("runtime array not supported yet."); + TINT_ICE(diagnostics_) + << "Runtime arrays may only exist in storage buffers, which should " + "have been transformed into a ByteAddressBuffer"; return false; - } else { - sizes.push_back(arr->size()); } + sizes.push_back(arr->size()); base_type = arr->type(); } - if (!EmitType(out, base_type, "")) { + if (!EmitType(out, base_type, storage_class, "")) { return false; } if (!name.empty()) { @@ -2553,7 +2415,7 @@ bool GeneratorImpl::EmitType(std::ostream& out, } else if (type->Is()) { out << "int"; } else if (auto* mat = type->As()) { - if (!EmitType(out, mat->type(), "")) { + if (!EmitType(out, mat->type(), storage_class, "")) { return false; } // Note: HLSL's matrices are declared as NxM, where N is the number of @@ -2652,7 +2514,7 @@ bool GeneratorImpl::EmitType(std::ostream& out, out << "uint" << size; } else { out << "vector<"; - if (!EmitType(out, vec->type(), "")) { + if (!EmitType(out, vec->type(), storage_class, "")) { return false; } out << ", " << size << ">"; @@ -2689,7 +2551,7 @@ bool GeneratorImpl::EmitStructType(std::ostream& out, // TODO(dsinclair): Handle [[offset]] annotation on structs // https://bugs.chromium.org/p/tint/issues/detail?id=184 - if (!EmitType(out, mem->type(), + if (!EmitType(out, mem->type(), ast::StorageClass::kNone, builder_.Symbols().NameFor(mem->symbol()))) { return false; } @@ -2788,8 +2650,10 @@ bool GeneratorImpl::EmitVariable(std::ostream& out, if (var->is_const()) { out << "const "; } - auto* type = builder_.Sem().Get(var)->Type(); - if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { + auto* sem = builder_.Sem().Get(var); + auto* type = sem->Type(); + if (!EmitType(out, type, sem->StorageClass(), + builder_.Symbols().NameFor(var->symbol()))) { return false; } if (!type->Is()) { @@ -2824,7 +2688,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, out << pre.str(); } - auto* type = builder_.Sem().Get(var)->Type(); + auto* sem = builder_.Sem().Get(var); + auto* type = sem->Type(); if (ast::HasDecoration(var->decorations())) { auto const_id = var->constant_id(); @@ -2840,7 +2705,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, } out << "#endif" << std::endl; out << "static const "; - if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, sem->StorageClass(), + builder_.Symbols().NameFor(var->symbol()))) { return false; } out << " " << builder_.Symbols().NameFor(var->symbol()) @@ -2848,7 +2714,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl; } else { out << "static const "; - if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) { + if (!EmitType(out, type, sem->StorageClass(), + builder_.Symbols().NameFor(var->symbol()))) { return false; } if (!type->Is()) { diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 48ef7e3fc9..08d844f552 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -37,6 +37,9 @@ namespace tint { // Forward declarations +namespace type { +class AccessControl; +} // namespace type namespace semantic { class Call; class Intrinsic; @@ -266,16 +269,6 @@ class GeneratorImpl : public TextGenerator { bool EmitMemberAccessor(std::ostream& pre, std::ostream& out, ast::MemberAccessorExpression* expr); - /// Handles a storage buffer accessor expression - /// @param pre the preamble for the expression stream - /// @param out the output of the expression stream - /// @param expr the storage buffer accessor expression - /// @param rhs the right side of a store expression. Set to nullptr for a load - /// @returns true if the storage buffer accessor was emitted - bool EmitStorageBufferAccessor(std::ostream& pre, - std::ostream& out, - ast::Expression* expr, - ast::Expression* rhs); /// Handles return statements /// @param out the output stream /// @param stmt the statement to emit @@ -294,9 +287,13 @@ class GeneratorImpl : public TextGenerator { /// Handles generating type /// @param out the output stream /// @param type the type to generate + /// @param storage_class the storage class of the variable /// @param name the name of the variable, only used for array emission /// @returns true if the type is emitted - bool EmitType(std::ostream& out, type::Type* type, const std::string& name); + bool EmitType(std::ostream& out, + type::Type* type, + ast::StorageClass storage_class, + const std::string& name); /// Handles generating a structure declaration /// @param out the output stream /// @param ty the struct to generate @@ -332,15 +329,6 @@ class GeneratorImpl : public TextGenerator { /// @returns true if the variable was emitted bool EmitProgramConstVariable(std::ostream& out, const ast::Variable* var); - /// Returns true if the accessor is accessing a storage buffer. - /// @param expr the expression to check - /// @returns true if the accessor is accessing a storage buffer for which - /// we need to execute a Load instruction. - bool is_storage_buffer_access(ast::MemberAccessorExpression* expr); - /// Returns true if the accessor is accessing a storage buffer. - /// @param expr the expression to check - /// @returns true if the accessor is accessing a storage buffer - bool is_storage_buffer_access(ast::ArrayAccessorExpression* expr); /// Registers the given global with the generator /// @param global the global to register void register_global(ast::Variable* global); @@ -348,12 +336,6 @@ class GeneratorImpl : public TextGenerator { /// @param var the variable to check /// @returns true if the global is in an input or output struct bool global_is_in_struct(const semantic::Variable* var) const; - /// Creates a text string representing the index into a storage buffer - /// @param pre the pre stream - /// @param expr the expression to use as the index - /// @returns the index string, or blank if unable to generate - std::string generate_storage_buffer_index_expression(std::ostream& pre, - ast::Expression* expr); /// Handles generating a builtin method name /// @param intrinsic the semantic info for the intrinsic /// @returns the name or "" if not valid diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index fcf000adf1..ea23929875 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -280,30 +280,30 @@ TEST_F(HlslGeneratorImplTest_Function, EXPECT_EQ(result(), R"(struct VertexOutput { float4 pos; }; -struct tint_symbol_2 { +struct tint_symbol_6 { float4 pos : SV_Position; }; -struct tint_symbol_6 { +struct tint_symbol_9 { float4 pos : SV_Position; }; VertexOutput foo(float x) { - const VertexOutput tint_symbol_8 = {float4(x, x, x, 1.0f)}; - return tint_symbol_8; -} - -tint_symbol_2 vert_main1() { - const VertexOutput tint_symbol_4 = {foo(0.5f)}; - const tint_symbol_2 tint_symbol_1 = {tint_symbol_4.pos}; + const VertexOutput tint_symbol_1 = {float4(x, x, x, 1.0f)}; return tint_symbol_1; } -tint_symbol_6 vert_main2() { - const VertexOutput tint_symbol_7 = {foo(0.25f)}; +tint_symbol_6 vert_main1() { + const VertexOutput tint_symbol_7 = {foo(0.5f)}; const tint_symbol_6 tint_symbol_5 = {tint_symbol_7.pos}; return tint_symbol_5; } +tint_symbol_9 vert_main2() { + const VertexOutput tint_symbol_10 = {foo(0.25f)}; + const tint_symbol_9 tint_symbol_8 = {tint_symbol_10.pos}; + return tint_symbol_8; +} + )"); Validate(); @@ -415,16 +415,19 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_THAT(result(), - HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1); + EXPECT_EQ(result(), + R"( +RWByteAddressBuffer coord : register(u0, space1); void frag_main() { - float v = asfloat(coord.Load(4)); + float v = asfloat(coord.Load(4u)); return; -})")); +} + +)"); Validate(); } @@ -456,16 +459,19 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_THAT(result(), - HasSubstr(R"(ByteAddressBuffer coord : register(t0, space1); + EXPECT_EQ(result(), + R"( +ByteAddressBuffer coord : register(t0, space1); void frag_main() { - float v = asfloat(coord.Load(4)); + float v = asfloat(coord.Load(4u)); return; -})")); +} + +)"); Validate(); } @@ -495,16 +501,19 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_THAT(result(), - HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1); + EXPECT_EQ(result(), + R"( +RWByteAddressBuffer coord : register(u0, space1); void frag_main() { - coord.Store(4, asuint(2.0f)); + coord.Store(4u, asuint(2.0f)); return; -})")); +} + +)"); Validate(); } @@ -534,16 +543,19 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_THAT(result(), - HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1); + EXPECT_EQ(result(), + R"( +RWByteAddressBuffer coord : register(u0, space1); void frag_main() { - coord.Store(4, asuint(2.0f)); + coord.Store(4u, asuint(2.0f)); return; -})")); +} + +)"); Validate(); } @@ -792,20 +804,22 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_THAT(result(), - HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1); + EXPECT_EQ(result(), + R"(RWByteAddressBuffer coord : register(u0, space1); float sub_func(float param) { - return asfloat(coord.Load((4 * 0))); + return asfloat(coord.Load(0u)); } void frag_main() { float v = sub_func(1.0f); return; -})")); +} + +)"); Validate(); } @@ -946,11 +960,13 @@ TEST_F(HlslGeneratorImplTest_Function, // // [[stage(compute)]] // fn a() { + // var v = data.d; // return; // } // // [[stage(compute)]] // fn b() { + // var v = data.d; // return; // } @@ -994,7 +1010,7 @@ TEST_F(HlslGeneratorImplTest_Function, }); } - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); EXPECT_EQ(result(), R"( @@ -1002,13 +1018,13 @@ RWByteAddressBuffer data : register(u0, space0); [numthreads(1, 1, 1)] void a() { - float v = asfloat(data.Load(0)); + float v = asfloat(data.Load(0u)); return; } [numthreads(1, 1, 1)] void b() { - float v = asfloat(data.Load(0)); + float v = asfloat(data.Load(0u)); return; } diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc index 23e3fc253b..295af7ed81 100644 --- a/src/writer/hlsl/generator_impl_member_accessor_test.cc +++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "gmock/gmock.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/struct_block_decoration.h" +#include "src/type/access_control_type.h" #include "src/writer/hlsl/test_helper.h" namespace tint { @@ -19,276 +23,388 @@ namespace writer { namespace hlsl { namespace { -using HlslGeneratorImplTest_MemberAccessor = TestHelper; +using ::testing::HasSubstr; + +using create_type_func_ptr = + type::Type* (*)(const ProgramBuilder::TypesBuilder& ty); + +inline type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) { + return ty.i32(); +} +inline type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) { + return ty.u32(); +} +inline type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) { + return ty.f32(); +} +template +inline type::Type* ty_vec2(const ProgramBuilder::TypesBuilder& ty) { + return ty.vec2(); +} +template +inline type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) { + return ty.vec3(); +} +template +inline type::Type* ty_vec4(const ProgramBuilder::TypesBuilder& ty) { + return ty.vec4(); +} +template +inline type::Type* ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat2x2(); +} +template +inline type::Type* ty_mat2x3(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat2x3(); +} +template +inline type::Type* ty_mat2x4(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat2x4(); +} +template +inline type::Type* ty_mat3x2(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat3x2(); +} +template +inline type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat3x3(); +} +template +inline type::Type* ty_mat3x4(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat3x4(); +} +template +inline type::Type* ty_mat4x2(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat4x2(); +} +template +inline type::Type* ty_mat4x3(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat4x3(); +} +template +inline type::Type* ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat4x4(); +} + +using i32 = ProgramBuilder::i32; +using u32 = ProgramBuilder::u32; +using f32 = ProgramBuilder::f32; + +template +class HlslGeneratorImplTest_MemberAccessorBase : public BASE { + public: + void SetupStorageBuffer(ast::StructMemberList members) { + ProgramBuilder& b = *this; + + auto* s = + b.Structure("Data", members, {b.create()}); + + auto* ac_ty = + b.create(ast::AccessControl::kReadWrite, s); + + b.Global("data", ac_ty, ast::StorageClass::kStorage, nullptr, + ast::DecorationList{ + b.create(0), + b.create(1), + }); + } + + void SetupFunction(ast::StatementList statements) { + ProgramBuilder& b = *this; + b.Func("main", ast::VariableList{}, b.ty.void_(), statements, + ast::DecorationList{ + b.create(ast::PipelineStage::kVertex), + }); + } +}; + +using HlslGeneratorImplTest_MemberAccessor = + HlslGeneratorImplTest_MemberAccessorBase; + +template +using HlslGeneratorImplTest_MemberAccessorWithParam = + HlslGeneratorImplTest_MemberAccessorBase>; TEST_F(HlslGeneratorImplTest_MemberAccessor, EmitExpression_MemberAccessor) { auto* s = Structure("Data", {Member("mem", ty.f32())}); - auto* str_var = Global("str", s, ast::StorageClass::kPrivate); + Global("str", s, ast::StorageClass::kPrivate); auto* expr = MemberAccessor("str", "mem"); WrapInFunction(expr); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); - gen.register_global(str_var); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_EQ(result(), R"(struct Data { + float mem; +}; - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "str.mem"); +[numthreads(1, 1, 1)] +void test_function() { + float tint_symbol_5 = str.mem; + return; } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load) { - // struct Data { - // a : i32; - // b : f32; - // }; - // var data : Data; - // data.b; - // - // -> asfloat(data.Load(4)); - - auto* s = Structure("data", { - Member("a", ty.i32()), - Member("b", ty.f32()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* expr = MemberAccessor("data", "b"); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asfloat(data.Load(4))"); -} - -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_Int) { - // struct Data { - // a : i32; - // b : f32; - // }; - // var data : Data; - // data.a; - // - // -> asint(data.Load(0)); - - auto* s = Structure("data", { - Member("a", ty.i32()), - Member("b", ty.f32()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* expr = MemberAccessor("data", "a"); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asint(data.Load(0))"); -} - -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix) { - // struct Data { - // z : f32; - // a : mat2x3; - // }; - // var data : Data; - // mat2x3 b; - // data.a = b; - // - // -> float2x3 _tint_tmp = b; - // data.Store3(4 + 0, asuint(_tint_tmp[0])); - // data.Store3(4 + 16, asuint(_tint_tmp[1])); - - auto* s = - Structure("Data", {Member("z", ty.i32()), Member("a", ty.mat2x3())}); - - auto* b_var = Global("b", ty.mat2x3(), ast::StorageClass::kPrivate); - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* lhs = MemberAccessor("data", "a"); - auto* rhs = Expr("b"); - - auto* assign = create(lhs, rhs); - WrapInFunction(assign); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - gen.register_global(b_var); - - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ(result(), R"(float2x3 _tint_tmp = b; -data.Store3(16 + 0, asuint(_tint_tmp[0])); -data.Store3(16 + 16, asuint(_tint_tmp[1])); )"); } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix_Empty) { +struct TypeCase { + create_type_func_ptr member_type; + std::string expected; +}; +inline std::ostream& operator<<(std::ostream& out, TypeCase c) { + ProgramBuilder b; + auto* ty = c.member_type(b.ty); + out << ty->FriendlyName(b.Symbols()); + return out; +} + +using HlslGeneratorImplTest_MemberAccessor_StorageBufferLoad = + HlslGeneratorImplTest_MemberAccessorWithParam; +TEST_P(HlslGeneratorImplTest_MemberAccessor_StorageBufferLoad, Test) { + // struct Data { + // a : i32; + // b : ; + // }; + // var data : Data; + // data.b; + + auto p = GetParam(); + + SetupStorageBuffer({ + Member("a", ty.i32()), + Member("b", p.member_type(ty)), + }); + + SetupFunction({ + create(Var("x", nullptr, + ast::StorageClass::kFunction, + MemberAccessor("data", "b"))), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_THAT(result(), HasSubstr(p.expected)); + + Validate(); +} + +INSTANTIATE_TEST_SUITE_P( + HlslGeneratorImplTest_MemberAccessor, + HlslGeneratorImplTest_MemberAccessor_StorageBufferLoad, + testing::Values( + TypeCase{ty_u32, "data.Load(4u)"}, + TypeCase{ty_f32, "asfloat(data.Load(4u))"}, + TypeCase{ty_i32, "asint(data.Load(4u))"}, + TypeCase{ty_vec2, "data.Load2(8u)"}, + TypeCase{ty_vec2, "asfloat(data.Load2(8u))"}, + TypeCase{ty_vec2, "asint(data.Load2(8u))"}, + TypeCase{ty_vec3, "data.Load3(16u)"}, + TypeCase{ty_vec3, "asfloat(data.Load3(16u))"}, + TypeCase{ty_vec3, "asint(data.Load3(16u))"}, + TypeCase{ty_vec4, "data.Load4(16u)"}, + TypeCase{ty_vec4, "asfloat(data.Load4(16u))"}, + TypeCase{ty_vec4, "asint(data.Load4(16u))"}, + TypeCase{ + ty_mat2x2, + R"(return uint2x2(buffer.Load2((offset + 0u)), buffer.Load2((offset + 8u)));)"}, + TypeCase{ + ty_mat2x3, + R"(return float2x3(asfloat(buffer.Load3((offset + 0u))), asfloat(buffer.Load3((offset + 16u))));)"}, + TypeCase{ + ty_mat2x4, + R"(return int2x4(asint(buffer.Load4((offset + 0u))), asint(buffer.Load4((offset + 16u))));)"}, + TypeCase{ + ty_mat3x2, + R"(return uint3x2(buffer.Load2((offset + 0u)), buffer.Load2((offset + 8u)), buffer.Load2((offset + 16u)));)"}, + TypeCase{ + ty_mat3x3, + R"(return float3x3(asfloat(buffer.Load3((offset + 0u))), asfloat(buffer.Load3((offset + 16u))), asfloat(buffer.Load3((offset + 32u))));)"}, + TypeCase{ + ty_mat3x4, + R"(return int3x4(asint(buffer.Load4((offset + 0u))), asint(buffer.Load4((offset + 16u))), asint(buffer.Load4((offset + 32u))));)"}, + TypeCase{ + ty_mat4x2, + R"(return uint4x2(buffer.Load2((offset + 0u)), buffer.Load2((offset + 8u)), buffer.Load2((offset + 16u)), buffer.Load2((offset + 24u)));)"}, + TypeCase{ + ty_mat4x3, + R"(return float4x3(asfloat(buffer.Load3((offset + 0u))), asfloat(buffer.Load3((offset + 16u))), asfloat(buffer.Load3((offset + 32u))), asfloat(buffer.Load3((offset + 48u))));)"}, + TypeCase{ + ty_mat4x4, + R"(return int4x4(asint(buffer.Load4((offset + 0u))), asint(buffer.Load4((offset + 16u))), asint(buffer.Load4((offset + 32u))), asint(buffer.Load4((offset + 48u))));)"})); + +using HlslGeneratorImplTest_MemberAccessor_StorageBufferStore = + HlslGeneratorImplTest_MemberAccessorWithParam; +TEST_P(HlslGeneratorImplTest_MemberAccessor_StorageBufferStore, Test) { + // struct Data { + // a : i32; + // b : ; + // }; + // var data : Data; + // data.b = (); + + auto p = GetParam(); + + auto* type = p.member_type(ty); + + SetupStorageBuffer({ + Member("a", ty.i32()), + Member("b", type), + }); + + SetupFunction({ + create( + Var("value", type, ast::StorageClass::kFunction, Construct(type))), + Assign(MemberAccessor("data", "b"), Expr("value")), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_THAT(result(), HasSubstr(p.expected)); + + Validate(); +} + +INSTANTIATE_TEST_SUITE_P( + HlslGeneratorImplTest_MemberAccessor, + HlslGeneratorImplTest_MemberAccessor_StorageBufferStore, + testing::Values(TypeCase{ty_u32, "data.Store(4u, asuint(value))"}, + TypeCase{ty_f32, "data.Store(4u, asuint(value))"}, + TypeCase{ty_i32, "data.Store(4u, asuint(value))"}, + TypeCase{ty_vec2, "data.Store2(8u, asuint(value))"}, + TypeCase{ty_vec2, "data.Store2(8u, asuint(value))"}, + TypeCase{ty_vec2, "data.Store2(8u, asuint(value))"}, + TypeCase{ty_vec3, "data.Store3(16u, asuint(value))"}, + TypeCase{ty_vec3, "data.Store3(16u, asuint(value))"}, + TypeCase{ty_vec3, "data.Store3(16u, asuint(value))"}, + TypeCase{ty_vec4, "data.Store4(16u, asuint(value))"}, + TypeCase{ty_vec4, "data.Store4(16u, asuint(value))"}, + TypeCase{ty_vec4, "data.Store4(16u, asuint(value))"}, + TypeCase{ty_mat2x2, R"({ + buffer.Store2((offset + 0u), asuint(value[0u])); + buffer.Store2((offset + 8u), asuint(value[1u])); +})"}, + TypeCase{ty_mat2x3, R"({ + buffer.Store3((offset + 0u), asuint(value[0u])); + buffer.Store3((offset + 16u), asuint(value[1u])); +})"}, + TypeCase{ty_mat2x4, R"({ + buffer.Store4((offset + 0u), asuint(value[0u])); + buffer.Store4((offset + 16u), asuint(value[1u])); +})"}, + TypeCase{ty_mat3x2, R"({ + buffer.Store2((offset + 0u), asuint(value[0u])); + buffer.Store2((offset + 8u), asuint(value[1u])); + buffer.Store2((offset + 16u), asuint(value[2u])); +})"}, + TypeCase{ty_mat3x3, R"({ + buffer.Store3((offset + 0u), asuint(value[0u])); + buffer.Store3((offset + 16u), asuint(value[1u])); + buffer.Store3((offset + 32u), asuint(value[2u])); +})"}, + TypeCase{ty_mat3x4, R"({ + buffer.Store4((offset + 0u), asuint(value[0u])); + buffer.Store4((offset + 16u), asuint(value[1u])); + buffer.Store4((offset + 32u), asuint(value[2u])); +})"}, + TypeCase{ty_mat4x2, R"({ + buffer.Store2((offset + 0u), asuint(value[0u])); + buffer.Store2((offset + 8u), asuint(value[1u])); + buffer.Store2((offset + 16u), asuint(value[2u])); + buffer.Store2((offset + 24u), asuint(value[3u])); +})"}, + TypeCase{ty_mat4x3, R"({ + buffer.Store3((offset + 0u), asuint(value[0u])); + buffer.Store3((offset + 16u), asuint(value[1u])); + buffer.Store3((offset + 32u), asuint(value[2u])); + buffer.Store3((offset + 48u), asuint(value[3u])); +})"}, + TypeCase{ty_mat4x4, R"({ + buffer.Store4((offset + 0u), asuint(value[0u])); + buffer.Store4((offset + 16u), asuint(value[1u])); + buffer.Store4((offset + 32u), asuint(value[2u])); + buffer.Store4((offset + 48u), asuint(value[3u])); +})"})); + +TEST_F(HlslGeneratorImplTest_MemberAccessor, StorageBuffer_Store_Matrix_Empty) { // struct Data { // z : f32; // a : mat2x3; // }; // var data : Data; // data.a = mat2x3(); - // - // -> float2x3 _tint_tmp = float2x3(0.0f, 0.0f, 0.0f, - // 0.0f, 0.0f, 0.0f); - // data.Store3(16 + 0, asuint(_tint_tmp[0]); - // data.Store3(16 + 16, asuint(_tint_tmp[1])); - auto* s = - Structure("Data", {Member("z", ty.i32()), Member("a", ty.mat2x3())}); + SetupStorageBuffer({ + Member("a", ty.i32()), + Member("b", ty.mat2x3()), + }); - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); + SetupFunction({ + Assign(MemberAccessor("data", "b"), + Construct(ty.mat2x3(), ast::ExpressionList{})), + }); - auto* lhs = MemberAccessor("data", "a"); - auto* rhs = Construct(ty.mat2x3(), ast::ExpressionList{}); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* assign = create(lhs, rhs); - WrapInFunction(assign); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - GeneratorImpl& gen = Build(); +void tint_symbol_8(RWByteAddressBuffer buffer, uint offset, float2x3 value) { + buffer.Store3((offset + 0u), asuint(value[0u])); + buffer.Store3((offset + 16u), asuint(value[1u])); +} - gen.register_global(coord_var); +void main() { + tint_symbol_8(data, 16u, float2x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)); + return; +} - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ( - result(), - R"(float2x3 _tint_tmp = float2x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); -data.Store3(16 + 0, asuint(_tint_tmp[0])); -data.Store3(16 + 16, asuint(_tint_tmp[1])); -)"); +)"; + EXPECT_EQ(result(), expected); + + Validate(); } TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix) { - // struct Data { - // z : f32; - // a : mat3x2; - // }; - // var data : Data; - // data.a; - // - // -> asfloat(uint2x3(data.Load2(4 + 0), data.Load2(4 + 8), - // data.Load2(4 + 16))); - - auto* s = - Structure("Data", {Member("z", ty.i32()), Member("a", ty.mat3x2())}); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* expr = MemberAccessor("data", "a"); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), - "asfloat(uint2x3(data.Load2(8 + 0), data.Load2(8 + 8), " - "data.Load2(8 + 16)))"); -} - -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Nested) { - // struct Data { - // z : f32; - // a : mat2x3 - // }; - // var data : Outer; - // data.b.a; - // - // -> asfloat(uint3x2(data.Load3(4 + 0), data.Load3(16 + 16))); - - auto* s = Structure("Data", { - Member("z", ty.i32()), - Member("a", ty.mat2x3()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* expr = MemberAccessor("data", "a"); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), - "asfloat(uint3x2(data.Load3(16 + 0), data.Load3(16 + 16)))"); -} - -TEST_F( - HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_By3_Is_16_Bytes) { - // struct Data { - // a : mat3x3 - // }; - // var data : Data; - // data.a; - // - // -> asfloat(uint3x3(data.Load3(0), data.Load3(16), - // data.Load3(32))); - - auto* s = Structure("Data", { - Member("a", ty.mat3x3()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* expr = MemberAccessor("data", "a"); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), - "asfloat(uint3x3(data.Load3(0 + 0), data.Load3(0 + 16), " - "data.Load3(0 + 32)))"); -} - -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Single_Element) { + StorageBuffer_Load_Matrix_Single_Element) { // struct Data { // z : f32; // a : mat4x3; // }; // var data : Data; // data.a[2][1]; - // - // -> asfloat(data.Load((2 * 16) + (1 * 4) + 16))) - auto* s = Structure("Data", { - Member("z", ty.i32()), - Member("a", ty.mat4x3()), - }); + SetupStorageBuffer({ + Member("z", ty.f32()), + Member("a", ty.mat4x3()), + }); - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); + SetupFunction({ + create( + Var("x", nullptr, ast::StorageClass::kFunction, + IndexAccessor(IndexAccessor(MemberAccessor("data", "a"), 2), 1))), + }); - auto* expr = IndexAccessor( - IndexAccessor(MemberAccessor("data", "a"), Expr(2)), Expr(1)); - WrapInFunction(expr); + GeneratorImpl& gen = SanitizeAndBuild(); - GeneratorImpl& gen = Build(); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - gen.register_global(coord_var); +void main() { + float x = asfloat(data.Load(52u)); + return; +} - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asfloat(data.Load((4 * 1) + (16 * 2) + 16))"); +)"; + EXPECT_EQ(result(), expected); + + Validate(); } TEST_F(HlslGeneratorImplTest_MemberAccessor, @@ -298,26 +414,32 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor, // }; // var data : Data; // data.a[2]; - // - // -> asint(data.Load((2 * 4)); - type::Array ary(ty.i32(), 5, - ast::DecorationList{ - create(4), - }); - auto* s = Structure("Data", {Member("a", &ary)}); + SetupStorageBuffer({ + Member("z", ty.f32()), + Member("a", ty.array(4)), + }); - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); + SetupFunction({ + create( + Var("x", nullptr, ast::StorageClass::kFunction, + IndexAccessor(MemberAccessor("data", "a"), 2))), + }); - auto* expr = IndexAccessor(MemberAccessor("data", "a"), Expr(2)); - WrapInFunction(expr); + GeneratorImpl& gen = SanitizeAndBuild(); - GeneratorImpl& gen = Build(); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - gen.register_global(coord_var); +void main() { + int x = asint(data.Load(12u)); + return; +} - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asint(data.Load((4 * 2) + 0))"); +)"; + EXPECT_EQ(result(), expected); } TEST_F(HlslGeneratorImplTest_MemberAccessor, @@ -327,473 +449,374 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor, // }; // var data : Data; // data.a[(2 + 4) - 3]; - // - // -> asint(data.Load((4 * ((2 + 4) - 3))); - type::Array ary(ty.i32(), 5, - ast::DecorationList{ - create(4), - }); - auto* s = Structure("Data", {Member("a", &ary)}); + SetupStorageBuffer({ + Member("z", ty.f32()), + Member("a", ty.array(4)), + }); - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); + SetupFunction({ + create( + Var("x", nullptr, ast::StorageClass::kFunction, + IndexAccessor(MemberAccessor("data", "a"), + Sub(Add(2, Expr(4)), Expr(3))))), + }); - auto* expr = IndexAccessor(MemberAccessor("data", "a"), - Sub(Add(Expr(2), Expr(4)), Expr(3))); - WrapInFunction(expr); + GeneratorImpl& gen = SanitizeAndBuild(); - GeneratorImpl& gen = Build(); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asint(data.Load((4 * ((2 + 4) - 3)) + 0))"); +void main() { + int x = asint(data.Load((4u + (4u * uint(((2 + 4) - 3)))))); + return; } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store) { - // struct Data { - // a : i32; - // b : f32; - // }; - // var data : Data; - // data.b = 2.3f; - // - // -> data.Store(0, asuint(2.0f)); - - auto* s = Structure("data", { - Member("a", ty.i32()), - Member("b", ty.f32()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* lhs = MemberAccessor("data", "b"); - auto* rhs = Expr(2.0f); - auto* assign = create(lhs, rhs); - WrapInFunction(assign); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ(result(), R"(data.Store(4, asuint(2.0f)); -)"); +)"; + EXPECT_EQ(result(), expected); } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store_ToArray) { +TEST_F(HlslGeneratorImplTest_MemberAccessor, StorageBuffer_Store_ToArray) { // struct Data { // a : [[stride(4)]] array; // }; // var data : Data; // data.a[2] = 2; - // - // -> data.Store((2 * 4), asuint(2.3f)); - type::Array ary(ty.i32(), 5, - ast::DecorationList{ - create(4), - }); + SetupStorageBuffer({ + Member("z", ty.f32()), + Member("a", ty.array(4)), + }); - auto* s = Structure("Data", {Member("a", &ary)}); + SetupFunction({ + Assign(IndexAccessor(MemberAccessor("data", "a"), 2), 2), + }); - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* lhs = IndexAccessor(MemberAccessor("data", "a"), Expr(2)); - auto* rhs = Expr(2); - auto* assign = create(lhs, rhs); - WrapInFunction(assign); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ(result(), R"(data.Store((4 * 2) + 0, asuint(2)); -)"); +void main() { + data.Store(12u, asuint(2)); + return; } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store_Int) { - // struct Data { - // a : i32; - // b : f32; - // }; - // var data : Data; - // data.a = 2; - // - // -> data.Store(0, asuint(2)); - - auto* s = Structure("data", { - Member("a", ty.i32()), - Member("b", ty.f32()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* lhs = MemberAccessor("data", "a"); - auto* rhs = Expr(2); - auto* assign = create(lhs, rhs); - WrapInFunction(assign); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ(result(), R"(data.Store(0, asuint(2)); -)"); +)"; + EXPECT_EQ(result(), expected); } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_Vec3) { - // struct Data { +TEST_F(HlslGeneratorImplTest_MemberAccessor, StorageBuffer_Load_MultiLevel) { + // struct Inner { // a : vec3; // b : vec3; // }; - // var data : Data; - // data.b; - // - // -> asfloat(data.Load(16)); - - auto* s = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* expr = MemberAccessor("data", "b"); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asfloat(data.Load3(16))"); -} - -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store_Vec3) { // struct Data { - // a : vec3; - // b : vec3; - // }; - // var data : Data; - // data.b = vec3(2.3f, 1.2f, 0.2f); - // - // -> data.Store(16, asuint(float3(2.3f, 1.2f, 0.2f))); - - auto* s = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); - - auto* coord_var = Global("data", s, ast::StorageClass::kStorage); - - auto* lhs = MemberAccessor("data", "b"); - auto* rhs = vec3(1.f, 2.f, 3.f); - - auto* assign = create(lhs, rhs); - - WrapInFunction(assign); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ(result(), - R"(data.Store3(16, asuint(float3(1.0f, 2.0f, 3.0f))); -)"); -} - -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel) { - // struct Data { - // a : vec3; - // b : vec3; - // }; - // struct Pre { - // var c : [[stride(32)]] array; + // var c : [[stride(32)]] array; // }; // // var data : Pre; // data.c[2].b - // - // -> asfloat(data.Load3(16 + (2 * 32))) - auto* data = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); + auto* inner = Structure("Inner", { + Member("a", ty.vec3()), + Member("b", ty.vec3()), + }); - type::Array ary(data, 4, - ast::DecorationList{ - create(32), - }); + SetupStorageBuffer({ + Member("c", ty.array(inner, 4, 32)), + }); - auto* pre_struct = Structure("Pre", {Member("c", &ary)}); + SetupFunction({ + create(Var( + "x", nullptr, ast::StorageClass::kFunction, + MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), 2), "b"))), + }); - auto* coord_var = Global("data", pre_struct, ast::StorageClass::kStorage); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* expr = - MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), Expr(2)), "b"); - WrapInFunction(expr); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - GeneratorImpl& gen = Build(); +void main() { + float3 x = asfloat(data.Load3(80u)); + return; +} - gen.register_global(coord_var); +)"; + EXPECT_EQ(result(), expected); - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asfloat(data.Load3(16 + (32 * 2) + 0))"); + Validate(); } TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Swizzle) { - // struct Data { + StorageBuffer_Load_MultiLevel_Swizzle) { + // struct Inner { // a : vec3; // b : vec3; // }; - // struct Pre { - // var c : [[stride(32)]] array; + // struct Data { + // var c : [[stride(32)]] array; // }; // // var data : Pre; // data.c[2].b.xy - // - // -> asfloat(data.Load3(16 + (2 * 32))).xy - auto* data = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); + auto* inner = Structure("Inner", { + Member("a", ty.vec3()), + Member("b", ty.vec3()), + }); - type::Array ary(data, 4, - ast::DecorationList{create(32)}); + SetupStorageBuffer({ + Member("c", ty.array(inner, 4, 32)), + }); - auto* pre_struct = Structure("Pre", {Member("c", &ary)}); + SetupFunction({ + create( + Var("x", nullptr, ast::StorageClass::kFunction, + MemberAccessor( + MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), 2), + "b"), + "xy"))), + }); - auto* coord_var = Global("data", pre_struct, ast::StorageClass::kStorage); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* expr = MemberAccessor( - MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), Expr(2)), "b"), - "xy"); - WrapInFunction(expr); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asfloat(data.Load3(16 + (32 * 2) + 0)).xy"); +void main() { + float2 x = asfloat(data.Load3(80u)).xy; + return; } -TEST_F( - HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Swizzle_SingleLetter) { // NOLINT - // struct Data { +)"; + EXPECT_EQ(result(), expected); + + Validate(); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, + StorageBuffer_Load_MultiLevel_Swizzle_SingleLetter) { // NOLINT + // struct Inner { // a : vec3; // b : vec3; // }; - // struct Pre { - // var c : [[stride(32)]] array; + // struct Data { + // var c : [[stride(32)]] array; // }; // // var data : Pre; // data.c[2].b.g - // - // -> asfloat(data.Load((4 * 1) + 16 + (2 * 32) + 0)) - auto* data = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); + auto* inner = Structure("Inner", { + Member("a", ty.vec3()), + Member("b", ty.vec3()), + }); - type::Array ary(data, 4, - ast::DecorationList{ - create(32), - }); + SetupStorageBuffer({ + Member("c", ty.array(inner, 4, 32)), + }); - auto* pre_struct = Structure("Pre", {Member("c", &ary)}); + SetupFunction({ + create( + Var("x", nullptr, ast::StorageClass::kFunction, + MemberAccessor( + MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), 2), + "b"), + "g"))), + }); - auto* coord_var = Global("data", pre_struct, ast::StorageClass::kStorage); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* expr = MemberAccessor( - MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), Expr(2)), "b"), - "g"); - WrapInFunction(expr); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - GeneratorImpl& gen = Build(); +void main() { + float x = asfloat(data.Load(84u)); + return; +} - gen.register_global(coord_var); +)"; + EXPECT_EQ(result(), expected); - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asfloat(data.Load((4 * 1) + 16 + (32 * 2) + 0))"); + Validate(); } TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Index) { - // struct Data { + StorageBuffer_Load_MultiLevel_Index) { + // struct Inner { // a : vec3; // b : vec3; // }; - // struct Pre { - // var c : [[stride(32)]] array; + // struct Data { + // var c : [[stride(32)]] array; // }; // // var data : Pre; // data.c[2].b[1] - // - // -> asfloat(data.Load(4 + 16 + (2 * 32))) - auto* data = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); + auto* inner = Structure("Inner", { + Member("a", ty.vec3()), + Member("b", ty.vec3()), + }); - type::Array ary(data, 4, - ast::DecorationList{ - create(32), - }); + SetupStorageBuffer({ + Member("c", ty.array(inner, 4, 32)), + }); - auto* pre_struct = Structure("Pre", {Member("c", &ary)}); + SetupFunction({ + create(Var( + "x", nullptr, ast::StorageClass::kFunction, + IndexAccessor(MemberAccessor( + IndexAccessor(MemberAccessor("data", "c"), 2), "b"), + 1))), + }); - auto* coord_var = Global("data", pre_struct, ast::StorageClass::kStorage); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* expr = IndexAccessor( - MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), Expr(2)), "b"), - Expr(1)); - WrapInFunction(expr); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "asfloat(data.Load((4 * 1) + 16 + (32 * 2) + 0))"); +void main() { + float x = asfloat(data.Load(84u)); + return; } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store_MultiLevel) { - // struct Data { +)"; + EXPECT_EQ(result(), expected); + + Validate(); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, StorageBuffer_Store_MultiLevel) { + // struct Inner { // a : vec3; // b : vec3; // }; - // struct Pre { - // var c : [[stride(32)]] array; + // struct Data { + // var c : [[stride(32)]] array; // }; // // var data : Pre; // data.c[2].b = vec3(1.f, 2.f, 3.f); - // - // -> data.Store3(16 + (2 * 32), asuint(float3(1.0f, 2.0f, 3.0f))); - auto* data = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); + auto* inner = Structure("Inner", { + Member("a", ty.vec3()), + Member("b", ty.vec3()), + }); - type::Array ary(data, 4, - ast::DecorationList{ - create(32), - }); + SetupStorageBuffer({ + Member("c", ty.array(inner, 4, 32)), + }); - auto* pre_struct = Structure("Pre", {Member("c", &ary)}); + SetupFunction({ + Assign(MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), 2), "b"), + vec3(1.f, 2.f, 3.f)), + }); - auto* coord_var = Global("data", pre_struct, ast::StorageClass::kStorage); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* lhs = - MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), Expr(2)), "b"); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - auto* assign = - create(lhs, vec3(1.f, 2.f, 3.f)); +void main() { + data.Store3(80u, asuint(float3(1.0f, 2.0f, 3.0f))); + return; +} - WrapInFunction(assign); +)"; + EXPECT_EQ(result(), expected); - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ(result(), - R"(data.Store3(16 + (32 * 2) + 0, asuint(float3(1.0f, 2.0f, 3.0f))); -)"); + Validate(); } TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_StorageBuffer_Store_Swizzle_SingleLetter) { - // struct Data { + StorageBuffer_Store_Swizzle_SingleLetter) { + // struct Inner { // a : vec3; // b : vec3; // }; - // struct Pre { - // var c : [[stride(32)]] array; + // struct Data { + // var c : [[stride(32)]] array; // }; // // var data : Pre; // data.c[2].b.y = 1.f; - // - // -> data.Store((4 * 1) + 16 + (2 * 32) + 0, asuint(1.0f)); - auto* data = Structure("Data", { - Member("a", ty.vec3()), - Member("b", ty.vec3()), - }); + auto* inner = Structure("Inner", { + Member("a", ty.vec3()), + Member("b", ty.vec3()), + }); - type::Array ary(data, 4, - ast::DecorationList{ - create(32), - }); + SetupStorageBuffer({ + Member("c", ty.array(inner, 4, 32)), + }); - auto* pre_struct = Structure("Pre", {Member("c", &ary)}); + SetupFunction({ + Assign(MemberAccessor( + MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), 2), + "b"), + "y"), + Expr(1.f)), + }); - auto* coord_var = Global("data", pre_struct, ast::StorageClass::kStorage); + GeneratorImpl& gen = SanitizeAndBuild(); - auto* lhs = MemberAccessor( - MemberAccessor(IndexAccessor(MemberAccessor("data", "c"), Expr(2)), "b"), - "y"); - auto* rhs = Expr(1.f); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + auto* expected = + R"( +RWByteAddressBuffer data : register(u0, space1); - auto* assign = create(lhs, rhs); - - WrapInFunction(assign); - - GeneratorImpl& gen = Build(); - - gen.register_global(coord_var); - - ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error(); - EXPECT_EQ(result(), - R"(data.Store((4 * 1) + 16 + (32 * 2) + 0, asuint(1.0f)); -)"); +void main() { + data.Store(84u, asuint(1.0f)); + return; } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_Swizzle_xyz) { - Global("my_vec", ty.vec4(), ast::StorageClass::kPrivate); +)"; + EXPECT_EQ(result(), expected); + Validate(); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, Swizzle_xyz) { + auto* var = Var("my_vec", ty.vec4(), ast::StorageClass::kFunction, + vec4(1.f, 2.f, 3.f, 4.f)); auto* expr = MemberAccessor("my_vec", "xyz"); - WrapInFunction(expr); + WrapInFunction(var, expr); - GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "my_vec.xyz"); + GeneratorImpl& gen = SanitizeAndBuild(); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_THAT(result(), HasSubstr("my_vec.xyz")); + + Validate(); } -TEST_F(HlslGeneratorImplTest_MemberAccessor, - EmitExpression_MemberAccessor_Swizzle_gbr) { - Global("my_vec", ty.vec4(), ast::StorageClass::kPrivate); - +TEST_F(HlslGeneratorImplTest_MemberAccessor, Swizzle_gbr) { + auto* var = Var("my_vec", ty.vec4(), ast::StorageClass::kFunction, + vec4(1.f, 2.f, 3.f, 4.f)); auto* expr = MemberAccessor("my_vec", "gbr"); - WrapInFunction(expr); + WrapInFunction(var, expr); - GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error(); - EXPECT_EQ(result(), "my_vec.gbr"); + GeneratorImpl& gen = SanitizeAndBuild(); + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_THAT(result(), HasSubstr("my_vec.gbr")); + + Validate(); } } // namespace diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc index 8766b91181..229931d0fc 100644 --- a/src/writer/hlsl/generator_impl_sanitizer_test.cc +++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "src/ast/stage_decoration.h" +#include "src/ast/struct_block_decoration.h" #include "src/ast/variable_decl_statement.h" +#include "src/type/access_control_type.h" #include "src/writer/hlsl/test_helper.h" namespace tint { @@ -23,6 +25,54 @@ namespace { using HlslSanitizerTest = TestHelper; +TEST_F(HlslSanitizerTest, ArrayLength) { + auto* sb_ty = Structure("SB", + { + Member("x", ty.f32()), + Member("arr", ty.array(ty.vec4())), + }, + { + create(), + }); + auto* ac_ty = + create(ast::AccessControl::kReadOnly, sb_ty); + + Global("sb", ac_ty, ast::StorageClass::kStorage, nullptr, + ast::DecorationList{ + create(0), + create(1), + }); + + Func("main", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create( + Var("len", ty.u32(), ast::StorageClass::kFunction, + Call("arrayLength", MemberAccessor("sb", "arr")))), + }, + ast::DecorationList{ + create(ast::PipelineStage::kVertex), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + + auto got = result(); + auto* expect = R"( +ByteAddressBuffer sb : register(t0, space1); + +void main() { + uint tint_symbol_9 = 0u; + sb.GetDimensions(tint_symbol_9); + const uint tint_symbol_10 = ((tint_symbol_9 - 16u) / 16u); + uint len = tint_symbol_10; + return; +} + +)"; + EXPECT_EQ(expect, got); +} + TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) { auto* array_init = array(1, 2, 3, 4); auto* array_index = IndexAccessor(array_init, 3); diff --git a/src/writer/hlsl/generator_impl_type_test.cc b/src/writer/hlsl/generator_impl_type_test.cc index 1b116fa7f6..fb871e4c0c 100644 --- a/src/writer/hlsl/generator_impl_type_test.cc +++ b/src/writer/hlsl/generator_impl_type_test.cc @@ -37,7 +37,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Alias) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, alias, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, alias, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "alias"); } @@ -46,7 +47,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Alias_NameCollision) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, alias, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, alias, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "bool_tint_0"); } @@ -55,7 +57,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Array) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary")) + << gen.error(); EXPECT_EQ(result(), "bool ary[4]"); } @@ -64,7 +67,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_ArrayOfArray) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary")) + << gen.error(); EXPECT_EQ(result(), "bool ary[5][4]"); } @@ -75,7 +79,8 @@ TEST_F(HlslGeneratorImplTest_Type, GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary")) + << gen.error(); EXPECT_EQ(result(), "bool ary[5][4][1]"); } @@ -84,7 +89,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_ArrayOfArrayOfArray) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary")) + << gen.error(); EXPECT_EQ(result(), "bool ary[6][5][4]"); } @@ -93,7 +99,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Array_NameCollision) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "bool")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "bool")) + << gen.error(); EXPECT_EQ(result(), "bool bool_tint_0[4]"); } @@ -102,7 +109,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Array_WithoutName) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "bool[4]"); } @@ -111,7 +119,8 @@ TEST_F(HlslGeneratorImplTest_Type, DISABLED_EmitType_RuntimeArray) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary")) + << gen.error(); EXPECT_EQ(result(), "bool ary[]"); } @@ -121,7 +130,8 @@ TEST_F(HlslGeneratorImplTest_Type, GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, arr, "double")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "double")) + << gen.error(); EXPECT_EQ(result(), "bool double_tint_0[]"); } @@ -130,7 +140,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Bool) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, bool_, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, bool_, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "bool"); } @@ -139,7 +150,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_F32) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, f32, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, f32, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "float"); } @@ -148,7 +160,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_I32) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, i32, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, i32, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "int"); } @@ -157,7 +170,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Matrix) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, mat2x3, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "float2x3"); } @@ -167,7 +181,8 @@ TEST_F(HlslGeneratorImplTest_Type, DISABLED_EmitType_Pointer) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, &p, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, &p, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "float*"); } @@ -210,7 +225,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, s, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, s, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "S"); } @@ -227,7 +243,8 @@ TEST_F(HlslGeneratorImplTest_Type, DISABLED_EmitType_Struct_InjectPadding) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, s, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, s, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(gen.result(), R"(struct S { int a; int8_t pad_0[28]; @@ -280,7 +297,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_U32) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, u32, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, u32, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "uint"); } @@ -289,7 +307,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Vector) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, vec3, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, vec3, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "float3"); } @@ -298,7 +317,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Void) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, void_, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, void_, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "void"); } @@ -307,7 +327,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitSampler) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, &sampler, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, &sampler, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "SamplerState"); } @@ -316,7 +337,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitSamplerComparison) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, &sampler, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, &sampler, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "SamplerComparisonState"); } @@ -419,7 +441,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitMultisampledTexture) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitType(out, &s, "")) << gen.error(); + ASSERT_TRUE(gen.EmitType(out, &s, ast::StorageClass::kNone, "")) + << gen.error(); EXPECT_EQ(result(), "Texture2DMS"); }