diff --git a/src/ast/struct.cc b/src/ast/struct.cc index f0b2a9c1e2..01344b6c26 100644 --- a/src/ast/struct.cc +++ b/src/ast/struct.cc @@ -31,6 +31,15 @@ Struct::Struct(Struct&&) = default; Struct::~Struct() = default; +StructMember* Struct::get_member(const std::string& name) const { + for (auto& mem : members_) { + if (mem->name() == name) { + return mem.get(); + } + } + return nullptr; +} + bool Struct::IsValid() const { for (const auto& mem : members_) { if (mem == nullptr || !mem->IsValid()) { diff --git a/src/ast/struct.h b/src/ast/struct.h index 9705d0f7a8..e12cfd5881 100644 --- a/src/ast/struct.h +++ b/src/ast/struct.h @@ -59,6 +59,11 @@ class Struct : public Node { /// @returns the members const StructMemberList& members() const { return members_; } + /// Returns the struct member with the given name or nullptr if non exists. + /// @param name the name of the member + /// @returns the struct member or nullptr if not found + StructMember* get_member(const std::string& name) const; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc index f0ce746b82..447fed9ffe 100644 --- a/src/ast/struct_member.cc +++ b/src/ast/struct_member.cc @@ -14,6 +14,8 @@ #include "src/ast/struct_member.h" +#include "src/ast/struct_member_offset_decoration.h" + namespace tint { namespace ast { @@ -37,6 +39,24 @@ StructMember::StructMember(StructMember&&) = default; StructMember::~StructMember() = default; +bool StructMember::has_offset_decoration() const { + for (const auto& deco : decorations_) { + if (deco->IsOffset()) { + return true; + } + } + return false; +} + +uint32_t StructMember::offset() const { + for (const auto& deco : decorations_) { + if (deco->IsOffset()) { + return deco->AsOffset()->offset(); + } + } + return 0; +} + bool StructMember::IsValid() const { if (name_.empty() || type_ == nullptr) { return false; diff --git a/src/ast/struct_member.h b/src/ast/struct_member.h index fb21ea80ee..91a06217d3 100644 --- a/src/ast/struct_member.h +++ b/src/ast/struct_member.h @@ -72,6 +72,11 @@ class StructMember : public Node { /// @returns the decorations const StructMemberDecorationList& decorations() const { return decorations_; } + /// @returns true if the struct member has an offset decoration + bool has_offset_decoration() const; + /// @returns the offset decoration value. + uint32_t offset() const; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 6cf1791578..6af10e0787 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -424,7 +424,9 @@ bool TypeDeterminer::DetermineArrayAccessor( ret = ctx_.type_mgr().Get( std::make_unique(m->type(), m->rows())); } else { - set_error(expr->source(), "invalid parent type in array accessor"); + set_error(expr->source(), "invalid parent type (" + + parent_type->type_name() + + ") in array accessor"); return false; } diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 0ae0e9491d..cb9b824bde 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -65,6 +65,37 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { return stmts->last()->IsBreak() || stmts->last()->IsFallthrough(); } +std::string get_buffer_name(ast::Expression* expr) { + for (;;) { + if (expr->IsIdentifier()) { + return expr->AsIdentifier()->name(); + } else if (expr->IsMemberAccessor()) { + expr = expr->AsMemberAccessor()->structure(); + } else if (expr->IsArrayAccessor()) { + expr = expr->AsArrayAccessor()->array(); + } else { + break; + } + } + return ""; +} + +uint32_t convert_swizzle_to_index(const std::string& swizzle) { + if (swizzle == "r" || swizzle == "x") { + return 0; + } + if (swizzle == "g" || swizzle == "y") { + return 1; + } + if (swizzle == "b" || swizzle == "z") { + return 2; + } + if (swizzle == "a" || swizzle == "w") { + return 3; + } + return 0; +} + } // namespace GeneratorImpl::GeneratorImpl(ast::Module* module) : module_(module) {} @@ -73,7 +104,7 @@ GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate() { for (const auto& global : module_->global_variables()) { - global_variables_.set(global->name(), global.get()); + register_global(global.get()); } for (auto* const alias : module_->alias_types()) { @@ -114,6 +145,10 @@ bool GeneratorImpl::Generate() { return true; } +void GeneratorImpl::register_global(ast::Variable* global) { + global_variables_.set(global->name(), global); +} + std::string GeneratorImpl::generate_name(const std::string& prefix) { std::string name = prefix; uint32_t i = 0; @@ -166,6 +201,11 @@ bool GeneratorImpl::EmitAliasType(const ast::type::AliasType* alias) { } bool GeneratorImpl::EmitArrayAccessor(ast::ArrayAccessorExpression* expr) { + // Handle writing into a storage buffer array + if (is_storage_buffer_access(expr)) { + return EmitStorageBufferAccessor(expr, nullptr); + } + if (!EmitExpression(expr->array())) { return false; } @@ -201,6 +241,28 @@ bool GeneratorImpl::EmitAs(ast::AsExpression* expr) { bool GeneratorImpl::EmitAssign(ast::AssignmentStatement* stmt) { make_indent(); + // If the LHS is an accessor into a storage buffer then we have to + // emit a Store operation instead of an ='s. + if (stmt->lhs()->IsMemberAccessor()) { + auto* mem = stmt->lhs()->AsMemberAccessor(); + if (is_storage_buffer_access(mem)) { + if (!EmitStorageBufferAccessor(mem, stmt->rhs())) { + return false; + } + out_ << ";" << std::endl; + return true; + } + } else if (stmt->lhs()->IsArrayAccessor()) { + auto* ary = stmt->lhs()->AsArrayAccessor(); + if (is_storage_buffer_access(ary)) { + if (!EmitStorageBufferAccessor(ary, stmt->rhs())) { + return false; + } + out_ << ";" << std::endl; + return true; + } + } + if (!EmitExpression(stmt->lhs())) { return false; } @@ -1108,6 +1170,19 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) { out_ << std::endl; } + bool emitted_storagebuffer = false; + for (auto data : func->referenced_storagebuffer_variables()) { + auto* var = data.first; + auto* binding = data.second.binding; + + out_ << "RWByteAddressBuffer " << var->name() << " : register(u" + << binding->value() << ");" << std::endl; + emitted_storagebuffer = true; + } + if (emitted_storagebuffer) { + out_ << std::endl; + } + auto ep_name = ep->name(); if (ep_name.empty()) { ep_name = ep->function_name(); @@ -1396,7 +1471,188 @@ bool GeneratorImpl::EmitLoop(ast::LoopStatement* stmt) { return true; } +// TODO(dsinclair): This currently only handles loading of 4, 8, 12 or 16 byte +// members. If we need to support larger we'll need to do the loading into +// chunks. +// +// TODO(dsinclair): Need to support loading through a pointer. The pointer is +// just a memory address in the storage buffer, so need to do the correct +// calculation. +bool GeneratorImpl::EmitStorageBufferAccessor(ast::Expression* expr, + ast::Expression* rhs) { + auto* result_type = expr->result_type()->UnwrapAliasPtrAlias(); + std::string access_method = rhs != nullptr ? "Store" : "Load"; + if (result_type->IsVector()) { + access_method += std::to_string(result_type->AsVector()->size()); + } + + // If we aren't storing then we need to put in the outer cast. + if (rhs == nullptr) { + if (result_type->is_float_scalar_or_vector()) { + out_ << "asfloat("; + } else if (result_type->is_signed_scalar_or_vector()) { + out_ << "asint("; + } else if (result_type->is_unsigned_scalar_or_vector()) { + out_ << "asuint("; + } + } + + auto buffer_name = get_buffer_name(expr); + if (buffer_name.empty()) { + error_ = "error emitting storage buffer access"; + return false; + } + out_ << buffer_name << "." << access_method << "("; + + auto* ptr = expr; + bool first = true; + for (;;) { + if (ptr->IsIdentifier()) { + break; + } + + if (!first) { + out_ << " + "; + } + first = false; + if (ptr->IsMemberAccessor()) { + auto* mem = ptr->AsMemberAccessor(); + auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias(); + + if (res_type->IsStruct()) { + auto* str_type = res_type->AsStruct()->impl(); + auto* str_member = str_type->get_member(mem->member()->name()); + + if (!str_member->has_offset_decoration()) { + error_ = "missing offset decoration for struct member"; + return false; + } + out_ << str_member->offset(); + } else if (res_type->IsVector()) { + // This must be a single element swizzle if we've got a vector at this + // point. + if (mem->member()->name().size() != 1) { + error_ = + "Encountered multi-element swizzle when should have only one " + "level"; + return false; + } + + // TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32) + // so this is assuming 4. This will need to be fixed when we get f16 or + // f64 types. + out_ << "(4 * " << convert_swizzle_to_index(mem->member()->name()) + << ")"; + } else { + error_ = + "Invalid result type for member accessor: " + res_type->type_name(); + return false; + } + + ptr = mem->structure(); + } else if (ptr->IsArrayAccessor()) { + auto* ary = ptr->AsArrayAccessor(); + auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias(); + + out_ << "("; + // TODO(dsinclair): Handle matrix case and struct case. + if (ary_type->IsArray()) { + out_ << ary_type->AsArray()->array_stride(); + } else if (ary_type->IsVector()) { + // TODO(dsinclair): This is a hack. Our vectors can only be f32, i32 + // or u32 which are all 4 bytes. When we get f16 or other types we'll + // have to ask the type for the byte size. + out_ << "4"; + } else { + error_ = "Invalid array type in storage buffer access"; + return false; + } + out_ << " * "; + if (!EmitExpression(ary->idx_expr())) { + return false; + } + out_ << ")"; + + ptr = ary->array(); + } else { + error_ = "error emitting storage buffer access"; + return false; + } + } + + if (rhs != nullptr) { + out_ << ", asuint("; + if (!EmitExpression(rhs)) { + return false; + } + out_ << ")"; + } + + out_ << ")"; + + // Close the outer cast. + if (rhs == nullptr) { + out_ << ")"; + } + + return true; +} + +bool GeneratorImpl::is_storage_buffer_access( + ast::ArrayAccessorExpression* expr) { + // We only care about array so we can get to the next part of the expression. + // If it isn't an array or a member accessor we can stop looking as it won't + // be a storage buffer. + auto* ary = expr->array(); + if (ary->IsMemberAccessor()) { + return is_storage_buffer_access(ary->AsMemberAccessor()); + } else if (ary->IsArrayAccessor()) { + return is_storage_buffer_access(ary->AsArrayAccessor()); + } + return false; +} + +bool GeneratorImpl::is_storage_buffer_access( + ast::MemberAccessorExpression* expr) { + auto* structure = expr->structure(); + auto* data_type = structure->result_type()->UnwrapAliasPtrAlias(); + // If the data is a multi-element swizzle then we will not load the swizzle + // portion through the Load command. + if (data_type->IsVector() && expr->member()->name().size() > 1) { + return false; + } + + // Check if this is a storage buffer variable + if (structure->IsIdentifier()) { + auto* ident = expr->structure()->AsIdentifier(); + if (ident->has_path()) { + return false; + } + + ast::Variable* var = nullptr; + if (!global_variables_.get(ident->name(), &var)) { + return false; + } + return var->storage_class() == ast::StorageClass::kStorageBuffer; + } else if (structure->IsMemberAccessor()) { + return is_storage_buffer_access(structure->AsMemberAccessor()); + } else if (structure->IsArrayAccessor()) { + return is_storage_buffer_access(structure->AsArrayAccessor()); + } + + // Technically I don't think this is possible, but if we don't have a struct + // or array accessor then we can't have a storage buffer I believe. + return false; +} + bool GeneratorImpl::EmitMemberAccessor(ast::MemberAccessorExpression* expr) { + // Look for storage buffer accesses as we have to convert them into Load + // expressions. Stores will be identified in the assignment emission and a + // member accessor store of a storage buffer will not get here. + if (is_storage_buffer_access(expr)) { + return EmitStorageBufferAccessor(expr, nullptr); + } + if (!EmitExpression(expr->structure())) { return false; } diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 4b5159a8c7..c9eae280c7 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -159,6 +159,11 @@ class GeneratorImpl : public TextGenerator { /// @param expr the member accessor expression /// @returns true if the member accessor was emitted bool EmitMemberAccessor(ast::MemberAccessorExpression* expr); + /// Handles a storage buffer accessor expression + /// @param expr the storage buffer accessor expression + /// @param rhs the right side of a store expression. Set to nullptr for a load + /// @returns true if the storage buffer accessor was emitted + bool EmitStorageBufferAccessor(ast::Expression* expr, ast::Expression* rhs); /// Handles return statements /// @param stmt the statement to emit /// @returns true if the statement was successfully emitted @@ -193,6 +198,18 @@ class GeneratorImpl : public TextGenerator { /// @returns true if the variable was emitted bool EmitProgramConstVariable(const ast::Variable* var); + /// Returns true if the accessor is accessing a storage buffer. + /// @param expr the expression to check + /// @returns true if the accessor is accessing a storage buffer for which + /// we need to execute a Load instruction. + bool is_storage_buffer_access(ast::MemberAccessorExpression* expr); + /// Returns true if the accessor is accessing a storage buffer. + /// @param expr the expression to check + /// @returns true if the accessor is accessing a storage buffer + bool is_storage_buffer_access(ast::ArrayAccessorExpression* expr); + /// Registers the given global with the generator + /// @param global the global to register + void register_global(ast::Variable* global); /// Checks if the global variable is in an input or output struct /// @param var the variable to check /// @returns true if the global is in an input or output struct diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 47a79d953d..9170fc7b47 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -30,6 +30,7 @@ #include "src/ast/set_decoration.h" #include "src/ast/sint_literal.h" #include "src/ast/struct.h" +#include "src/ast/struct_member_offset_decoration.h" #include "src/ast/type/alias_type.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" @@ -417,14 +418,104 @@ void frag_main() { } TEST_F(HlslGeneratorImplTest, - DISABLED_Emit_Function_EntryPoint_With_StorageBuffer) { + Emit_Function_EntryPoint_With_StorageBuffer_Read) { ast::type::VoidType void_type; ast::type::F32Type f32; - ast::type::VectorType vec4(&f32, 4); + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); auto coord_var = std::make_unique(std::make_unique( - "coord", ast::StorageClass::kStorageBuffer, &vec4)); + "coord", ast::StorageClass::kStorageBuffer, &s)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("coord"), + std::make_unique("b"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kFragment, "", + "frag_main"); + mod.AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(RWByteAddressBuffer coord : register(u0); + +void frag_main() { + float v = asfloat(coord.Load(4)); + return; +} + +)"); +} + +TEST_F(HlslGeneratorImplTest, + Emit_Function_EntryPoint_With_StorageBuffer_Store) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kStorageBuffer, &s)); ast::VariableDecorationList decos; decos.push_back(std::make_unique(0)); @@ -442,14 +533,15 @@ TEST_F(HlslGeneratorImplTest, auto func = std::make_unique("frag_main", std::move(params), &void_type); - auto var = - std::make_unique("v", ast::StorageClass::kFunction, &f32); - var->set_constructor(std::make_unique( - std::make_unique("coord"), - std::make_unique("x"))); + auto assign = std::make_unique( + std::make_unique( + std::make_unique("coord"), + std::make_unique("b")), + std::make_unique( + std::make_unique(&f32, 2.0f))); auto body = std::make_unique(); - body->append(std::make_unique(std::move(var))); + body->append(std::move(assign)); body->append(std::make_unique()); func->set_body(std::move(body)); @@ -463,7 +555,14 @@ TEST_F(HlslGeneratorImplTest, GeneratorImpl g(&mod); ASSERT_TRUE(g.Generate()) << g.error(); - EXPECT_EQ(g.result(), R"( ... )"); + EXPECT_EQ(g.result(), R"(RWByteAddressBuffer coord : register(u0); + +void frag_main() { + coord.Store(4, asuint(2.00000000f)); + return; +} + +)"); } TEST_F(HlslGeneratorImplTest, diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc index a4df1d8280..88e3660cbf 100644 --- a/src/writer/hlsl/generator_impl_member_accessor_test.cc +++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc @@ -15,9 +15,27 @@ #include #include "gtest/gtest.h" +#include "src/ast/array_accessor_expression.h" +#include "src/ast/assignment_statement.h" +#include "src/ast/binary_expression.h" +#include "src/ast/decorated_variable.h" +#include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" #include "src/ast/member_accessor_expression.h" #include "src/ast/module.h" +#include "src/ast/scalar_constructor_expression.h" +#include "src/ast/sint_literal.h" +#include "src/ast/struct.h" +#include "src/ast/struct_member.h" +#include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/type/array_type.h" +#include "src/ast/type/f32_type.h" +#include "src/ast/type/i32_type.h" +#include "src/ast/type/struct_type.h" +#include "src/ast/type/vector_type.h" +#include "src/ast/type_constructor_expression.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/hlsl/generator_impl.h" namespace tint { @@ -28,17 +46,1061 @@ namespace { using HlslGeneratorImplTest = testing::Test; TEST_F(HlslGeneratorImplTest, EmitExpression_MemberAccessor) { + ast::type::F32Type f32; + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("mem", &f32, std::move(deco))); + + auto strct = std::make_unique(); + strct->set_members(std::move(members)); + + ast::type::StructType s(std::move(strct)); + s.set_name("Str"); + + auto str_var = std::make_unique( + std::make_unique("str", ast::StorageClass::kPrivate, &s)); + auto str = std::make_unique("str"); auto mem = std::make_unique("mem"); ast::MemberAccessorExpression expr(std::move(str), std::move(mem)); - ast::Module m; - GeneratorImpl g(&m); + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(str_var.get()); + g.register_global(str_var.get()); + mod.AddGlobalVariable(std::move(str_var)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); EXPECT_EQ(g.result(), "str.mem"); } +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Load) { + // struct Data { + // [[offset 0]] a : i32; + // [[offset 4]] b : f32; + // }; + // var data : Data; + // data.b; + // + // -> asfloat(data.Load(4)); + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::MemberAccessorExpression expr( + std::make_unique("data"), + std::make_unique("b")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + ASSERT_TRUE(td.DetermineResultType(&expr)); + + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asfloat(data.Load(4))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Load_Int) { + // struct Data { + // [[offset 0]] a : i32; + // [[offset 4]] b : f32; + // }; + // var data : Data; + // data.a; + // + // -> asint(data.Load(0)); + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::MemberAccessorExpression expr( + std::make_unique("data"), + std::make_unique("a")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + ASSERT_TRUE(td.DetermineResultType(&expr)); + + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asint(data.Load(0))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray) { + // struct Data { + // [[offset 0]] a : [[stride 4]] array; + // }; + // var data : Data; + // data.a[2]; + // + // -> asint(data.Load((2 * 4)); + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32, 5); + ary.set_array_stride(4); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ary, std::move(a_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::ArrayAccessorExpression expr( + std::make_unique( + std::make_unique("data"), + std::make_unique("a")), + std::make_unique( + std::make_unique(&i32, 2))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asint(data.Load((4 * 2) + 0))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray_ExprIdx) { + // struct Data { + // [[offset 0]] a : [[stride 4]] array; + // }; + // var data : Data; + // data.a[(2 + 4) - 3]; + // + // -> asint(data.Load((4 * ((2 + 4) - 3))); + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32, 5); + ary.set_array_stride(4); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ary, std::move(a_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::ArrayAccessorExpression expr( + std::make_unique( + std::make_unique("data"), + std::make_unique("a")), + std::make_unique( + ast::BinaryOp::kSubtract, + std::make_unique( + ast::BinaryOp::kAdd, + std::make_unique( + std::make_unique(&i32, 2)), + std::make_unique( + std::make_unique(&i32, 4))), + std::make_unique( + std::make_unique(&i32, 3)))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asint(data.Load((4 * ((2 + 4) - 3)) + 0))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Store) { + // struct Data { + // [[offset 0]] a : i32; + // [[offset 4]] b : f32; + // }; + // var data : Data; + // data.b = 2.3f; + // + // -> data.Store(0, asuint(2.0f)); + + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + auto lhs = std::make_unique( + std::make_unique("data"), + std::make_unique("b")); + auto rhs = std::make_unique( + std::make_unique(&f32, 2.0f)); + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&assign)); + ASSERT_TRUE(g.EmitStatement(&assign)) << g.error(); + EXPECT_EQ(g.result(), R"(data.Store(4, asuint(2.00000000f)); +)"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Store_ToArray) { + // struct Data { + // [[offset 0]] a : [[stride 4]] array; + // }; + // var data : Data; + // data.a[2] = 2; + // + // -> data.Store((2 * 4), asuint(2.3f)); + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32, 5); + ary.set_array_stride(4); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ary, std::move(a_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + auto lhs = std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("a")), + std::make_unique( + std::make_unique(&i32, 2))); + auto rhs = std::make_unique( + std::make_unique(&i32, 2)); + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error(); + ASSERT_TRUE(g.EmitStatement(&assign)) << g.error(); + EXPECT_EQ(g.result(), R"(data.Store((4 * 2) + 0, asuint(2)); +)"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Store_Int) { + // struct Data { + // [[offset 0]] a : i32; + // [[offset 4]] b : f32; + // }; + // var data : Data; + // data.a = 2; + // + // -> data.Store(0, asuint(2)); + + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + auto lhs = std::make_unique( + std::make_unique("data"), + std::make_unique("a")); + auto rhs = std::make_unique( + std::make_unique(&i32, 2)); + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&assign)); + ASSERT_TRUE(g.EmitStatement(&assign)) << g.error(); + EXPECT_EQ(g.result(), R"(data.Store(0, asuint(2)); +)"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Load_Vec3) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // var data : Data; + // data.b; + // + // -> asfloat(data.Load(16)); + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + ast::MemberAccessorExpression expr( + std::make_unique("data"), + std::make_unique("b")); + + ASSERT_TRUE(td.DetermineResultType(&expr)); + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asfloat(data.Load3(16))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Store_Vec3) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // var data : Data; + // data.b = vec3(2.3f, 1.2f, 0.2f); + // + // -> data.Store(16, asuint(vector(2.3f, 1.2f, 0.2f))); + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + auto lit1 = std::make_unique(&f32, 1.f); + auto lit2 = std::make_unique(&f32, 2.f); + auto lit3 = std::make_unique(&f32, 3.f); + ast::ExpressionList values; + values.push_back( + std::make_unique(std::move(lit1))); + values.push_back( + std::make_unique(std::move(lit2))); + values.push_back( + std::make_unique(std::move(lit3))); + + auto lhs = std::make_unique( + std::make_unique("data"), + std::make_unique("b")); + auto rhs = std::make_unique( + &fvec3, std::move(values)); + + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&assign)); + ASSERT_TRUE(g.EmitStatement(&assign)) << g.error(); + EXPECT_EQ( + g.result(), + R"(data.Store3(16, asuint(vector(1.00000000f, 2.00000000f, 3.00000000f))); +)"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // struct Pre { + // var c : [[stride 32]] array; + // }; + // + // var data : Pre; + // data.c[2].b + // + // -> asfloat(data.Load3(16 + (2 * 32))) + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(deco))); + + deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(deco))); + + auto data_str = std::make_unique(); + data_str->set_members(std::move(members)); + + ast::type::StructType data(std::move(data_str)); + data.set_name("Data"); + + ast::type::ArrayType ary(&data, 4); + ary.set_array_stride(32); + + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("c", &ary, std::move(deco))); + + auto pre_str = std::make_unique(); + pre_str->set_members(std::move(members)); + + ast::type::StructType pre(std::move(pre_str)); + pre.set_name("Pre"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &pre)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + ast::MemberAccessorExpression expr( + std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("c")), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique("b")); + + ASSERT_TRUE(td.DetermineResultType(&expr)); + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asfloat(data.Load3(16 + (32 * 2) + 0))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Swizzle) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // struct Pre { + // var c : [[stride 32]] array; + // }; + // + // var data : Pre; + // data.c[2].b.xy + // + // -> asfloat(data.Load3(16 + (2 * 32))).xy + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(deco))); + + deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(deco))); + + auto data_str = std::make_unique(); + data_str->set_members(std::move(members)); + + ast::type::StructType data(std::move(data_str)); + data.set_name("Data"); + + ast::type::ArrayType ary(&data, 4); + ary.set_array_stride(32); + + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("c", &ary, std::move(deco))); + + auto pre_str = std::make_unique(); + pre_str->set_members(std::move(members)); + + ast::type::StructType pre(std::move(pre_str)); + pre.set_name("Pre"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &pre)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + ast::MemberAccessorExpression expr( + std::make_unique( + std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("c")), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique("b")), + std::make_unique("xy")); + + ASSERT_TRUE(td.DetermineResultType(&expr)); + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asfloat(data.Load3(16 + (32 * 2) + 0)).xy"); +} + +TEST_F( + HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Swizzle_SingleLetter) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // struct Pre { + // var c : [[stride 32]] array; + // }; + // + // var data : Pre; + // data.c[2].b.g + // + // -> asfloat(data.Load((4 * 1) + 16 + (2 * 32) + 0)) + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(deco))); + + deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(deco))); + + auto data_str = std::make_unique(); + data_str->set_members(std::move(members)); + + ast::type::StructType data(std::move(data_str)); + data.set_name("Data"); + + ast::type::ArrayType ary(&data, 4); + ary.set_array_stride(32); + + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("c", &ary, std::move(deco))); + + auto pre_str = std::make_unique(); + pre_str->set_members(std::move(members)); + + ast::type::StructType pre(std::move(pre_str)); + pre.set_name("Pre"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &pre)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + ast::MemberAccessorExpression expr( + std::make_unique( + std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("c")), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique("b")), + std::make_unique("g")); + + ASSERT_TRUE(td.DetermineResultType(&expr)); + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asfloat(data.Load((4 * 1) + 16 + (32 * 2) + 0))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Load_MultiLevel_Index) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // struct Pre { + // var c : [[stride 32]] array; + // }; + // + // var data : Pre; + // data.c[2].b[1] + // + // -> asfloat(data.Load(4 + 16 + (2 * 32))) + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(deco))); + + deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(deco))); + + auto data_str = std::make_unique(); + data_str->set_members(std::move(members)); + + ast::type::StructType data(std::move(data_str)); + data.set_name("Data"); + + ast::type::ArrayType ary(&data, 4); + ary.set_array_stride(32); + + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("c", &ary, std::move(deco))); + + auto pre_str = std::make_unique(); + pre_str->set_members(std::move(members)); + + ast::type::StructType pre(std::move(pre_str)); + pre.set_name("Pre"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &pre)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + ast::ArrayAccessorExpression expr( + std::make_unique( + std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("c")), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique("b")), + std::make_unique( + std::make_unique(&i32, 1))); + + ASSERT_TRUE(td.DetermineResultType(&expr)); + ASSERT_TRUE(g.EmitExpression(&expr)) << g.error(); + EXPECT_EQ(g.result(), "asfloat(data.Load((4 * 1) + 16 + (32 * 2) + 0))"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Store_MultiLevel) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // struct Pre { + // var c : [[stride 32]] array; + // }; + // + // var data : Pre; + // data.c[2].b = vec3(1.f, 2.f, 3.f); + // + // -> data.Store3(16 + (2 * 32), asuint(vector(1.0f, 2.0f, 3.0f))); + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(deco))); + + deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(deco))); + + auto data_str = std::make_unique(); + data_str->set_members(std::move(members)); + + ast::type::StructType data(std::move(data_str)); + data.set_name("Data"); + + ast::type::ArrayType ary(&data, 4); + ary.set_array_stride(32); + + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("c", &ary, std::move(deco))); + + auto pre_str = std::make_unique(); + pre_str->set_members(std::move(members)); + + ast::type::StructType pre(std::move(pre_str)); + pre.set_name("Pre"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &pre)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + auto lhs = std::make_unique( + std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("c")), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique("b")); + + auto lit1 = std::make_unique(&f32, 1.f); + auto lit2 = std::make_unique(&f32, 2.f); + auto lit3 = std::make_unique(&f32, 3.f); + ast::ExpressionList values; + values.push_back( + std::make_unique(std::move(lit1))); + values.push_back( + std::make_unique(std::move(lit2))); + values.push_back( + std::make_unique(std::move(lit3))); + + auto rhs = std::make_unique( + &fvec3, std::move(values)); + + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&assign)); + ASSERT_TRUE(g.EmitStatement(&assign)) << g.error(); + EXPECT_EQ( + g.result(), + R"(data.Store3(16 + (32 * 2) + 0, asuint(vector(1.00000000f, 2.00000000f, 3.00000000f))); +)"); +} + +TEST_F(HlslGeneratorImplTest, + EmitExpression_MemberAccessor_StorageBuffer_Store_Swizzle_SingleLetter) { + // struct Data { + // [[offset 0]] a : vec3; + // [[offset 16]] b : vec3; + // }; + // struct Pre { + // var c : [[stride 32]] array; + // }; + // + // var data : Pre; + // data.c[2].b.y = 1.f; + // + // -> data.Store((4 * 1) + 16 + (2 * 32) + 0, asuint(1.0f)); + + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::VectorType fvec3(&f32, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &ivec3, std::move(deco))); + + deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("b", &fvec3, std::move(deco))); + + auto data_str = std::make_unique(); + data_str->set_members(std::move(members)); + + ast::type::StructType data(std::move(data_str)); + data.set_name("Data"); + + ast::type::ArrayType ary(&data, 4); + ary.set_array_stride(32); + + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("c", &ary, std::move(deco))); + + auto pre_str = std::make_unique(); + pre_str->set_members(std::move(members)); + + ast::type::StructType pre(std::move(pre_str)); + pre.set_name("Pre"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &pre)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + GeneratorImpl g(&mod); + td.RegisterVariableForTesting(coord_var.get()); + g.register_global(coord_var.get()); + mod.AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + auto lhs = std::make_unique( + std::make_unique( + std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("c")), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique("b")), + std::make_unique("y")); + + auto rhs = std::make_unique( + std::make_unique(&i32, 1.f)); + + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&assign)); + ASSERT_TRUE(g.EmitStatement(&assign)) << g.error(); + EXPECT_EQ(g.result(), + R"(data.Store((4 * 1) + 16 + (32 * 2) + 0, asuint(1.00000000f)); +)"); +} + } // namespace } // namespace hlsl } // namespace writer