Move vector and matrix to type/.

This CL moves vector and matrix to type/ and updates the namespaces as
needed.

Bug: tint:1718
Change-Id: I48423b37f15cd69c03ab288143b2d36564789fbf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113423
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair
2022-12-08 22:21:24 +00:00
committed by Dan Sinclair
parent 100d4bf339
commit 0e780da882
63 changed files with 1144 additions and 1129 deletions

View File

@@ -487,7 +487,7 @@ struct BuiltinPolyfill::State {
if (!ty->is_unsigned_integer_scalar_or_vector()) {
expr = b.Construct<i32>(expr);
}
if (ty->Is<sem::Vector>()) {
if (ty->Is<type::Vector>()) {
expr = b.Construct(T(ty), expr);
}
return expr;
@@ -643,7 +643,7 @@ struct BuiltinPolyfill::State {
/// with scalar calls.
/// @param vec the vector type
/// @return the polyfill function name
Symbol quantizeToF16(const sem::Vector* vec) {
Symbol quantizeToF16(const type::Vector* vec) {
auto name = b.Symbols().New("tint_quantizeToF16");
utils::Vector<const ast::Expression*, 4> args;
for (uint32_t i = 0; i < vec->Width(); i++) {
@@ -673,7 +673,7 @@ struct BuiltinPolyfill::State {
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
if (rhs_ty->Is<sem::Vector>()) {
if (rhs_ty->Is<type::Vector>()) {
mask = b.Construct(CreateASTTypeFor(ctx, rhs_ty), mask);
}
auto* lhs = ctx.Clone(bin_op->lhs);
@@ -761,7 +761,7 @@ struct BuiltinPolyfill::State {
/// @returns 1 if `ty` is not a vector, otherwise the vector width
uint32_t WidthOf(const type::Type* ty) const {
if (auto* v = ty->As<sem::Vector>()) {
if (auto* v = ty->As<type::Vector>()) {
return v->Width();
}
return 1;
@@ -905,7 +905,7 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
break;
case sem::BuiltinType::kQuantizeToF16:
if (polyfill.quantize_to_vec_f16) {
if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) {
if (auto* vec = builtin->ReturnType()->As<type::Vector>()) {
fn = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.quantizeToF16(vec); });
}

View File

@@ -157,7 +157,7 @@ bool IntrinsicDataTypeFor(const type::Type* ty, DecomposeMemoryAccess::Intrinsic
out = DecomposeMemoryAccess::Intrinsic::DataType::kF16;
return true;
}
if (auto* vec = ty->As<sem::Vector>()) {
if (auto* vec = ty->As<type::Vector>()) {
switch (vec->Width()) {
case 2:
if (vec->type()->Is<type::I32>()) {
@@ -529,7 +529,7 @@ struct DecomposeMemoryAccess::State {
});
} else {
utils::Vector<const ast::Expression*, 8> values;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
if (auto* mat_ty = el_ty->As<type::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
@@ -629,7 +629,7 @@ struct DecomposeMemoryAccess::State {
return utils::Vector{b.Decl(array), for_loop};
},
[&](const sem::Matrix* mat_ty) {
[&](const type::Matrix* mat_ty) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
utils::Vector<const ast::Statement*, 4> stmts;
@@ -901,7 +901,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
if (swizzle->Indices().Length() == 1) {
if (auto access = state.TakeAccess(accessor->structure)) {
auto* vec_ty = access.type->As<sem::Vector>();
auto* vec_ty = access.type->As<type::Vector>();
auto* offset = state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0u]);
state.AddAccess(accessor, {
access.var,
@@ -937,7 +937,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
});
continue;
}
if (auto* vec_ty = access.type->As<sem::Vector>()) {
if (auto* vec_ty = access.type->As<type::Vector>()) {
auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
state.AddAccess(accessor, {
access.var,
@@ -946,7 +946,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
});
continue;
}
if (auto* mat_ty = access.type->As<sem::Matrix>()) {
if (auto* mat_ty = access.type->As<type::Matrix>()) {
auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
state.AddAccess(accessor, {
access.var,

View File

@@ -35,7 +35,7 @@ struct MatrixInfo {
/// The stride in bytes between columns of the matrix
uint32_t stride = 0;
/// The type of the matrix
const sem::Matrix* matrix = nullptr;
const type::Matrix* matrix = nullptr;
/// @returns a new ast::Array that holds an vector column for each row of the
/// matrix.
@@ -77,7 +77,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
continue;
}
for (auto* member : str_ty->Members()) {
auto* matrix = member->Type()->As<sem::Matrix>();
auto* matrix = member->Type()->As<type::Matrix>();
if (!matrix) {
continue;
}

View File

@@ -86,7 +86,7 @@ struct ExpandCompoundAssignment::State {
// Helper function that returns `true` if the type of `expr` is a vector.
auto is_vec = [&](const ast::Expression* expr) {
return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<type::Vector>();
};
// Hoist the LHS expression subtree into local constants to produce a new

View File

@@ -50,7 +50,7 @@ bool ShouldRun(const Program* program) {
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const type::Type* type) {
type = type->UnwrapRef();
if (type->Is<sem::Matrix>()) {
if (type->Is<type::Matrix>()) {
return true;
} else if (auto* ary = type->As<sem::Array>()) {
return ContainsMatrix(ary->ElemType());

View File

@@ -51,7 +51,7 @@ struct PackedVec3::State {
if (auto* str = sem.Get<sem::Struct>(decl)) {
if (str->IsHostShareable()) {
for (auto* member : str->Members()) {
if (auto* vec = member->Type()->As<sem::Vector>()) {
if (auto* vec = member->Type()->As<type::Vector>()) {
if (vec->Width() == 3) {
members.Add(member);
@@ -121,11 +121,11 @@ struct PackedVec3::State {
}
// Wrap the load expressions with a cast to the unpacked type.
utils::Hashmap<const sem::Vector*, Symbol, 3> unpack_fns;
utils::Hashmap<const type::Vector*, Symbol, 3> unpack_fns;
for (auto* ref : refs) {
// ref is either a packed vec3 that needs casting, or a pointer to a vec3 which we just
// leave alone.
if (auto* vec_ty = ref->Type()->UnwrapRef()->As<sem::Vector>()) {
if (auto* vec_ty = ref->Type()->UnwrapRef()->As<type::Vector>()) {
auto* expr = ref->Declaration();
ctx.Replace(expr, [this, vec_ty, expr] { //
auto* packed = ctx.CloneWithoutTransform(expr);

View File

@@ -86,7 +86,7 @@ struct Robustness::State {
auto* clamped_idx = Switch(
sem->Object()->Type()->UnwrapRef(), //
[&](const sem::Vector* vec) -> const ast::Expression* {
[&](const type::Vector* vec) -> const ast::Expression* {
if (sem->Index()->ConstantValue()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
@@ -95,7 +95,7 @@ struct Robustness::State {
return b.Call("min", idx(), u32(vec->Width() - 1u));
},
[&](const sem::Matrix* mat) -> const ast::Expression* {
[&](const type::Matrix* mat) -> const ast::Expression* {
if (sem->Index()->ConstantValue()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
@@ -177,7 +177,7 @@ struct Robustness::State {
auto* coords_ty = builtin->Parameters()[static_cast<size_t>(coords_idx)]->Type();
auto width_of = [&](const type::Type* ty) {
if (auto* vec = ty->As<sem::Vector>()) {
if (auto* vec = ty->As<type::Vector>()) {
return vec->Width();
}
return 1u;

View File

@@ -132,7 +132,7 @@ struct Std140::State {
while (auto* arr = ty->As<sem::Array>()) {
ty = arr->ElemType();
}
if (auto* mat = ty->As<sem::Matrix>()) {
if (auto* mat = ty->As<type::Matrix>()) {
if (MatrixNeedsDecomposing(mat)) {
return true;
}
@@ -241,7 +241,7 @@ struct Std140::State {
};
// Map of matrix type in src, to decomposed column structure in ctx.dst.
utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats;
utils::Hashmap<const type::Matrix*, Std140Matrix, 8> std140_mats;
/// AccessChain describes a chain of access expressions to uniform buffer variable.
struct AccessChain {
@@ -253,7 +253,7 @@ struct Std140::State {
utils::Vector<const sem::Expression*, 8> dynamic_indices;
/// The type of the std140-decomposed matrix being accessed.
/// May be nullptr if the chain does not pass through a std140-decomposed matrix.
const sem::Matrix* std140_mat_ty = nullptr;
const type::Matrix* std140_mat_ty = nullptr;
/// The index in #indices of the access that resolves to the std140-decomposed matrix.
/// May hold no value if the chain does not pass through a std140-decomposed matrix.
std::optional<size_t> std140_mat_idx;
@@ -266,7 +266,9 @@ struct Std140::State {
/// @returns true if the given matrix needs decomposing to column vectors for std140 layout.
/// Std140 layout require matrix stride to be 16, otherwise decomposing is needed.
static bool MatrixNeedsDecomposing(const sem::Matrix* mat) { return mat->ColumnStride() != 16; }
static bool MatrixNeedsDecomposing(const type::Matrix* mat) {
return mat->ColumnStride() != 16;
}
/// ForkTypes walks the user-declared types in dependency order, forking structures that are
/// used as uniform buffers which (transitively) use matrices that need std140 decomposition to
@@ -282,7 +284,7 @@ struct Std140::State {
bool fork_std140 = false;
utils::Vector<const ast::StructMember*, 8> members;
for (auto* member : str->Members()) {
if (auto* mat = member->Type()->As<sem::Matrix>()) {
if (auto* mat = member->Type()->As<type::Matrix>()) {
// Is this member a matrix that needs decomposition for std140-layout?
if (MatrixNeedsDecomposing(mat)) {
// Structure member of matrix type needs decomposition.
@@ -406,7 +408,7 @@ struct Std140::State {
}
return nullptr;
},
[&](const sem::Matrix* mat) -> const ast::Type* {
[&](const type::Matrix* mat) -> const ast::Type* {
if (MatrixNeedsDecomposing(mat)) {
auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
@@ -453,7 +455,7 @@ struct Std140::State {
/// @param size the size in bytes of the matrix.
/// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst).
utils::Vector<const ast::StructMember*, 4> DecomposedMatrixStructMembers(
const sem::Matrix* mat,
const type::Matrix* mat,
const std::string& name_prefix,
uint32_t align,
uint32_t size) {
@@ -533,7 +535,7 @@ struct Std140::State {
if (std140_mat_members.Contains(a->Member())) {
// Record this on the access.
access.std140_mat_idx = access.indices.Length();
access.std140_mat_ty = expr->Type()->UnwrapRef()->As<sem::Matrix>();
access.std140_mat_ty = expr->Type()->UnwrapRef()->As<type::Matrix>();
}
// Structure member accesses are always statically indexed
access.indices.Push(u32(a->Member()->Index()));
@@ -551,7 +553,7 @@ struct Std140::State {
expr = a->Object();
// Is the object a std140 decomposed matrix?
if (auto* mat = expr->Type()->UnwrapRef()->As<sem::Matrix>()) {
if (auto* mat = expr->Type()->UnwrapRef()->As<type::Matrix>()) {
if (std140_mats.Contains(mat)) {
// Record this on the access.
access.std140_mat_idx = access.indices.Length();
@@ -641,7 +643,7 @@ struct Std140::State {
}
return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
},
[&](const sem::Matrix* mat) {
[&](const type::Matrix* mat) {
return "mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) +
"_" + ConvertSuffix(mat->type());
},
@@ -713,7 +715,7 @@ struct Std140::State {
}
stmts.Push(b.Return(b.Construct(CreateASTTypeFor(ctx, ty), std::move(args))));
}, //
[&](const sem::Matrix* mat) {
[&](const type::Matrix* mat) {
// Reassemble a std140 matrix from the structure of column vector members.
if (auto std140_mat = std140_mats.Get(mat)) {
utils::Vector<const ast::Expression*, 8> args;
@@ -835,7 +837,7 @@ struct Std140::State {
auto* mat_member = str->Members()[mat_member_idx];
auto mat_columns = *std140_mat_members.Get(mat_member);
expr = b.MemberAccessor(expr, mat_columns[column_idx]->symbol);
ty = mat_member->Type()->As<sem::Matrix>()->ColumnType();
ty = mat_member->Type()->As<type::Matrix>()->ColumnType();
} else {
// Non-structure-member matrix. The columns are decomposed into a new, bespoke std140
// structure.
@@ -843,8 +845,8 @@ struct Std140::State {
BuildAccessExpr(expr, ty, chain, std140_mat_idx, dynamic_index);
expr = new_expr;
ty = new_ty;
auto* mat = ty->As<sem::Matrix>();
auto std140_mat = std140_mats.Get(ty->As<sem::Matrix>());
auto* mat = ty->As<type::Matrix>();
auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
expr = b.MemberAccessor(expr, std140_mat->columns[column_idx]);
ty = mat->ColumnType();
}
@@ -918,7 +920,7 @@ struct Std140::State {
}
auto mat_columns = *std140_mat_members.Get(mat_member);
expr = b.MemberAccessor(expr, mat_columns[column_idx]->symbol);
ty = mat_member->Type()->As<sem::Matrix>()->ColumnType();
ty = mat_member->Type()->As<type::Matrix>()->ColumnType();
} else {
// Non-structure-member matrix. The columns are decomposed into a new, bespoke
// std140 structure.
@@ -929,8 +931,8 @@ struct Std140::State {
if (column_idx == 0) {
name += "_" + mat_name + "_p" + std::to_string(column_param_idx);
}
auto* mat = ty->As<sem::Matrix>();
auto std140_mat = std140_mats.Get(ty->As<sem::Matrix>());
auto* mat = ty->As<type::Matrix>();
auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
expr = b.MemberAccessor(expr, std140_mat->columns[column_idx]);
ty = mat->ColumnType();
}
@@ -1021,8 +1023,8 @@ struct Std140::State {
auto [new_expr, new_ty, mat_name] =
BuildAccessExpr(expr, ty, chain, std140_mat_idx, dynamic_index);
expr = new_expr;
auto* mat = ty->As<sem::Matrix>();
auto std140_mat = std140_mats.Get(ty->As<sem::Matrix>());
auto* mat = ty->As<type::Matrix>();
auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
columns = utils::Transform(std140_mat->columns, [&](auto column_name) {
return b.MemberAccessor(b.Deref(let), column_name);
});
@@ -1084,12 +1086,12 @@ struct Std140::State {
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, arr->ElemType(), name};
}, //
[&](const sem::Matrix* mat) -> ExprTypeName {
[&](const type::Matrix* mat) -> ExprTypeName {
auto* idx = dynamic_index(dyn_idx->slot);
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, mat->ColumnType(), name};
}, //
[&](const sem::Vector* vec) -> ExprTypeName {
[&](const type::Vector* vec) -> ExprTypeName {
auto* idx = dynamic_index(dyn_idx->slot);
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, vec->type(), name};
@@ -1104,13 +1106,13 @@ struct Std140::State {
/// The access is a vector swizzle.
return Switch(
ty, //
[&](const sem::Vector* vec) -> ExprTypeName {
[&](const type::Vector* vec) -> ExprTypeName {
static const char xyzw[] = {'x', 'y', 'z', 'w'};
std::string rhs;
for (auto el : *swizzle) {
rhs += xyzw[el];
}
auto swizzle_ty = src->Types().Find<sem::Vector>(
auto swizzle_ty = src->Types().Find<type::Vector>(
vec->type(), static_cast<uint32_t>(swizzle->Length()));
auto* expr = b.MemberAccessor(lhs, rhs);
return {expr, swizzle_ty, rhs};
@@ -1136,11 +1138,11 @@ struct Std140::State {
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, arr->ElemType(), std::to_string(idx)};
}, //
[&](const sem::Matrix* mat) -> ExprTypeName {
[&](const type::Matrix* mat) -> ExprTypeName {
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, mat->ColumnType(), std::to_string(idx)};
}, //
[&](const sem::Vector* vec) -> ExprTypeName {
[&](const type::Vector* vec) -> ExprTypeName {
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, vec->type(), std::to_string(idx)};
}, //

View File

@@ -92,11 +92,11 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type
if (ty->Is<type::Bool>()) {
return ctx.dst->create<ast::Bool>();
}
if (auto* m = ty->As<sem::Matrix>()) {
if (auto* m = ty->As<type::Matrix>()) {
auto* el = CreateASTTypeFor(ctx, m->type());
return ctx.dst->create<ast::Matrix>(el, m->rows(), m->columns());
}
if (auto* v = ty->As<sem::Vector>()) {
if (auto* v = ty->As<type::Vector>()) {
auto* el = CreateASTTypeFor(ctx, v->type());
return ctx.dst->create<ast::Vector>(el, v->Width());
}

View File

@@ -50,8 +50,8 @@ TEST_F(CreateASTTypeForTest, Basic) {
TEST_F(CreateASTTypeForTest, Matrix) {
auto* mat = create([](ProgramBuilder& b) {
auto* column_type = b.create<sem::Vector>(b.create<type::F32>(), 2u);
return b.create<sem::Matrix>(column_type, 3u);
auto* column_type = b.create<type::Vector>(b.create<type::F32>(), 2u);
return b.create<type::Matrix>(column_type, 3u);
});
ASSERT_TRUE(mat->Is<ast::Matrix>());
ASSERT_TRUE(mat->As<ast::Matrix>()->type->Is<ast::F32>());
@@ -61,7 +61,7 @@ TEST_F(CreateASTTypeForTest, Matrix) {
TEST_F(CreateASTTypeForTest, Vector) {
auto* vec =
create([](ProgramBuilder& b) { return b.create<sem::Vector>(b.create<type::F32>(), 2u); });
create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
ASSERT_TRUE(vec->Is<ast::Vector>());
ASSERT_TRUE(vec->As<ast::Vector>()->type->Is<ast::F32>());
ASSERT_EQ(vec->As<ast::Vector>()->width, 2u);

View File

@@ -47,7 +47,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
}
auto* indexed = sem.Get(object_expr);
if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
if (!indexed->Type()->IsAnyOf<sem::Array, type::Matrix>()) {
// We only care about array and matrices.
return true;
}

View File

@@ -36,7 +36,7 @@ bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* sem = program->Sem().Get<sem::Expression>(node)) {
if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
if (call->Target()->Is<sem::TypeConversion>() && call->Type()->Is<sem::Matrix>()) {
if (call->Target()->Is<sem::TypeConversion>() && call->Type()->Is<type::Matrix>()) {
auto& args = call->Arguments();
if (args.Length() == 1 && args[0]->Type()->UnwrapRef()->is_float_matrix()) {
return true;
@@ -65,7 +65,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
using HelperFunctionKey =
utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>;
utils::UnorderedKeyWrapper<std::tuple<const type::Matrix*, const type::Matrix*>>;
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
@@ -75,7 +75,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
if (!ty_conv) {
return nullptr;
}
auto* dst_type = call->Type()->As<sem::Matrix>();
auto* dst_type = call->Type()->As<type::Matrix>();
if (!dst_type) {
return nullptr;
}
@@ -87,7 +87,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
auto& matrix = args[0];
auto* src_type = matrix->Type()->UnwrapRef()->As<sem::Matrix>();
auto* src_type = matrix->Type()->UnwrapRef()->As<type::Matrix>();
if (!src_type) {
return nullptr;
}

View File

@@ -32,7 +32,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* call = program->Sem().Get<sem::Call>(node)) {
if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) {
if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<type::Matrix>()) {
auto& args = call->Arguments();
if (!args.IsEmpty() && args[0]->Type()->UnwrapRef()->is_scalar()) {
return true;
@@ -59,7 +59,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
std::unordered_map<const sem::Matrix*, Symbol> scalar_inits;
std::unordered_map<const type::Matrix*, Symbol> scalar_inits;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
@@ -67,7 +67,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
if (!ty_init) {
return nullptr;
}
auto* mat_type = call->Type()->As<sem::Matrix>();
auto* mat_type = call->Type()->As<type::Matrix>();
if (!mat_type) {
return nullptr;
}
@@ -84,7 +84,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
if (args[0]
->Type()
->UnwrapRef()
->IsAnyOf<sem::Matrix, sem::Vector, type::AbstractNumeric>()) {
->IsAnyOf<type::Matrix, type::Vector, type::AbstractNumeric>()) {
return nullptr;
}

View File

@@ -165,7 +165,7 @@ AttributeWGSLType WGSLTypeOf(const type::Type* ty) {
[](const type::F16*) -> AttributeWGSLType {
return {BaseWGSLType::kF16, 1};
},
[](const sem::Vector* vec) -> AttributeWGSLType {
[](const type::Vector* vec) -> AttributeWGSLType {
return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
},
[](Default) -> AttributeWGSLType {