mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-21 10:49:14 +00:00
Move vector and matrix to type/.
This CL moves vector and matrix to type/ and updates the namespaces as needed. Bug: tint:1718 Change-Id: I48423b37f15cd69c03ab288143b2d36564789fbf Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113423 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
committed by
Dan Sinclair
parent
100d4bf339
commit
0e780da882
@@ -487,7 +487,7 @@ struct BuiltinPolyfill::State {
|
||||
if (!ty->is_unsigned_integer_scalar_or_vector()) {
|
||||
expr = b.Construct<i32>(expr);
|
||||
}
|
||||
if (ty->Is<sem::Vector>()) {
|
||||
if (ty->Is<type::Vector>()) {
|
||||
expr = b.Construct(T(ty), expr);
|
||||
}
|
||||
return expr;
|
||||
@@ -643,7 +643,7 @@ struct BuiltinPolyfill::State {
|
||||
/// with scalar calls.
|
||||
/// @param vec the vector type
|
||||
/// @return the polyfill function name
|
||||
Symbol quantizeToF16(const sem::Vector* vec) {
|
||||
Symbol quantizeToF16(const type::Vector* vec) {
|
||||
auto name = b.Symbols().New("tint_quantizeToF16");
|
||||
utils::Vector<const ast::Expression*, 4> args;
|
||||
for (uint32_t i = 0; i < vec->Width(); i++) {
|
||||
@@ -673,7 +673,7 @@ struct BuiltinPolyfill::State {
|
||||
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
|
||||
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
|
||||
if (rhs_ty->Is<sem::Vector>()) {
|
||||
if (rhs_ty->Is<type::Vector>()) {
|
||||
mask = b.Construct(CreateASTTypeFor(ctx, rhs_ty), mask);
|
||||
}
|
||||
auto* lhs = ctx.Clone(bin_op->lhs);
|
||||
@@ -761,7 +761,7 @@ struct BuiltinPolyfill::State {
|
||||
|
||||
/// @returns 1 if `ty` is not a vector, otherwise the vector width
|
||||
uint32_t WidthOf(const type::Type* ty) const {
|
||||
if (auto* v = ty->As<sem::Vector>()) {
|
||||
if (auto* v = ty->As<type::Vector>()) {
|
||||
return v->Width();
|
||||
}
|
||||
return 1;
|
||||
@@ -905,7 +905,7 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
||||
break;
|
||||
case sem::BuiltinType::kQuantizeToF16:
|
||||
if (polyfill.quantize_to_vec_f16) {
|
||||
if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) {
|
||||
if (auto* vec = builtin->ReturnType()->As<type::Vector>()) {
|
||||
fn = builtin_polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.quantizeToF16(vec); });
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ bool IntrinsicDataTypeFor(const type::Type* ty, DecomposeMemoryAccess::Intrinsic
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kF16;
|
||||
return true;
|
||||
}
|
||||
if (auto* vec = ty->As<sem::Vector>()) {
|
||||
if (auto* vec = ty->As<type::Vector>()) {
|
||||
switch (vec->Width()) {
|
||||
case 2:
|
||||
if (vec->type()->Is<type::I32>()) {
|
||||
@@ -529,7 +529,7 @@ struct DecomposeMemoryAccess::State {
|
||||
});
|
||||
} else {
|
||||
utils::Vector<const ast::Expression*, 8> values;
|
||||
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
|
||||
if (auto* mat_ty = el_ty->As<type::Matrix>()) {
|
||||
auto* vec_ty = mat_ty->ColumnType();
|
||||
Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
|
||||
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
|
||||
@@ -629,7 +629,7 @@ struct DecomposeMemoryAccess::State {
|
||||
|
||||
return utils::Vector{b.Decl(array), for_loop};
|
||||
},
|
||||
[&](const sem::Matrix* mat_ty) {
|
||||
[&](const type::Matrix* mat_ty) {
|
||||
auto* vec_ty = mat_ty->ColumnType();
|
||||
Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
|
||||
utils::Vector<const ast::Statement*, 4> stmts;
|
||||
@@ -901,7 +901,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
|
||||
if (swizzle->Indices().Length() == 1) {
|
||||
if (auto access = state.TakeAccess(accessor->structure)) {
|
||||
auto* vec_ty = access.type->As<sem::Vector>();
|
||||
auto* vec_ty = access.type->As<type::Vector>();
|
||||
auto* offset = state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0u]);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
@@ -937,7 +937,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (auto* vec_ty = access.type->As<sem::Vector>()) {
|
||||
if (auto* vec_ty = access.type->As<type::Vector>()) {
|
||||
auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
@@ -946,7 +946,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (auto* mat_ty = access.type->As<sem::Matrix>()) {
|
||||
if (auto* mat_ty = access.type->As<type::Matrix>()) {
|
||||
auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
|
||||
@@ -35,7 +35,7 @@ struct MatrixInfo {
|
||||
/// The stride in bytes between columns of the matrix
|
||||
uint32_t stride = 0;
|
||||
/// The type of the matrix
|
||||
const sem::Matrix* matrix = nullptr;
|
||||
const type::Matrix* matrix = nullptr;
|
||||
|
||||
/// @returns a new ast::Array that holds an vector column for each row of the
|
||||
/// matrix.
|
||||
@@ -77,7 +77,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||
continue;
|
||||
}
|
||||
for (auto* member : str_ty->Members()) {
|
||||
auto* matrix = member->Type()->As<sem::Matrix>();
|
||||
auto* matrix = member->Type()->As<type::Matrix>();
|
||||
if (!matrix) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ struct ExpandCompoundAssignment::State {
|
||||
|
||||
// Helper function that returns `true` if the type of `expr` is a vector.
|
||||
auto is_vec = [&](const ast::Expression* expr) {
|
||||
return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
|
||||
return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<type::Vector>();
|
||||
};
|
||||
|
||||
// Hoist the LHS expression subtree into local constants to produce a new
|
||||
|
||||
@@ -50,7 +50,7 @@ bool ShouldRun(const Program* program) {
|
||||
// Returns `true` if `type` is or contains a matrix type.
|
||||
bool ContainsMatrix(const type::Type* type) {
|
||||
type = type->UnwrapRef();
|
||||
if (type->Is<sem::Matrix>()) {
|
||||
if (type->Is<type::Matrix>()) {
|
||||
return true;
|
||||
} else if (auto* ary = type->As<sem::Array>()) {
|
||||
return ContainsMatrix(ary->ElemType());
|
||||
|
||||
@@ -51,7 +51,7 @@ struct PackedVec3::State {
|
||||
if (auto* str = sem.Get<sem::Struct>(decl)) {
|
||||
if (str->IsHostShareable()) {
|
||||
for (auto* member : str->Members()) {
|
||||
if (auto* vec = member->Type()->As<sem::Vector>()) {
|
||||
if (auto* vec = member->Type()->As<type::Vector>()) {
|
||||
if (vec->Width() == 3) {
|
||||
members.Add(member);
|
||||
|
||||
@@ -121,11 +121,11 @@ struct PackedVec3::State {
|
||||
}
|
||||
|
||||
// Wrap the load expressions with a cast to the unpacked type.
|
||||
utils::Hashmap<const sem::Vector*, Symbol, 3> unpack_fns;
|
||||
utils::Hashmap<const type::Vector*, Symbol, 3> unpack_fns;
|
||||
for (auto* ref : refs) {
|
||||
// ref is either a packed vec3 that needs casting, or a pointer to a vec3 which we just
|
||||
// leave alone.
|
||||
if (auto* vec_ty = ref->Type()->UnwrapRef()->As<sem::Vector>()) {
|
||||
if (auto* vec_ty = ref->Type()->UnwrapRef()->As<type::Vector>()) {
|
||||
auto* expr = ref->Declaration();
|
||||
ctx.Replace(expr, [this, vec_ty, expr] { //
|
||||
auto* packed = ctx.CloneWithoutTransform(expr);
|
||||
|
||||
@@ -86,7 +86,7 @@ struct Robustness::State {
|
||||
|
||||
auto* clamped_idx = Switch(
|
||||
sem->Object()->Type()->UnwrapRef(), //
|
||||
[&](const sem::Vector* vec) -> const ast::Expression* {
|
||||
[&](const type::Vector* vec) -> const ast::Expression* {
|
||||
if (sem->Index()->ConstantValue()) {
|
||||
// Index and size is constant.
|
||||
// Validation will have rejected any OOB accesses.
|
||||
@@ -95,7 +95,7 @@ struct Robustness::State {
|
||||
|
||||
return b.Call("min", idx(), u32(vec->Width() - 1u));
|
||||
},
|
||||
[&](const sem::Matrix* mat) -> const ast::Expression* {
|
||||
[&](const type::Matrix* mat) -> const ast::Expression* {
|
||||
if (sem->Index()->ConstantValue()) {
|
||||
// Index and size is constant.
|
||||
// Validation will have rejected any OOB accesses.
|
||||
@@ -177,7 +177,7 @@ struct Robustness::State {
|
||||
auto* coords_ty = builtin->Parameters()[static_cast<size_t>(coords_idx)]->Type();
|
||||
|
||||
auto width_of = [&](const type::Type* ty) {
|
||||
if (auto* vec = ty->As<sem::Vector>()) {
|
||||
if (auto* vec = ty->As<type::Vector>()) {
|
||||
return vec->Width();
|
||||
}
|
||||
return 1u;
|
||||
|
||||
@@ -132,7 +132,7 @@ struct Std140::State {
|
||||
while (auto* arr = ty->As<sem::Array>()) {
|
||||
ty = arr->ElemType();
|
||||
}
|
||||
if (auto* mat = ty->As<sem::Matrix>()) {
|
||||
if (auto* mat = ty->As<type::Matrix>()) {
|
||||
if (MatrixNeedsDecomposing(mat)) {
|
||||
return true;
|
||||
}
|
||||
@@ -241,7 +241,7 @@ struct Std140::State {
|
||||
};
|
||||
|
||||
// Map of matrix type in src, to decomposed column structure in ctx.dst.
|
||||
utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats;
|
||||
utils::Hashmap<const type::Matrix*, Std140Matrix, 8> std140_mats;
|
||||
|
||||
/// AccessChain describes a chain of access expressions to uniform buffer variable.
|
||||
struct AccessChain {
|
||||
@@ -253,7 +253,7 @@ struct Std140::State {
|
||||
utils::Vector<const sem::Expression*, 8> dynamic_indices;
|
||||
/// The type of the std140-decomposed matrix being accessed.
|
||||
/// May be nullptr if the chain does not pass through a std140-decomposed matrix.
|
||||
const sem::Matrix* std140_mat_ty = nullptr;
|
||||
const type::Matrix* std140_mat_ty = nullptr;
|
||||
/// The index in #indices of the access that resolves to the std140-decomposed matrix.
|
||||
/// May hold no value if the chain does not pass through a std140-decomposed matrix.
|
||||
std::optional<size_t> std140_mat_idx;
|
||||
@@ -266,7 +266,9 @@ struct Std140::State {
|
||||
|
||||
/// @returns true if the given matrix needs decomposing to column vectors for std140 layout.
|
||||
/// Std140 layout require matrix stride to be 16, otherwise decomposing is needed.
|
||||
static bool MatrixNeedsDecomposing(const sem::Matrix* mat) { return mat->ColumnStride() != 16; }
|
||||
static bool MatrixNeedsDecomposing(const type::Matrix* mat) {
|
||||
return mat->ColumnStride() != 16;
|
||||
}
|
||||
|
||||
/// ForkTypes walks the user-declared types in dependency order, forking structures that are
|
||||
/// used as uniform buffers which (transitively) use matrices that need std140 decomposition to
|
||||
@@ -282,7 +284,7 @@ struct Std140::State {
|
||||
bool fork_std140 = false;
|
||||
utils::Vector<const ast::StructMember*, 8> members;
|
||||
for (auto* member : str->Members()) {
|
||||
if (auto* mat = member->Type()->As<sem::Matrix>()) {
|
||||
if (auto* mat = member->Type()->As<type::Matrix>()) {
|
||||
// Is this member a matrix that needs decomposition for std140-layout?
|
||||
if (MatrixNeedsDecomposing(mat)) {
|
||||
// Structure member of matrix type needs decomposition.
|
||||
@@ -406,7 +408,7 @@ struct Std140::State {
|
||||
}
|
||||
return nullptr;
|
||||
},
|
||||
[&](const sem::Matrix* mat) -> const ast::Type* {
|
||||
[&](const type::Matrix* mat) -> const ast::Type* {
|
||||
if (MatrixNeedsDecomposing(mat)) {
|
||||
auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
|
||||
auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
|
||||
@@ -453,7 +455,7 @@ struct Std140::State {
|
||||
/// @param size the size in bytes of the matrix.
|
||||
/// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst).
|
||||
utils::Vector<const ast::StructMember*, 4> DecomposedMatrixStructMembers(
|
||||
const sem::Matrix* mat,
|
||||
const type::Matrix* mat,
|
||||
const std::string& name_prefix,
|
||||
uint32_t align,
|
||||
uint32_t size) {
|
||||
@@ -533,7 +535,7 @@ struct Std140::State {
|
||||
if (std140_mat_members.Contains(a->Member())) {
|
||||
// Record this on the access.
|
||||
access.std140_mat_idx = access.indices.Length();
|
||||
access.std140_mat_ty = expr->Type()->UnwrapRef()->As<sem::Matrix>();
|
||||
access.std140_mat_ty = expr->Type()->UnwrapRef()->As<type::Matrix>();
|
||||
}
|
||||
// Structure member accesses are always statically indexed
|
||||
access.indices.Push(u32(a->Member()->Index()));
|
||||
@@ -551,7 +553,7 @@ struct Std140::State {
|
||||
expr = a->Object();
|
||||
|
||||
// Is the object a std140 decomposed matrix?
|
||||
if (auto* mat = expr->Type()->UnwrapRef()->As<sem::Matrix>()) {
|
||||
if (auto* mat = expr->Type()->UnwrapRef()->As<type::Matrix>()) {
|
||||
if (std140_mats.Contains(mat)) {
|
||||
// Record this on the access.
|
||||
access.std140_mat_idx = access.indices.Length();
|
||||
@@ -641,7 +643,7 @@ struct Std140::State {
|
||||
}
|
||||
return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
|
||||
},
|
||||
[&](const sem::Matrix* mat) {
|
||||
[&](const type::Matrix* mat) {
|
||||
return "mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) +
|
||||
"_" + ConvertSuffix(mat->type());
|
||||
},
|
||||
@@ -713,7 +715,7 @@ struct Std140::State {
|
||||
}
|
||||
stmts.Push(b.Return(b.Construct(CreateASTTypeFor(ctx, ty), std::move(args))));
|
||||
}, //
|
||||
[&](const sem::Matrix* mat) {
|
||||
[&](const type::Matrix* mat) {
|
||||
// Reassemble a std140 matrix from the structure of column vector members.
|
||||
if (auto std140_mat = std140_mats.Get(mat)) {
|
||||
utils::Vector<const ast::Expression*, 8> args;
|
||||
@@ -835,7 +837,7 @@ struct Std140::State {
|
||||
auto* mat_member = str->Members()[mat_member_idx];
|
||||
auto mat_columns = *std140_mat_members.Get(mat_member);
|
||||
expr = b.MemberAccessor(expr, mat_columns[column_idx]->symbol);
|
||||
ty = mat_member->Type()->As<sem::Matrix>()->ColumnType();
|
||||
ty = mat_member->Type()->As<type::Matrix>()->ColumnType();
|
||||
} else {
|
||||
// Non-structure-member matrix. The columns are decomposed into a new, bespoke std140
|
||||
// structure.
|
||||
@@ -843,8 +845,8 @@ struct Std140::State {
|
||||
BuildAccessExpr(expr, ty, chain, std140_mat_idx, dynamic_index);
|
||||
expr = new_expr;
|
||||
ty = new_ty;
|
||||
auto* mat = ty->As<sem::Matrix>();
|
||||
auto std140_mat = std140_mats.Get(ty->As<sem::Matrix>());
|
||||
auto* mat = ty->As<type::Matrix>();
|
||||
auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
|
||||
expr = b.MemberAccessor(expr, std140_mat->columns[column_idx]);
|
||||
ty = mat->ColumnType();
|
||||
}
|
||||
@@ -918,7 +920,7 @@ struct Std140::State {
|
||||
}
|
||||
auto mat_columns = *std140_mat_members.Get(mat_member);
|
||||
expr = b.MemberAccessor(expr, mat_columns[column_idx]->symbol);
|
||||
ty = mat_member->Type()->As<sem::Matrix>()->ColumnType();
|
||||
ty = mat_member->Type()->As<type::Matrix>()->ColumnType();
|
||||
} else {
|
||||
// Non-structure-member matrix. The columns are decomposed into a new, bespoke
|
||||
// std140 structure.
|
||||
@@ -929,8 +931,8 @@ struct Std140::State {
|
||||
if (column_idx == 0) {
|
||||
name += "_" + mat_name + "_p" + std::to_string(column_param_idx);
|
||||
}
|
||||
auto* mat = ty->As<sem::Matrix>();
|
||||
auto std140_mat = std140_mats.Get(ty->As<sem::Matrix>());
|
||||
auto* mat = ty->As<type::Matrix>();
|
||||
auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
|
||||
expr = b.MemberAccessor(expr, std140_mat->columns[column_idx]);
|
||||
ty = mat->ColumnType();
|
||||
}
|
||||
@@ -1021,8 +1023,8 @@ struct Std140::State {
|
||||
auto [new_expr, new_ty, mat_name] =
|
||||
BuildAccessExpr(expr, ty, chain, std140_mat_idx, dynamic_index);
|
||||
expr = new_expr;
|
||||
auto* mat = ty->As<sem::Matrix>();
|
||||
auto std140_mat = std140_mats.Get(ty->As<sem::Matrix>());
|
||||
auto* mat = ty->As<type::Matrix>();
|
||||
auto std140_mat = std140_mats.Get(ty->As<type::Matrix>());
|
||||
columns = utils::Transform(std140_mat->columns, [&](auto column_name) {
|
||||
return b.MemberAccessor(b.Deref(let), column_name);
|
||||
});
|
||||
@@ -1084,12 +1086,12 @@ struct Std140::State {
|
||||
auto* expr = b.IndexAccessor(lhs, idx);
|
||||
return {expr, arr->ElemType(), name};
|
||||
}, //
|
||||
[&](const sem::Matrix* mat) -> ExprTypeName {
|
||||
[&](const type::Matrix* mat) -> ExprTypeName {
|
||||
auto* idx = dynamic_index(dyn_idx->slot);
|
||||
auto* expr = b.IndexAccessor(lhs, idx);
|
||||
return {expr, mat->ColumnType(), name};
|
||||
}, //
|
||||
[&](const sem::Vector* vec) -> ExprTypeName {
|
||||
[&](const type::Vector* vec) -> ExprTypeName {
|
||||
auto* idx = dynamic_index(dyn_idx->slot);
|
||||
auto* expr = b.IndexAccessor(lhs, idx);
|
||||
return {expr, vec->type(), name};
|
||||
@@ -1104,13 +1106,13 @@ struct Std140::State {
|
||||
/// The access is a vector swizzle.
|
||||
return Switch(
|
||||
ty, //
|
||||
[&](const sem::Vector* vec) -> ExprTypeName {
|
||||
[&](const type::Vector* vec) -> ExprTypeName {
|
||||
static const char xyzw[] = {'x', 'y', 'z', 'w'};
|
||||
std::string rhs;
|
||||
for (auto el : *swizzle) {
|
||||
rhs += xyzw[el];
|
||||
}
|
||||
auto swizzle_ty = src->Types().Find<sem::Vector>(
|
||||
auto swizzle_ty = src->Types().Find<type::Vector>(
|
||||
vec->type(), static_cast<uint32_t>(swizzle->Length()));
|
||||
auto* expr = b.MemberAccessor(lhs, rhs);
|
||||
return {expr, swizzle_ty, rhs};
|
||||
@@ -1136,11 +1138,11 @@ struct Std140::State {
|
||||
auto* expr = b.IndexAccessor(lhs, idx);
|
||||
return {expr, arr->ElemType(), std::to_string(idx)};
|
||||
}, //
|
||||
[&](const sem::Matrix* mat) -> ExprTypeName {
|
||||
[&](const type::Matrix* mat) -> ExprTypeName {
|
||||
auto* expr = b.IndexAccessor(lhs, idx);
|
||||
return {expr, mat->ColumnType(), std::to_string(idx)};
|
||||
}, //
|
||||
[&](const sem::Vector* vec) -> ExprTypeName {
|
||||
[&](const type::Vector* vec) -> ExprTypeName {
|
||||
auto* expr = b.IndexAccessor(lhs, idx);
|
||||
return {expr, vec->type(), std::to_string(idx)};
|
||||
}, //
|
||||
|
||||
@@ -92,11 +92,11 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type
|
||||
if (ty->Is<type::Bool>()) {
|
||||
return ctx.dst->create<ast::Bool>();
|
||||
}
|
||||
if (auto* m = ty->As<sem::Matrix>()) {
|
||||
if (auto* m = ty->As<type::Matrix>()) {
|
||||
auto* el = CreateASTTypeFor(ctx, m->type());
|
||||
return ctx.dst->create<ast::Matrix>(el, m->rows(), m->columns());
|
||||
}
|
||||
if (auto* v = ty->As<sem::Vector>()) {
|
||||
if (auto* v = ty->As<type::Vector>()) {
|
||||
auto* el = CreateASTTypeFor(ctx, v->type());
|
||||
return ctx.dst->create<ast::Vector>(el, v->Width());
|
||||
}
|
||||
|
||||
@@ -50,8 +50,8 @@ TEST_F(CreateASTTypeForTest, Basic) {
|
||||
|
||||
TEST_F(CreateASTTypeForTest, Matrix) {
|
||||
auto* mat = create([](ProgramBuilder& b) {
|
||||
auto* column_type = b.create<sem::Vector>(b.create<type::F32>(), 2u);
|
||||
return b.create<sem::Matrix>(column_type, 3u);
|
||||
auto* column_type = b.create<type::Vector>(b.create<type::F32>(), 2u);
|
||||
return b.create<type::Matrix>(column_type, 3u);
|
||||
});
|
||||
ASSERT_TRUE(mat->Is<ast::Matrix>());
|
||||
ASSERT_TRUE(mat->As<ast::Matrix>()->type->Is<ast::F32>());
|
||||
@@ -61,7 +61,7 @@ TEST_F(CreateASTTypeForTest, Matrix) {
|
||||
|
||||
TEST_F(CreateASTTypeForTest, Vector) {
|
||||
auto* vec =
|
||||
create([](ProgramBuilder& b) { return b.create<sem::Vector>(b.create<type::F32>(), 2u); });
|
||||
create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
|
||||
ASSERT_TRUE(vec->Is<ast::Vector>());
|
||||
ASSERT_TRUE(vec->As<ast::Vector>()->type->Is<ast::F32>());
|
||||
ASSERT_EQ(vec->As<ast::Vector>()->width, 2u);
|
||||
|
||||
@@ -47,7 +47,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
|
||||
}
|
||||
|
||||
auto* indexed = sem.Get(object_expr);
|
||||
if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
|
||||
if (!indexed->Type()->IsAnyOf<sem::Array, type::Matrix>()) {
|
||||
// We only care about array and matrices.
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* sem = program->Sem().Get<sem::Expression>(node)) {
|
||||
if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
|
||||
if (call->Target()->Is<sem::TypeConversion>() && call->Type()->Is<sem::Matrix>()) {
|
||||
if (call->Target()->Is<sem::TypeConversion>() && call->Type()->Is<type::Matrix>()) {
|
||||
auto& args = call->Arguments();
|
||||
if (args.Length() == 1 && args[0]->Type()->UnwrapRef()->is_float_matrix()) {
|
||||
return true;
|
||||
@@ -65,7 +65,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
using HelperFunctionKey =
|
||||
utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>;
|
||||
utils::UnorderedKeyWrapper<std::tuple<const type::Matrix*, const type::Matrix*>>;
|
||||
|
||||
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
|
||||
|
||||
@@ -75,7 +75,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
|
||||
if (!ty_conv) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* dst_type = call->Type()->As<sem::Matrix>();
|
||||
auto* dst_type = call->Type()->As<type::Matrix>();
|
||||
if (!dst_type) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -87,7 +87,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
|
||||
|
||||
auto& matrix = args[0];
|
||||
|
||||
auto* src_type = matrix->Type()->UnwrapRef()->As<sem::Matrix>();
|
||||
auto* src_type = matrix->Type()->UnwrapRef()->As<type::Matrix>();
|
||||
if (!src_type) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ namespace {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* call = program->Sem().Get<sem::Call>(node)) {
|
||||
if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) {
|
||||
if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<type::Matrix>()) {
|
||||
auto& args = call->Arguments();
|
||||
if (!args.IsEmpty() && args[0]->Type()->UnwrapRef()->is_scalar()) {
|
||||
return true;
|
||||
@@ -59,7 +59,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
std::unordered_map<const sem::Matrix*, Symbol> scalar_inits;
|
||||
std::unordered_map<const type::Matrix*, Symbol> scalar_inits;
|
||||
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
@@ -67,7 +67,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
|
||||
if (!ty_init) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* mat_type = call->Type()->As<sem::Matrix>();
|
||||
auto* mat_type = call->Type()->As<type::Matrix>();
|
||||
if (!mat_type) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -84,7 +84,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
|
||||
if (args[0]
|
||||
->Type()
|
||||
->UnwrapRef()
|
||||
->IsAnyOf<sem::Matrix, sem::Vector, type::AbstractNumeric>()) {
|
||||
->IsAnyOf<type::Matrix, type::Vector, type::AbstractNumeric>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ AttributeWGSLType WGSLTypeOf(const type::Type* ty) {
|
||||
[](const type::F16*) -> AttributeWGSLType {
|
||||
return {BaseWGSLType::kF16, 1};
|
||||
},
|
||||
[](const sem::Vector* vec) -> AttributeWGSLType {
|
||||
[](const type::Vector* vec) -> AttributeWGSLType {
|
||||
return {WGSLTypeOf(vec->type()).base_type, vec->Width()};
|
||||
},
|
||||
[](Default) -> AttributeWGSLType {
|
||||
|
||||
Reference in New Issue
Block a user