diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc index b2a931e06e..494f78faed 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc @@ -150,6 +150,13 @@ uint32_t GeneratorImplIr::Constant(const constant::Value* constant) { } module_.PushType(spv::Op::OpConstantComposite, operands); }, + [&](const type::Matrix* mat) { + OperandList operands = {Type(ty), id}; + for (uint32_t i = 0; i < mat->columns(); i++) { + operands.push_back(Constant(constant->Index(i))); + } + module_.PushType(spv::Op::OpConstantComposite, operands); + }, [&](Default) { TINT_ICE(Writer, diagnostics_) << "unhandled constant type: " << ty->FriendlyName(); }); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc index 8c7778f2b9..7881c0a16e 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc @@ -129,6 +129,69 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec2h) { )"); } +TEST_F(SpvGeneratorImplTest, Constant_Mat2x3f) { + auto* f32 = mod.types.f32(); + auto* v = b.create( + mod.types.mat2x3(f32), + utils::Vector{ + b.create(mod.types.vec3(f32), + utils::Vector{b.F32(42), b.F32(-1), b.F32(0.25)}, false, + false), + b.create(mod.types.vec3(f32), + utils::Vector{b.F32(-42), b.F32(0), b.F32(-0.25)}, false, + true), + }, + false, false); + generator_.Constant(b.Constant(v)); + EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypeMatrix %3 2 +%6 = OpConstant %4 42 +%7 = OpConstant %4 -1 +%8 = OpConstant %4 0.25 +%5 = OpConstantComposite %3 %6 %7 %8 +%10 = OpConstant %4 -42 +%11 = OpConstant %4 0 +%12 = OpConstant %4 -0.25 +%9 = OpConstantComposite %3 %10 %11 %12 +%1 = OpConstantComposite %2 %5 %9 +)"); +} + +TEST_F(SpvGeneratorImplTest, Constant_Mat4x2h) { + auto* f16 = mod.types.f16(); + auto* v = b.create( + mod.types.mat4x2(f16), + utils::Vector{ + b.create(mod.types.vec2(f16), utils::Vector{b.F16(42), b.F16(-1)}, + false, false), + b.create(mod.types.vec2(f16), utils::Vector{b.F16(0), b.F16(0.25)}, + false, true), + b.create(mod.types.vec2(f16), utils::Vector{b.F16(-42), b.F16(1)}, + false, false), + b.create(mod.types.vec2(f16), utils::Vector{b.F16(0.5), b.F16(-0)}, + false, true), + }, + false, false); + generator_.Constant(b.Constant(v)); + EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 16 +%3 = OpTypeVector %4 2 +%2 = OpTypeMatrix %3 4 +%6 = OpConstant %4 0x1.5p+5 +%7 = OpConstant %4 -0x1p+0 +%5 = OpConstantComposite %3 %6 %7 +%9 = OpConstant %4 0x0p+0 +%10 = OpConstant %4 0x1p-2 +%8 = OpConstantComposite %3 %9 %10 +%12 = OpConstant %4 -0x1.5p+5 +%13 = OpConstant %4 0x1p+0 +%11 = OpConstantComposite %3 %12 %13 +%15 = OpConstant %4 0x1p-1 +%14 = OpConstantComposite %3 %15 %9 +%1 = OpConstantComposite %2 %5 %8 %11 %14 +)"); +} + // Test that we do not emit the same constant more than once. TEST_F(SpvGeneratorImplTest, Constant_Deduplicate) { generator_.Constant(b.Constant(i32(42)));