transform::VarForDynamicIndex: Operate on matrices

Much like arrays, the SPIR-V writer cannot cope with dynamic indexing of matrices.

Fixed: tint:825
Change-Id: Ia111f15e0cf6fbd441861a4b3455a33b82b692ab
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51781
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-05-20 18:16:07 +00:00 committed by Tint LUCI CQ
parent 6c582778cf
commit 351ac4a009
11 changed files with 186 additions and 52 deletions

View File

@ -37,7 +37,7 @@ Output VarForDynamicIndex::Run(const Program* in, const DataMap&) {
if (auto* access_expr = node->As<ast::ArrayAccessorExpression>()) { if (auto* access_expr = node->As<ast::ArrayAccessorExpression>()) {
// Found an array accessor expression // Found an array accessor expression
auto* index_expr = access_expr->idx_expr(); auto* index_expr = access_expr->idx_expr();
auto* array_expr = access_expr->array(); auto* indexed_expr = access_expr->array();
if (index_expr->Is<ast::ScalarConstructorExpression>()) { if (index_expr->Is<ast::ScalarConstructorExpression>()) {
// Index expression is a literal value. As this isn't a dynamic index, // Index expression is a literal value. As this isn't a dynamic index,
@ -45,20 +45,20 @@ Output VarForDynamicIndex::Run(const Program* in, const DataMap&) {
continue; continue;
} }
auto* array = ctx.src->Sem().Get(array_expr); auto* indexed = ctx.src->Sem().Get(indexed_expr);
if (!array->Type()->Is<sem::Array>()) { if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
// This transform currently only cares about arrays. // This transform currently only cares about array and matrices.
continue; continue;
} }
auto* stmt = array->Stmt(); // Statement that owns the expression auto* stmt = indexed->Stmt(); // Statement that owns the expression
auto* block = stmt->Block(); // Block that owns the statement auto* block = stmt->Block(); // Block that owns the statement
// Construct a `var` declaration to hold the value in memory. // Construct a `var` declaration to hold the value in memory.
auto* ty = CreateASTTypeFor(&ctx, array->Type()); auto* ty = CreateASTTypeFor(&ctx, indexed->Type());
auto var_name = ctx.dst->Symbols().New("var_for_array"); auto var_name = ctx.dst->Symbols().New("var_for_index");
auto* var_decl = ctx.dst->Decl(ctx.dst->Var( auto* var_decl = ctx.dst->Decl(ctx.dst->Var(
var_name, ty, ast::StorageClass::kNone, ctx.Clone(array_expr))); var_name, ty, ast::StorageClass::kNone, ctx.Clone(indexed_expr)));
// Insert the `var` declaration before the statement that performs the // Insert the `var` declaration before the statement that performs the
// indexing. Note that for indexing chains, AST node ordering guarantees // indexing. Note that for indexing chains, AST node ordering guarantees
@ -67,7 +67,7 @@ Output VarForDynamicIndex::Run(const Program* in, const DataMap&) {
var_decl); var_decl);
// Replace the original index expression with the new `var`. // Replace the original index expression with the new `var`.
ctx.Replace(array_expr, ctx.dst->Expr(var_name)); ctx.Replace(indexed_expr, ctx.dst->Expr(var_name));
} }
} }

View File

@ -23,10 +23,10 @@
namespace tint { namespace tint {
namespace transform { namespace transform {
/// A transform that extracts array values that are dynamically indexed to a /// A transform that extracts array and matrix values that are dynamically
/// temporary `var` local before performing the index. This transform is used by /// indexed to a temporary `var` local before performing the index. This
/// the SPIR-V writer for dynamically indexing arrays, as there is no SPIR-V /// transform is used by the SPIR-V writer as there is no SPIR-V instruction
/// instruction that can dynamically index a non-pointer composite. /// that can dynamically index a non-pointer composite.
class VarForDynamicIndex : public Transform { class VarForDynamicIndex : public Transform {
public: public:
/// Constructor /// Constructor

View File

@ -44,8 +44,8 @@ fn f() {
fn f() { fn f() {
var i : i32; var i : i32;
let p : array<i32, 4> = array<i32, 4>(1, 2, 3, 4); let p : array<i32, 4> = array<i32, 4>(1, 2, 3, 4);
var var_for_array : array<i32, 4> = p; var var_for_index : array<i32, 4> = p;
let x : i32 = var_for_array[i]; let x : i32 = var_for_index[i];
} }
)"; )";
@ -67,7 +67,7 @@ fn f() {
// TODO(bclayton): Optimize this case: // TODO(bclayton): Optimize this case:
// This output is not as efficient as it could be. // This output is not as efficient as it could be.
// We only actually need to hoist the inner-most array to a `var` // We only actually need to hoist the inner-most array to a `var`
// (`var_for_array`), as later indexing operations will be working with // (`var_for_index`), as later indexing operations will be working with
// references, not values. // references, not values.
auto* expect = R"( auto* expect = R"(
@ -75,9 +75,9 @@ fn f() {
var i : i32; var i : i32;
var j : i32; var j : i32;
let p : array<array<i32, 2>, 2> = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4)); let p : array<array<i32, 2>, 2> = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
var var_for_array : array<array<i32, 2>, 2> = p; var var_for_index : array<array<i32, 2>, 2> = p;
var var_for_array_1 : array<i32, 2> = var_for_array[i]; var var_for_index_1 : array<i32, 2> = var_for_index[i];
let x : i32 = var_for_array_1[j]; let x : i32 = var_for_index_1[j];
} }
)"; )";

View File

@ -22,9 +22,9 @@ namespace {
using BuilderTest = TestHelper; using BuilderTest = TestHelper;
TEST_F(BuilderTest, ArrayAccessor) { TEST_F(BuilderTest, ArrayAccessor_VectorRef_Literal) {
// vec3<f32> ary; // var ary : vec3<f32>;
// ary[1] -> ptr<f32> // ary[1] -> ref<f32>
auto* var = Var("ary", ty.vec3<f32>()); auto* var = Var("ary", ty.vec3<f32>());
@ -57,10 +57,10 @@ TEST_F(BuilderTest, ArrayAccessor) {
)"); )");
} }
TEST_F(BuilderTest, Accessor_Array_LoadIndex) { TEST_F(BuilderTest, ArrayAccessor_VectorRef_Dynamic) {
// ary : vec3<f32>; // var ary : vec3<f32>;
// idx : i32; // var idx : i32;
// ary[idx] -> ptr<f32> // ary[idx] -> ref<f32>
auto* var = Var("ary", ty.vec3<f32>()); auto* var = Var("ary", ty.vec3<f32>());
auto* idx = Var("idx", ty.i32()); auto* idx = Var("idx", ty.i32());
@ -98,9 +98,9 @@ TEST_F(BuilderTest, Accessor_Array_LoadIndex) {
)"); )");
} }
TEST_F(BuilderTest, ArrayAccessor_Dynamic) { TEST_F(BuilderTest, ArrayAccessor_VectorRef_Dynamic2) {
// vec3<f32> ary; // var ary : vec3<f32>;
// ary[1 + 2] -> ptr<f32> // ary[1 + 2] -> ref<f32>
auto* var = Var("ary", ty.vec3<f32>()); auto* var = Var("ary", ty.vec3<f32>());
@ -134,10 +134,10 @@ TEST_F(BuilderTest, ArrayAccessor_Dynamic) {
)"); )");
} }
TEST_F(BuilderTest, ArrayAccessor_MultiLevel) { TEST_F(BuilderTest, ArrayAccessor_ArrayRef_MultiLevel) {
auto* ary4 = ty.array(ty.vec3<f32>(), 4); auto* ary4 = ty.array(ty.vec3<f32>(), 4);
// ary = array<vec3<f32>, 4> // var ary : array<vec3<f32>, 4>
// ary[3][2]; // ary[3][2];
auto* var = Var("ary", ary4); auto* var = Var("ary", ary4);
@ -172,7 +172,7 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
)"); )");
} }
TEST_F(BuilderTest, Accessor_ArrayWithSwizzle) { TEST_F(BuilderTest, ArrayAccessor_ArrayRef_ArrayWithSwizzle) {
auto* ary4 = ty.array(ty.vec3<f32>(), 4); auto* ary4 = ty.array(ty.vec3<f32>(), 4);
// var a : array<vec3<f32>, 4>; // var a : array<vec3<f32>, 4>;
@ -680,7 +680,7 @@ TEST_F(BuilderTest, MemberAccessor_Array_of_Swizzle) {
)"); )");
} }
TEST_F(BuilderTest, Accessor_Mixed_ArrayAndMember) { TEST_F(BuilderTest, ArrayAccessor_Mixed_ArrayAndMember) {
// type C = struct { // type C = struct {
// baz : vec3<f32> // baz : vec3<f32>
// } // }
@ -747,7 +747,7 @@ TEST_F(BuilderTest, Accessor_Mixed_ArrayAndMember) {
)"); )");
} }
TEST_F(BuilderTest, Accessor_Array_Of_Vec) { TEST_F(BuilderTest, ArrayAccessor_Of_Vec) {
// let pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>( // let pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
// vec2<f32>(0.0, 0.5), // vec2<f32>(0.0, 0.5),
// vec2<f32>(-0.5, -0.5), // vec2<f32>(-0.5, -0.5),
@ -790,7 +790,7 @@ TEST_F(BuilderTest, Accessor_Array_Of_Vec) {
Validate(b); Validate(b);
} }
TEST_F(BuilderTest, Accessor_Array_Of_Array_Of_f32) { TEST_F(BuilderTest, ArrayAccessor_Of_Array_Of_f32) {
// let pos : array<array<f32, 2>, 3> = array<vec2<f32, 2>, 3>( // let pos : array<array<f32, 2>, 3> = array<vec2<f32, 2>, 3>(
// array<f32, 2>(0.0, 0.5), // array<f32, 2>(0.0, 0.5),
// array<f32, 2>(-0.5, -0.5), // array<f32, 2>(-0.5, -0.5),
@ -835,7 +835,7 @@ TEST_F(BuilderTest, Accessor_Array_Of_Array_Of_f32) {
Validate(b); Validate(b);
} }
TEST_F(BuilderTest, Accessor_Const_Vec) { TEST_F(BuilderTest, ArrayAccessor_Vec_Literal) {
// let pos : vec2<f32> = vec2<f32>(0.0, 0.5); // let pos : vec2<f32> = vec2<f32>(0.0, 0.5);
// pos[1] // pos[1]
@ -864,7 +864,7 @@ TEST_F(BuilderTest, Accessor_Const_Vec) {
)"); )");
} }
TEST_F(BuilderTest, Accessor_Const_Vec_Dynamic) { TEST_F(BuilderTest, ArrayAccessor_Vec_Dynamic) {
// let pos : vec2<f32> = vec2<f32>(0.0, 0.5); // let pos : vec2<f32> = vec2<f32>(0.0, 0.5);
// idx : i32 // idx : i32
// pos[idx] // pos[idx]
@ -900,7 +900,7 @@ TEST_F(BuilderTest, Accessor_Const_Vec_Dynamic) {
)"); )");
} }
TEST_F(BuilderTest, Accessor_Array_NonPointer) { TEST_F(BuilderTest, ArrayAccessor_Array_Literal) {
// let a : array<f32, 3>; // let a : array<f32, 3>;
// a[2] // a[2]
@ -934,7 +934,7 @@ TEST_F(BuilderTest, Accessor_Array_NonPointer) {
Validate(b); Validate(b);
} }
TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) { TEST_F(BuilderTest, ArrayAccessor_Array_Dynamic) {
// let a : array<f32, 3>; // let a : array<f32, 3>;
// idx : i32 // idx : i32
// a[idx] // a[idx]
@ -982,6 +982,58 @@ TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) {
Validate(b); Validate(b);
} }
TEST_F(BuilderTest, ArrayAccessor_Matrix_Dynamic) {
// let a : mat2x2<f32>(vec2<f32>(1., 2.), vec2<f32>(3., 4.));
// idx : i32
// a[idx]
auto* var =
Const("a", ty.mat2x2<f32>(),
Construct(ty.mat2x2<f32>(), Construct(ty.vec2<f32>(), 1.f, 2.f),
Construct(ty.vec2<f32>(), 3.f, 4.f)));
auto* idx = Var("idx", ty.i32());
auto* expr = IndexAccessor("a", idx);
WrapInFunction(var, idx, expr);
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build());
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
%1 = OpTypeFunction %2
%7 = OpTypeFloat 32
%6 = OpTypeVector %7 2
%5 = OpTypeMatrix %6 2
%8 = OpConstant %7 1
%9 = OpConstant %7 2
%10 = OpConstantComposite %6 %8 %9
%11 = OpConstant %7 3
%12 = OpConstant %7 4
%13 = OpConstantComposite %6 %11 %12
%14 = OpConstantComposite %5 %10 %13
%17 = OpTypeInt 32 1
%16 = OpTypePointer Function %17
%18 = OpConstantNull %17
%20 = OpTypePointer Function %5
%21 = OpConstantNull %5
%23 = OpTypePointer Function %6
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%15 = OpVariable %16 Function %18
%19 = OpVariable %20 Function %21
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(OpStore %19 %14
%22 = OpLoad %17 %15
%24 = OpAccessChain %23 %19 %22
%25 = OpLoad %6 %24
)");
Validate(b);
}
} // namespace } // namespace
} // namespace spirv } // namespace spirv
} // namespace writer } // namespace writer

View File

@ -17,9 +17,9 @@
OpName %tint_symbol_5 "tint_symbol_5" OpName %tint_symbol_5 "tint_symbol_5"
OpName %tint_symbol_2 "tint_symbol_2" OpName %tint_symbol_2 "tint_symbol_2"
OpName %main "main" OpName %main "main"
OpName %var_for_array "var_for_array" OpName %var_for_index "var_for_index"
OpName %output "output" OpName %output "output"
OpName %var_for_array_1 "var_for_array_1" OpName %var_for_index_1 "var_for_index_1"
OpDecorate %tint_pointsize BuiltIn PointSize OpDecorate %tint_pointsize BuiltIn PointSize
OpDecorate %tint_symbol BuiltIn VertexIndex OpDecorate %tint_symbol BuiltIn VertexIndex
OpDecorate %tint_symbol_1 BuiltIn InstanceIndex OpDecorate %tint_symbol_1 BuiltIn InstanceIndex
@ -88,21 +88,21 @@
OpFunctionEnd OpFunctionEnd
%main = OpFunction %void None %22 %main = OpFunction %void None %22
%24 = OpLabel %24 = OpLabel
%var_for_array = OpVariable %_ptr_Function__arr_v2float_uint_4 Function %40 %var_for_index = OpVariable %_ptr_Function__arr_v2float_uint_4 Function %40
%output = OpVariable %_ptr_Function_Output Function %48 %output = OpVariable %_ptr_Function_Output Function %48
%var_for_array_1 = OpVariable %_ptr_Function__arr_v4float_uint_4 Function %62 %var_for_index_1 = OpVariable %_ptr_Function__arr_v4float_uint_4 Function %62
OpStore %tint_pointsize %float_1 OpStore %tint_pointsize %float_1
OpStore %var_for_array %37 OpStore %var_for_index %37
%41 = OpLoad %uint %tint_symbol_1 %41 = OpLoad %uint %tint_symbol_1
%44 = OpAccessChain %_ptr_Function_float %var_for_array %41 %uint_0 %44 = OpAccessChain %_ptr_Function_float %var_for_index %41 %uint_0
%45 = OpLoad %float %44 %45 = OpLoad %float %44
%50 = OpAccessChain %_ptr_Function_v4float %output %uint_0 %50 = OpAccessChain %_ptr_Function_v4float %output %uint_0
%52 = OpCompositeConstruct %v4float %float_0_5 %float_0_5 %45 %float_1 %52 = OpCompositeConstruct %v4float %float_0_5 %float_0_5 %45 %float_1
OpStore %50 %52 OpStore %50 %52
OpStore %var_for_array_1 %59 OpStore %var_for_index_1 %59
%64 = OpAccessChain %_ptr_Function_v4float %output %uint_1 %64 = OpAccessChain %_ptr_Function_v4float %output %uint_1
%65 = OpLoad %uint %tint_symbol_1 %65 = OpLoad %uint %tint_symbol_1
%66 = OpAccessChain %_ptr_Function_v4float %var_for_array_1 %65 %66 = OpAccessChain %_ptr_Function_v4float %var_for_index_1 %65
%67 = OpLoad %v4float %66 %67 = OpLoad %v4float %66
OpStore %64 %67 OpStore %64 %67
%69 = OpLoad %Output %output %69 = OpLoad %Output %output

6
test/bug/tint/825.wgsl Normal file
View File

@ -0,0 +1,6 @@
fn f() {
var i : i32;
var j : i32;
let m : mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
let f : f32 = m[i][j];
}

View File

@ -0,0 +1,12 @@
void f() {
int i = 0;
int j = 0;
const float2x2 m = float2x2(float2(1.0f, 2.0f), float2(3.0f, 4.0f));
const float f = m[i][j];
}
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}

View File

@ -0,0 +1,10 @@
#include <metal_stdlib>
using namespace metal;
void f() {
int i = 0;
int j = 0;
float2x2 const m = float2x2(float2(1.0f, 2.0f), float2(3.0f, 4.0f));
float const f = m[i][j];
}

View File

@ -0,0 +1,48 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 30
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %unused_entry_point "unused_entry_point"
OpName %f "f"
OpName %i "i"
OpName %j "j"
OpName %var_for_index "var_for_index"
%void = OpTypeVoid
%1 = OpTypeFunction %void
%int = OpTypeInt 32 1
%_ptr_Function_int = OpTypePointer Function %int
%10 = OpConstantNull %int
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%mat2v2float = OpTypeMatrix %v2float 2
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%17 = OpConstantComposite %v2float %float_1 %float_2
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%20 = OpConstantComposite %v2float %float_3 %float_4
%21 = OpConstantComposite %mat2v2float %17 %20
%_ptr_Function_mat2v2float = OpTypePointer Function %mat2v2float
%24 = OpConstantNull %mat2v2float
%_ptr_Function_float = OpTypePointer Function %float
%unused_entry_point = OpFunction %void None %1
%4 = OpLabel
OpReturn
OpFunctionEnd
%f = OpFunction %void None %1
%6 = OpLabel
%i = OpVariable %_ptr_Function_int Function %10
%j = OpVariable %_ptr_Function_int Function %10
%var_for_index = OpVariable %_ptr_Function_mat2v2float Function %24
OpStore %var_for_index %21
%25 = OpLoad %int %i
%26 = OpLoad %int %j
%28 = OpAccessChain %_ptr_Function_float %var_for_index %25 %26
%29 = OpLoad %float %28
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,6 @@
fn f() {
var i : i32;
var j : i32;
let m : mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
let f : f32 = m[i][j];
}

View File

@ -16,7 +16,7 @@
OpName %tint_symbol_3 "tint_symbol_3" OpName %tint_symbol_3 "tint_symbol_3"
OpName %tint_symbol_1 "tint_symbol_1" OpName %tint_symbol_1 "tint_symbol_1"
OpName %vtx_main "vtx_main" OpName %vtx_main "vtx_main"
OpName %var_for_array "var_for_array" OpName %var_for_index "var_for_index"
OpName %tint_symbol_6 "tint_symbol_6" OpName %tint_symbol_6 "tint_symbol_6"
OpName %tint_symbol_4 "tint_symbol_4" OpName %tint_symbol_4 "tint_symbol_4"
OpName %frag_main "frag_main" OpName %frag_main "frag_main"
@ -64,11 +64,11 @@
OpFunctionEnd OpFunctionEnd
%vtx_main = OpFunction %void None %29 %vtx_main = OpFunction %void None %29
%31 = OpLabel %31 = OpLabel
%var_for_array = OpVariable %_ptr_Function__arr_v2float_uint_3 Function %35 %var_for_index = OpVariable %_ptr_Function__arr_v2float_uint_3 Function %35
OpStore %tint_pointsize %float_1 OpStore %tint_pointsize %float_1
OpStore %var_for_array %pos OpStore %var_for_index %pos
%37 = OpLoad %int %tint_symbol %37 = OpLoad %int %tint_symbol
%39 = OpAccessChain %_ptr_Function_v2float %var_for_array %37 %39 = OpAccessChain %_ptr_Function_v2float %var_for_index %37
%40 = OpLoad %v2float %39 %40 = OpLoad %v2float %39
%41 = OpCompositeExtract %float %40 0 %41 = OpCompositeExtract %float %40 0
%42 = OpCompositeExtract %float %40 1 %42 = OpCompositeExtract %float %40 1