From 007dc42cbb433d894be4e9595a2ba19a7e61752d Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Thu, 8 Oct 2020 17:01:55 +0000 Subject: [PATCH] [type-determiner][spirv-writer] Add arrayLength support This CL adds support for retrieving the array length of a Runtime Array in the SPIR-V backend. Bug: tint:252 Change-Id: Ic13c4a99da5760738d57702c45f52c6a194a172d Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/29220 Commit-Queue: David Neto Reviewed-by: David Neto --- src/ast/intrinsic.cc | 3 + src/ast/intrinsic.h | 1 + src/type_determiner.cc | 8 + src/type_determiner_test.cc | 1 + src/writer/spirv/builder.cc | 31 ++++ src/writer/spirv/builder_intrinsic_test.cc | 181 +++++++++++++++++++++ 6 files changed, 225 insertions(+) diff --git a/src/ast/intrinsic.cc b/src/ast/intrinsic.cc index fd274ea8cf..ce3d692215 100644 --- a/src/ast/intrinsic.cc +++ b/src/ast/intrinsic.cc @@ -31,6 +31,9 @@ std::ostream& operator<<(std::ostream& out, Intrinsic i) { case Intrinsic::kAny: out << "any"; break; + case Intrinsic::kArrayLength: + out << "arrayLength"; + break; case Intrinsic::kAsin: out << "asin"; break; diff --git a/src/ast/intrinsic.h b/src/ast/intrinsic.h index c66f749db3..60fc4acae3 100644 --- a/src/ast/intrinsic.h +++ b/src/ast/intrinsic.h @@ -28,6 +28,7 @@ enum class Intrinsic { kAcos, kAll, kAny, + kArrayLength, kAsin, kAtan, kAtan2, diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 5a412a4d1d..2500468571 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -520,6 +520,11 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident, ctx_.type_mgr().Get(std::make_unique())); return true; } + if (ident->intrinsic() == ast::Intrinsic::kArrayLength) { + expr->func()->set_result_type( + ctx_.type_mgr().Get(std::make_unique())); + return true; + } if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) { if (expr->params().size() != 1) { set_error(expr->source(), @@ -638,6 +643,7 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident, } } if (data == nullptr) { + error_ = "unable to find intrinsic " + ident->name(); return false; } @@ -788,6 +794,8 @@ void TypeDeterminer::SetIntrinsicIfNeeded(ast::IdentifierExpression* ident) { ident->set_intrinsic(ast::Intrinsic::kAll); } else if (ident->name() == "any") { ident->set_intrinsic(ast::Intrinsic::kAny); + } else if (ident->name() == "arrayLength") { + ident->set_intrinsic(ast::Intrinsic::kArrayLength); } else if (ident->name() == "asin") { ident->set_intrinsic(ast::Intrinsic::kAsin); } else if (ident->name() == "atan") { diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index d60b376f2f..74adef954b 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -2494,6 +2494,7 @@ INSTANTIATE_TEST_SUITE_P( IntrinsicData{"acos", ast::Intrinsic::kAcos}, IntrinsicData{"all", ast::Intrinsic::kAll}, IntrinsicData{"any", ast::Intrinsic::kAny}, + IntrinsicData{"arrayLength", ast::Intrinsic::kArrayLength}, IntrinsicData{"asin", ast::Intrinsic::kAsin}, IntrinsicData{"atan", ast::Intrinsic::kAtan}, IntrinsicData{"atan2", ast::Intrinsic::kAtan2}, diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index b37fdd7cba..45b8e34ecd 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1670,6 +1670,37 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident, op = spv::Op::OpAny; } else if (intrinsic == ast::Intrinsic::kAll) { op = spv::Op::OpAll; + } else if (intrinsic == ast::Intrinsic::kArrayLength) { + if (call->params().empty()) { + error_ = "missing param for runtime array length"; + return 0; + } else if (!call->params()[0]->IsMemberAccessor()) { + if (call->params()[0]->result_type()->IsPointer()) { + error_ = "pointer accessors not supported yet"; + } else { + error_ = "invalid accessor for runtime array length"; + } + return 0; + } + auto* accessor = call->params()[0]->AsMemberAccessor(); + auto struct_id = GenerateExpression(accessor->structure()); + if (struct_id == 0) { + return 0; + } + params.push_back(Operand::Int(struct_id)); + + auto* type = accessor->structure()->result_type()->UnwrapAliasPtrAlias(); + if (!type->IsStruct()) { + error_ = + "invalid type (" + type->type_name() + ") for runtime array length"; + return 0; + } + // Runtime array must be the last member in the structure + params.push_back( + Operand::Int(uint32_t(type->AsStruct()->impl()->members().size() - 1))); + + push_function_inst(spv::Op::OpArrayLength, params); + return result_id; } else if (intrinsic == ast::Intrinsic::kCountOneBits) { op = spv::Op::OpBitCount; } else if (intrinsic == ast::Intrinsic::kDot) { diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc index f2d374b700..a16c21ea89 100644 --- a/src/writer/spirv/builder_intrinsic_test.cc +++ b/src/writer/spirv/builder_intrinsic_test.cc @@ -18,16 +18,22 @@ #include "src/ast/call_expression.h" #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/sint_literal.h" +#include "src/ast/struct.h" +#include "src/ast/struct_member.h" +#include "src/ast/type/array_type.h" #include "src/ast/type/bool_type.h" #include "src/ast/type/depth_texture_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/matrix_type.h" #include "src/ast/type/multisampled_texture_type.h" +#include "src/ast/type/pointer_type.h" #include "src/ast/type/sampled_texture_type.h" #include "src/ast/type/sampler_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type/void_type.h" @@ -2716,6 +2722,181 @@ OpFunctionEnd )"); } +TEST_F(BuilderTest, Call_ArrayLength) { + ast::type::F32Type f32; + ast::type::VoidType void_type; + ast::type::ArrayType ary(&f32); + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("a", &ary, std::move(decos))); + + auto s = std::make_unique(ast::StructDecoration::kNone, + std::move(members)); + ast::type::StructType s_type(std::move(s)); + s_type.set_name("my_struct"); + + auto var = std::make_unique("b", ast::StorageClass::kPrivate, + &s_type); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique("b"), + std::make_unique("a"))); + + ast::CallExpression expr( + std::make_unique("arrayLength"), + std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &void_type); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&expr), 11u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), + R"(%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%9 = OpTypeFloat 32 +%8 = OpTypeRuntimeArray %9 +%7 = OpTypeStruct %8 +%6 = OpTypePointer Private %7 +%10 = OpConstantNull %7 +%5 = OpVariable %6 Private %10 +%12 = OpTypeInt 32 0 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%11 = OpArrayLength %12 %5 0 +)"); +} + +TEST_F(BuilderTest, Call_ArrayLength_OtherMembersInStruct) { + ast::type::F32Type f32; + ast::type::VoidType void_type; + ast::type::ArrayType ary(&f32); + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("z", &f32, std::move(decos))); + members.push_back( + std::make_unique("a", &ary, std::move(decos))); + + auto s = std::make_unique(ast::StructDecoration::kNone, + std::move(members)); + ast::type::StructType s_type(std::move(s)); + s_type.set_name("my_struct"); + + auto var = std::make_unique("b", ast::StorageClass::kPrivate, + &s_type); + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique("b"), + std::make_unique("a"))); + + ast::CallExpression expr( + std::make_unique("arrayLength"), + std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &void_type); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&expr), 11u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), + R"(%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%8 = OpTypeFloat 32 +%9 = OpTypeRuntimeArray %8 +%7 = OpTypeStruct %8 %9 +%6 = OpTypePointer Private %7 +%10 = OpConstantNull %7 +%5 = OpVariable %6 Private %10 +%12 = OpTypeInt 32 0 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%11 = OpArrayLength %12 %5 1 +)"); +} + +// TODO(dsinclair): https://bugs.chromium.org/p/tint/issues/detail?id=266 +TEST_F(BuilderTest, DISABLED_Call_ArrayLength_Ptr) { + ast::type::F32Type f32; + ast::type::VoidType void_type; + ast::type::ArrayType ary(&f32); + ast::type::PointerType ptr(&ary, ast::StorageClass::kStorageBuffer); + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("z", &f32, std::move(decos))); + members.push_back( + std::make_unique("a", &ary, std::move(decos))); + + auto s = std::make_unique(ast::StructDecoration::kNone, + std::move(members)); + ast::type::StructType s_type(std::move(s)); + s_type.set_name("my_struct"); + + auto var = std::make_unique("b", ast::StorageClass::kPrivate, + &s_type); + + auto ptr_var = std::make_unique( + "ptr_var", ast::StorageClass::kPrivate, &ptr); + ptr_var->set_constructor(std::make_unique( + std::make_unique("b"), + std::make_unique("a"))); + + ast::ExpressionList params; + params.push_back(std::make_unique("ptr_var")); + + ast::CallExpression expr( + std::make_unique("arrayLength"), + std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + td.RegisterVariableForTesting(ptr_var.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &void_type); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateExpression(&expr), 11u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"( ... )"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%11 = OpArrayLength %12 %5 1 +)"); +} + } // namespace } // namespace spirv } // namespace writer