[spirv-writer] Handle multi name swizzles.

This CL rebuilds the accessor code to allow generating multi item
swizzles. This requires being able to output the access chain in the
middle of the access chain and then work with the results of that access
chain.

Bug: tint:5
Change-Id: I0687509c9ddec6a2e13d9e3595f04a091ee9af7b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20623
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-05-01 16:17:03 +00:00 committed by David Neto
parent a1a4800861
commit 7cac245abc
6 changed files with 656 additions and 154 deletions

View File

@ -107,7 +107,8 @@ bool ends_with(const std::string& input, const std::string& suffix) {
const auto input_len = input.size(); const auto input_len = input.size();
const auto suffix_len = suffix.size(); const auto suffix_len = suffix.size();
// Avoid integer overflow. // Avoid integer overflow.
return (input_len >= suffix_len) && (input_len - suffix_len == input.rfind(suffix)); return (input_len >= suffix_len) &&
(input_len - suffix_len == input.rfind(suffix));
} }
/// @param filename the filename to inspect /// @param filename the filename to inspect
@ -227,7 +228,9 @@ bool ReadFile(const std::string& input_file, std::vector<T>* buffer) {
/// and std::vector do. /// and std::vector do.
/// @returns true on success /// @returns true on success
template <typename ContainerT> template <typename ContainerT>
bool WriteFile(const std::string& output_file, const std::string mode, const ContainerT& buffer) { bool WriteFile(const std::string& output_file,
const std::string mode,
const ContainerT& buffer) {
const bool use_stdout = output_file.empty() || output_file == "-"; const bool use_stdout = output_file.empty() || output_file == "-";
FILE* file = stdout; FILE* file = stdout;
@ -244,7 +247,9 @@ bool WriteFile(const std::string& output_file, const std::string mode, const Con
} }
} }
size_t written = fwrite(buffer.data(), sizeof(typename ContainerT::value_type), buffer.size(), file); size_t written =
fwrite(buffer.data(), sizeof(typename ContainerT::value_type),
buffer.size(), file);
if (buffer.size() != written) { if (buffer.size() != written) {
if (use_stdout) { if (use_stdout) {
std::cerr << "Could not write all output to standard output" << std::endl; std::cerr << "Could not write all output to standard output" << std::endl;

View File

@ -275,6 +275,9 @@ bool TypeDeterminer::DetermineArrayAccessor(
if (!DetermineResultType(expr->array())) { if (!DetermineResultType(expr->array())) {
return false; return false;
} }
if (!DetermineResultType(expr->idx_expr())) {
return false;
}
auto* res = expr->array()->result_type(); auto* res = expr->array()->result_type();
auto* parent_type = res->UnwrapPtrIfNeeded(); auto* parent_type = res->UnwrapPtrIfNeeded();
@ -435,6 +438,12 @@ bool TypeDeterminer::DetermineMemberAccessor(
set_error(expr->source(), "struct member " + name + " not found"); set_error(expr->source(), "struct member " + name + " not found");
return false; return false;
} }
// If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
}
} else if (data_type->IsVector()) { } else if (data_type->IsVector()) {
auto* vec = data_type->AsVector(); auto* vec = data_type->AsVector();
@ -442,6 +451,11 @@ bool TypeDeterminer::DetermineMemberAccessor(
if (size == 1) { if (size == 1) {
// A single element swizzle is just the type of the vector. // A single element swizzle is just the type of the vector.
ret = vec->type(); ret = vec->type();
// If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
}
} else { } else {
// The vector will have a number of components equal to the length of the // The vector will have a number of components equal to the length of the
// swizzle. This assumes the validator will check that the swizzle // swizzle. This assumes the validator will check that the swizzle
@ -455,11 +469,6 @@ bool TypeDeterminer::DetermineMemberAccessor(
return false; return false;
} }
// If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class()));
}
expr->set_result_type(ret); expr->set_result_type(ret);
return true; return true;

View File

@ -860,12 +860,9 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle)); ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr); ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsPointer()); ASSERT_TRUE(mem.result_type()->IsVector());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
auto* ptr = mem.result_type()->AsPointer(); EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u);
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
} }
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
@ -968,12 +965,9 @@ TEST_F(TypeDeterminerTest, Expr_Accessor_MultiLevel) {
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr); ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsPointer()); ASSERT_TRUE(mem.result_type()->IsVector());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
auto* ptr = mem.result_type()->AsPointer(); EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u);
ASSERT_TRUE(ptr->type()->IsVector());
EXPECT_TRUE(ptr->type()->AsVector()->type()->IsF32());
EXPECT_EQ(ptr->type()->AsVector()->size(), 2u);
} }
using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam<ast::BinaryOp>; using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam<ast::BinaryOp>;

View File

@ -203,6 +203,12 @@ void Builder::iterate(std::function<void(const Instruction&)> cb) const {
} }
} }
uint32_t Builder::GenerateU32Literal(uint32_t val) {
ast::type::U32Type u32;
ast::IntLiteral lit(&u32, val);
return GenerateLiteralIfNeeded(&lit);
}
bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) { bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) {
auto lhs_id = GenerateExpression(assign->lhs()); auto lhs_id = GenerateExpression(assign->lhs());
if (lhs_id == 0) { if (lhs_id == 0) {
@ -479,105 +485,229 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) {
return true; return true;
} }
uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
assert(expr->IsArrayAccessor() || expr->IsMemberAccessor()); AccessorInfo* info) {
auto idx_id = GenerateExpression(expr->idx_expr());
if (idx_id == 0) {
return 0;
}
// If the source is a pointer we access chain into it.
if (info->source_type->IsPointer()) {
info->access_chain_indices.push_back(idx_id);
return true;
}
auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
if (result_type_id == 0) {
return false;
}
// We don't have a pointer, so we have to extract value from the vector
auto extract = result_op();
auto extract_id = extract.to_i();
push_function_inst(spv::Op::OpVectorExtractDynamic,
{Operand::Int(result_type_id), extract,
Operand::Int(info->source_id), Operand::Int(idx_id)});
info->source_id = extract_id;
info->source_type = expr->result_type();
return true;
}
bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
AccessorInfo* info) {
auto* data_type = expr->structure()->result_type()->UnwrapPtrIfNeeded();
while (data_type->IsAlias()) {
data_type = data_type->AsAlias()->type();
}
// If the data_type is a structure we're accessing a member, if it's a
// vector we're accessing a swizzle.
if (data_type->IsStruct()) {
if (!info->source_type->IsPointer()) {
error_ =
"Attempting to access a struct member on a non-pointer. Something is "
"wrong";
return false;
}
auto* strct = data_type->AsStruct()->impl();
auto name = expr->member()->name();
uint32_t i = 0;
for (; i < strct->members().size(); ++i) {
const auto& member = strct->members()[i];
if (member->name() == name) {
break;
}
}
auto idx_id = GenerateU32Literal(i);
if (idx_id == 0) {
return 0;
}
info->access_chain_indices.push_back(idx_id);
info->source_type = expr->result_type();
return true;
}
if (!data_type->IsVector()) {
error_ = "Member accessor without a struct or vector. Something is wrong";
return false;
}
auto swiz = expr->member()->name();
// Single element swizzle is either an access chain or a composite extract
if (swiz.size() == 1) {
auto val = IndexFromName(swiz[0]);
if (val == std::numeric_limits<uint32_t>::max()) {
error_ = "invalid swizzle name: " + swiz;
return false;
}
if (info->source_type->IsPointer()) {
auto idx_id = GenerateU32Literal(val);
if (idx_id == 0) {
return 0;
}
info->access_chain_indices.push_back(idx_id);
} else {
auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
if (result_type_id == 0) {
return 0;
}
auto extract = result_op();
auto extract_id = extract.to_i();
push_function_inst(spv::Op::OpCompositeExtract,
{Operand::Int(result_type_id), extract,
Operand::Int(info->source_id), Operand::Int(val)});
info->source_id = extract_id;
info->source_type = expr->result_type();
}
return true;
}
// Multi-item extract is a VectorShuffle. We have to emit any existing access
// chain data, then load the access chain and shuffle that.
if (!info->access_chain_indices.empty()) {
auto result_type_id = GenerateTypeIfNeeded(info->source_type);
if (result_type_id == 0) {
return 0;
}
auto extract = result_op();
auto extract_id = extract.to_i();
std::vector<Operand> ops = {Operand::Int(result_type_id), extract,
Operand::Int(info->source_id)};
for (auto id : info->access_chain_indices) {
ops.push_back(Operand::Int(id));
}
push_function_inst(spv::Op::OpAccessChain, ops);
info->source_id = GenerateLoadIfNeeded(expr->result_type(), extract_id);
info->source_type = expr->result_type()->UnwrapPtrIfNeeded();
info->access_chain_indices.clear();
}
auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
if (result_type_id == 0) {
return false;
}
auto vec_id = GenerateLoadIfNeeded(info->source_type, info->source_id);
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
std::vector<Operand> idx_list; std::vector<Operand> ops = {Operand::Int(result_type_id), result,
Operand::Int(vec_id), Operand::Int(vec_id)};
for (uint32_t i = 0; i < swiz.size(); ++i) {
auto val = IndexFromName(swiz[i]);
if (val == std::numeric_limits<uint32_t>::max()) {
error_ = "invalid swizzle name: " + swiz;
return false;
}
ops.push_back(Operand::Int(val));
}
push_function_inst(spv::Op::OpVectorShuffle, ops);
info->source_id = result_id;
info->source_type = expr->result_type();
return true;
}
uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
assert(expr->IsArrayAccessor() || expr->IsMemberAccessor());
// Gather a list of all the member and array accessors that are in this chain.
// The list is built in reverse order as that's the order we need to access
// the chain.
std::vector<ast::Expression*> accessors;
ast::Expression* source = expr; ast::Expression* source = expr;
while (true) { while (true) {
if (source->IsArrayAccessor()) { if (source->IsArrayAccessor()) {
auto* ary_accessor = source->AsArrayAccessor(); accessors.insert(accessors.begin(), source);
source = ary_accessor->array(); source = source->AsArrayAccessor()->array();
auto idx = GenerateExpression(ary_accessor->idx_expr());
if (idx == 0) {
return 0;
}
idx_list.insert(idx_list.begin(), Operand::Int(idx));
} else if (source->IsMemberAccessor()) { } else if (source->IsMemberAccessor()) {
auto* mem_accessor = source->AsMemberAccessor(); accessors.insert(accessors.begin(), source);
source = mem_accessor->structure(); source = source->AsMemberAccessor()->structure();
auto* data_type =
mem_accessor->structure()->result_type()->UnwrapPtrIfNeeded();
while (data_type->IsAlias()) {
data_type = data_type->AsAlias()->type();
}
if (data_type->IsStruct()) {
auto* strct = data_type->AsStruct()->impl();
auto name = mem_accessor->member()->name();
uint32_t i = 0;
for (; i < strct->members().size(); ++i) {
const auto& member = strct->members()[i];
if (member->name() == name) {
break;
}
}
ast::type::U32Type u32;
ast::IntLiteral idx(&u32, i);
auto idx_id = GenerateLiteralIfNeeded(&idx);
if (idx_id == 0) {
return false;
}
idx_list.insert(idx_list.begin(), Operand::Int(idx_id));
} else if (data_type->IsVector()) {
auto swiz = mem_accessor->member()->name();
if (swiz.size() == 1) {
// A single item swizzle is a simple access chain
auto val = IndexFromName(swiz[0]);
if (val == std::numeric_limits<uint32_t>::max()) {
error_ = "invalid swizzle name: " + swiz;
return false;
}
ast::type::U32Type u32;
ast::IntLiteral idx(&u32, val);
auto idx_id = GenerateLiteralIfNeeded(&idx);
if (idx_id == 0) {
return false;
}
idx_list.insert(idx_list.begin(), Operand::Int(idx_id));
} else {
// A multi-item swizzle means we need to generate the access chain
// to the current point and then pull values out of it
//
// TODO(dsinclair): Handle multi-item swizzle
}
} else {
error_ = "invalid type for member accessor: " + data_type->type_name();
return 0;
}
} else { } else {
break; break;
} }
} }
auto source_id = GenerateExpression(source); AccessorInfo info;
if (source_id == 0) { info.source_id = GenerateExpression(source);
if (info.source_id == 0) {
return 0; return 0;
} }
info.source_type = source->result_type();
auto type_id = GenerateTypeIfNeeded(expr->result_type()); std::vector<uint32_t> access_chain_indices;
if (type_id == 0) { for (auto* accessor : accessors) {
return 0; if (accessor->IsArrayAccessor()) {
if (!GenerateArrayAccessor(accessor->AsArrayAccessor(), &info)) {
return 0;
}
} else if (accessor->IsMemberAccessor()) {
if (!GenerateMemberAccessor(accessor->AsMemberAccessor(), &info)) {
return 0;
}
} else {
error_ = "invalid accessor in list: " + accessor->str();
return 0;
}
} }
idx_list.insert(idx_list.begin(), Operand::Int(source_id)); if (!info.access_chain_indices.empty()) {
idx_list.insert(idx_list.begin(), result); auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
idx_list.insert(idx_list.begin(), Operand::Int(type_id)); if (result_type_id == 0) {
push_function_inst(spv::Op::OpAccessChain, idx_list); return 0;
}
return result_id; auto result = result_op();
auto result_id = result.to_i();
std::vector<Operand> ops = {Operand::Int(result_type_id), result,
Operand::Int(info.source_id)};
for (auto id : info.access_chain_indices) {
ops.push_back(Operand::Int(id));
}
push_function_inst(spv::Op::OpAccessChain, ops);
info.source_id = result_id;
}
return info.source_id;
} }
uint32_t Builder::GenerateIdentifierExpression( uint32_t Builder::GenerateIdentifierExpression(
@ -1237,9 +1367,7 @@ bool Builder::GenerateArrayType(ast::type::ArrayType* ary,
if (ary->IsRuntimeArray()) { if (ary->IsRuntimeArray()) {
push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)}); push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)});
} else { } else {
ast::type::U32Type u32; auto len_id = GenerateU32Literal(ary->size());
ast::IntLiteral ary_size(&u32, ary->size());
auto len_id = GenerateLiteralIfNeeded(&ary_size);
if (len_id == 0) { if (len_id == 0) {
return false; return false;
} }

View File

@ -38,6 +38,22 @@ namespace spirv {
/// Builder class to create SPIR-V instructions from a module. /// Builder class to create SPIR-V instructions from a module.
class Builder { class Builder {
public: public:
/// Contains information for generating accessor chains
struct AccessorInfo {
/// The ID of the current chain source. The chain source may change as we
/// evaluate the access chain. The chain source always points to the ID
/// which we will use to evaluate the current set of accessors. This maybe
/// the original variable, or maybe an intermediary if we had to evaulate
/// the access chain early (in the case of a swizzle of an access chain).
uint32_t source_id;
/// The type of the current chain source. This type matches the deduced
/// result_type of the current source defined above.
ast::type::Type* source_type;
/// A list of access chain indices to emit. Note, we _only_ have access
/// chain indices if the source is pointer.
std::vector<uint32_t> access_chain_indices;
};
/// Constructor /// Constructor
/// @param mod the module to generate from /// @param mod the module to generate from
explicit Builder(ast::Module* mod); explicit Builder(ast::Module* mod);
@ -146,6 +162,10 @@ class Builder {
/// @returns the SPIR-V builtin or SpvBuiltInMax on error. /// @returns the SPIR-V builtin or SpvBuiltInMax on error.
SpvBuiltIn ConvertBuiltin(ast::Builtin builtin) const; SpvBuiltIn ConvertBuiltin(ast::Builtin builtin) const;
/// Generates a uint32_t literal.
/// @param val the value to generate
/// @returns the ID of the generated literal
uint32_t GenerateU32Literal(uint32_t val);
/// Generates an assignment statement /// Generates an assignment statement
/// @param assign the statement to generate /// @param assign the statement to generate
/// @returns true if the statement was successfully generated /// @returns true if the statement was successfully generated
@ -164,7 +184,7 @@ class Builder {
bool GenerateEntryPoint(ast::EntryPoint* ep); bool GenerateEntryPoint(ast::EntryPoint* ep);
/// Generates an expression /// Generates an expression
/// @param expr the expression to generate /// @param expr the expression to generate
/// @returns the resulting ID of the expression or 0 on error /// @returns the resulting ID of the exp = {};ression or 0 on error
uint32_t GenerateExpression(ast::Expression* expr); uint32_t GenerateExpression(ast::Expression* expr);
/// Generates the instructions for a function /// Generates the instructions for a function
/// @param func the function to generate /// @param func the function to generate
@ -182,10 +202,26 @@ class Builder {
/// @param var the variable to generate /// @param var the variable to generate
/// @returns true if the variable is emited. /// @returns true if the variable is emited.
bool GenerateGlobalVariable(ast::Variable* var); bool GenerateGlobalVariable(ast::Variable* var);
/// Generates an array accessor expression /// Generates an array accessor expression.
///
/// For more information on accessors see the "Pointer evaluation" section of
/// the WGSL specification.
///
/// @param expr the expresssion to generate /// @param expr the expresssion to generate
/// @returns the id of the expression or 0 on failure /// @returns the id of the expression or 0 on failure
uint32_t GenerateAccessorExpression(ast::Expression* expr); uint32_t GenerateAccessorExpression(ast::Expression* expr);
/// Generates an array accessor
/// @param expr the accessor to generate
/// @param info the current accessor information
/// @returns true if the accessor was generated successfully
bool GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
AccessorInfo* info);
/// Generates a member accessor
/// @param expr the accessor to generate
/// @param info the current accessor information
/// @returns true if the accessor was generated successfully
bool GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
AccessorInfo* info);
/// Generates an identifier expression /// Generates an identifier expression
/// @param expr the expresssion to generate /// @param expr the expresssion to generate
/// @returns the id of the expression or 0 on failure /// @returns the id of the expression or 0 on failure

View File

@ -15,6 +15,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/ast/array_accessor_expression.h" #include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h" #include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/float_literal.h" #include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h" #include "src/ast/identifier_expression.h"
#include "src/ast/int_literal.h" #include "src/ast/int_literal.h"
@ -46,6 +47,9 @@ TEST_F(BuilderTest, ArrayAccessor) {
ast::type::F32Type f32; ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3); ast::type::VectorType vec3(&f32, 3);
// vec3<f32> ary;
// ary[1] -> ptr<f32>
ast::Variable var("ary", ast::StorageClass::kFunction, &vec3); ast::Variable var("ary", ast::StorageClass::kFunction, &vec3);
auto ary = std::make_unique<ast::IdentifierExpression>("ary"); auto ary = std::make_unique<ast::IdentifierExpression>("ary");
@ -64,20 +68,69 @@ TEST_F(BuilderTest, ArrayAccessor) {
b.push_function(Function{}); b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 5u); EXPECT_EQ(b.GenerateAccessorExpression(&expr), 8u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3 %3 = OpTypeVector %4 3
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%6 = OpTypeInt 32 1 %5 = OpTypeInt 32 1
%7 = OpConstant %6 1 %6 = OpConstant %5 1
%8 = OpTypePointer Function %4 %7 = OpTypePointer Function %4
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%5 = OpAccessChain %8 %1 %7 R"(%8 = OpAccessChain %7 %1 %6
)");
}
TEST_F(BuilderTest, ArrayAccessor_Dynamic) {
ast::type::I32Type i32;
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
// vec3<f32> ary;
// ary[1 + 2] -> ptr<f32>
ast::Variable var("ary", ast::StorageClass::kFunction, &vec3);
auto ary = std::make_unique<ast::IdentifierExpression>("ary");
ast::ArrayAccessorExpression expr(
std::move(ary), std::make_unique<ast::BinaryExpression>(
ast::BinaryOp::kAdd,
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1)),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2))));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&var);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 10u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%2 = OpTypePointer Function %3
%5 = OpTypeInt 32 1
%6 = OpConstant %5 1
%7 = OpConstant %5 2
%9 = OpTypePointer Function %4
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%8 = OpIAdd %5 %6 %7
%10 = OpAccessChain %9 %1 %8
)"); )");
} }
@ -110,7 +163,7 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
b.push_function(Function{}); b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 8u); EXPECT_EQ(b.GenerateAccessorExpression(&expr), 12u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3 %4 = OpTypeVector %5 3
@ -118,16 +171,16 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
%7 = OpConstant %6 4 %7 = OpConstant %6 4
%3 = OpTypeArray %4 %7 %3 = OpTypeArray %4 %7
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%9 = OpTypeInt 32 1 %8 = OpTypeInt 32 1
%10 = OpConstant %9 2 %9 = OpConstant %8 3
%11 = OpConstant %9 3 %10 = OpConstant %8 2
%12 = OpTypePointer Function %5 %11 = OpTypePointer Function %5
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%8 = OpAccessChain %12 %1 %11 %10 R"(%12 = OpAccessChain %11 %1 %9 %10
)"); )");
} }
@ -169,20 +222,20 @@ TEST_F(BuilderTest, MemberAccessor) {
b.push_function(Function{}); b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 5u); EXPECT_EQ(b.GenerateAccessorExpression(&expr), 8u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeStruct %4 %4 %3 = OpTypeStruct %4 %4
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%6 = OpTypeInt 32 0 %5 = OpTypeInt 32 0
%7 = OpConstant %6 1 %6 = OpConstant %5 1
%8 = OpTypePointer Function %4 %7 = OpTypePointer Function %4
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%5 = OpAccessChain %8 %1 %7 R"(%8 = OpAccessChain %7 %1 %6
)"); )");
} }
@ -234,21 +287,21 @@ TEST_F(BuilderTest, MemberAccessor_Nested) {
b.push_function(Function{}); b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 6u); EXPECT_EQ(b.GenerateAccessorExpression(&expr), 9u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeStruct %5 %5 %4 = OpTypeStruct %5 %5
%3 = OpTypeStruct %4 %3 = OpTypeStruct %4
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%7 = OpTypeInt 32 0 %6 = OpTypeInt 32 0
%8 = OpConstant %7 0 %7 = OpConstant %6 0
%9 = OpTypePointer Function %5 %8 = OpTypePointer Function %5
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpAccessChain %9 %1 %8 %8 R"(%9 = OpAccessChain %8 %1 %7 %7
)"); )");
} }
@ -303,21 +356,21 @@ TEST_F(BuilderTest, MemberAccessor_Nested_WithAlias) {
b.push_function(Function{}); b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 6u); EXPECT_EQ(b.GenerateAccessorExpression(&expr), 9u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeStruct %5 %5 %4 = OpTypeStruct %5 %5
%3 = OpTypeStruct %4 %3 = OpTypeStruct %4
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%7 = OpTypeInt 32 0 %6 = OpTypeInt 32 0
%8 = OpConstant %7 0 %7 = OpConstant %6 0
%9 = OpTypePointer Function %5 %8 = OpTypePointer Function %5
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpAccessChain %9 %1 %8 %8 R"(%9 = OpAccessChain %8 %1 %7 %7
)"); )");
} }
@ -381,17 +434,17 @@ TEST_F(BuilderTest, MemberAccessor_Nested_Assignment_LHS) {
%4 = OpTypeStruct %5 %5 %4 = OpTypeStruct %5 %5
%3 = OpTypeStruct %4 %3 = OpTypeStruct %4
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%7 = OpTypeInt 32 0 %6 = OpTypeInt 32 0
%8 = OpConstant %7 0 %7 = OpConstant %6 0
%9 = OpTypePointer Function %5 %8 = OpTypePointer Function %5
%10 = OpConstant %5 2 %10 = OpConstant %5 2
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpAccessChain %9 %1 %8 %8 R"(%9 = OpAccessChain %8 %1 %7 %7
OpStore %6 %10 OpStore %9 %10
)"); )");
} }
@ -458,16 +511,16 @@ TEST_F(BuilderTest, MemberAccessor_Nested_Assignment_RHS) {
%3 = OpTypeStruct %4 %3 = OpTypeStruct %4
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%7 = OpTypePointer Function %5 %7 = OpTypePointer Function %5
%9 = OpTypeInt 32 0 %8 = OpTypeInt 32 0
%10 = OpConstant %9 0 %9 = OpConstant %8 0
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
%6 = OpVariable %7 Function %6 = OpVariable %7 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%8 = OpAccessChain %7 %1 %10 %10 R"(%10 = OpAccessChain %7 %1 %9 %9
%11 = OpLoad %5 %8 %11 = OpLoad %5 %10
OpStore %6 %11 OpStore %6 %11
)"); )");
} }
@ -476,6 +529,8 @@ TEST_F(BuilderTest, MemberAccessor_Swizzle_Single) {
ast::type::F32Type f32; ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3); ast::type::VectorType vec3(&f32, 3);
// ident.y
ast::Variable var("ident", ast::StorageClass::kFunction, &vec3); ast::Variable var("ident", ast::StorageClass::kFunction, &vec3);
ast::MemberAccessorExpression expr( ast::MemberAccessorExpression expr(
@ -492,41 +547,316 @@ TEST_F(BuilderTest, MemberAccessor_Swizzle_Single) {
b.push_function(Function{}); b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 5u); EXPECT_EQ(b.GenerateAccessorExpression(&expr), 8u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3 %3 = OpTypeVector %4 3
%2 = OpTypePointer Function %3 %2 = OpTypePointer Function %3
%6 = OpTypeInt 32 0 %5 = OpTypeInt 32 0
%7 = OpConstant %6 1 %6 = OpConstant %5 1
%8 = OpTypePointer Function %4 %7 = OpTypePointer Function %4
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function R"(%1 = OpVariable %2 Function
)"); )");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%5 = OpAccessChain %8 %1 %7 R"(%8 = OpAccessChain %7 %1 %6
)"); )");
} }
TEST_F(BuilderTest, DISABLED_MemberAccessor_Swizzle_MultipleNames) { TEST_F(BuilderTest, MemberAccessor_Swizzle_MultipleNames) {
// vec.yx ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
// ident.yx
ast::Variable var("ident", ast::StorageClass::kFunction, &vec3);
ast::MemberAccessorExpression expr(
std::make_unique<ast::IdentifierExpression>("ident"),
std::make_unique<ast::IdentifierExpression>("yx"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&var);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 7u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%2 = OpTypePointer Function %3
%5 = OpTypeVector %4 2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%7 = OpVectorShuffle %5 %6 %6 1 0
)");
} }
TEST_F(BuilderTest, DISABLED_Accessor_Mixed_ArrayAndMember) { TEST_F(BuilderTest, MemberAccessor_Swizzle_of_Swizzle) {
// a[0].foo[2].bar.baz.yx ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
// ident.yxz.xz
ast::Variable var("ident", ast::StorageClass::kFunction, &vec3);
ast::MemberAccessorExpression expr(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("ident"),
std::make_unique<ast::IdentifierExpression>("yxz")),
std::make_unique<ast::IdentifierExpression>("xz"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&var);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 8u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%2 = OpTypePointer Function %3
%7 = OpTypeVector %4 2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%5 = OpLoad %3 %1
%6 = OpVectorShuffle %3 %5 %5 1 0 2
%8 = OpVectorShuffle %7 %6 %6 0 2
)");
} }
TEST_F(BuilderTest, DISABLED_MemberAccessor_Swizzle_of_Swizzle) { TEST_F(BuilderTest, MemberAccessor_Member_of_Swizzle) {
// vec.yxz.xz ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
// ident.yxz.x
ast::Variable var("ident", ast::StorageClass::kFunction, &vec3);
ast::MemberAccessorExpression expr(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("ident"),
std::make_unique<ast::IdentifierExpression>("yxz")),
std::make_unique<ast::IdentifierExpression>("x"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&var);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 7u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%2 = OpTypePointer Function %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%5 = OpLoad %3 %1
%6 = OpVectorShuffle %3 %5 %5 1 0 2
%7 = OpCompositeExtract %4 %6 0
)");
} }
TEST_F(BuilderTest, DISABLED_MemberAccessor_Member_of_Swizzle) { TEST_F(BuilderTest, MemberAccessor_Array_of_Swizzle) {
// vec.yxz.x ast::type::I32Type i32;
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
// index.yxz[1]
ast::Variable var("ident", ast::StorageClass::kFunction, &vec3);
ast::ArrayAccessorExpression expr(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("ident"),
std::make_unique<ast::IdentifierExpression>("yxz")),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1)));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&var);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 9u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
%3 = OpTypeVector %4 3
%2 = OpTypePointer Function %3
%7 = OpTypeInt 32 1
%8 = OpConstant %7 1
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%5 = OpLoad %3 %1
%6 = OpVectorShuffle %3 %5 %5 1 0 2
%9 = OpVectorExtractDynamic %4 %6 %8
)");
} }
TEST_F(BuilderTest, DISABLED_MemberAccessor_Array_of_Swizzle) { TEST_F(BuilderTest, Accessor_Mixed_ArrayAndMember) {
// vec.yxz[1] ast::type::I32Type i32;
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
// type C = struct {
// baz : vec3<f32>
// }
// type B = struct {
// bar : C;
// }
// type A = struct {
// foo : array<B, 3>
// }
// var index : array<A, 2>
// index[0].foo[2].bar.baz.yx
ast::StructMemberDecorationList decos;
ast::StructMemberList members;
members.push_back(
std::make_unique<ast::StructMember>("baz", &vec3, std::move(decos)));
auto s = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(members));
ast::type::StructType c_type(std::move(s));
c_type.set_name("C");
members.push_back(
std::make_unique<ast::StructMember>("bar", &c_type, std::move(decos)));
s = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(members));
ast::type::StructType b_type(std::move(s));
b_type.set_name("B");
ast::type::ArrayType b_ary_type(&b_type, 3);
members.push_back(std::make_unique<ast::StructMember>("foo", &b_ary_type,
std::move(decos)));
s = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(members));
ast::type::StructType a_type(std::move(s));
a_type.set_name("A");
ast::type::ArrayType a_ary_type(&a_type, 2);
ast::Variable var("index", ast::StorageClass::kFunction, &a_ary_type);
ast::MemberAccessorExpression expr(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::ArrayAccessorExpression>(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::ArrayAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("index"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 0))),
std::make_unique<ast::IdentifierExpression>("foo")),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2))),
std::make_unique<ast::IdentifierExpression>("bar")),
std::make_unique<ast::IdentifierExpression>("baz")),
std::make_unique<ast::IdentifierExpression>("yx"));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&var);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b(&mod);
b.push_function(Function{});
ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(&expr), 18u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%9 = OpTypeFloat 32
%8 = OpTypeVector %9 3
%7 = OpTypeStruct %8
%6 = OpTypeStruct %7
%10 = OpTypeInt 32 0
%11 = OpConstant %10 3
%5 = OpTypeArray %6 %11
%4 = OpTypeStruct %5
%12 = OpConstant %10 2
%3 = OpTypeArray %4 %12
%2 = OpTypePointer Function %3
%13 = OpTypeInt 32 1
%14 = OpConstant %13 0
%15 = OpTypePointer Function %8
%17 = OpTypeVector %9 2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%16 = OpAccessChain %15 %1 %14 %14 %12 %14 %14
%18 = OpVectorShuffle %17 %16 %16 1 0
)");
}
TEST_F(BuilderTest, DISABLED_Accessor_Array_NonPointer) {
// const a : array<f32, 3>;
// a[2]
//
// This has to generate an OpConstantExtract and will need to read the 3 value
// out of the ScalarConstructor as extract requires integer indices.
}
TEST_F(BuilderTest, DISABLED_Accessor_Struct_NonPointer) {
// type A = struct {
// a : f32;
// b : f32;
// };
// const b : A;
// b.b
//
// This needs to do an OpCompositeExtract on the struct.
}
TEST_F(BuilderTest, DISABLED_Accessor_NonPointer_Multi) {
// type A = struct {
// a : f32;
// b : vec3<f32, 3>;
// };
// type B = struct {
// c : A;
// }
// const b : array<B, 3>;
// b[2].c.b.yx.x
//
// This needs to do an OpCompositeExtract similar to the AccessChain case
} }
} // namespace } // namespace