ast: Migrate to using ast::Type

Remove all sem::Type references from the AST.
ConstructedTypes are now all AST types.

The parsers will still create semantic types, but these are now disjoint
and ignored.
The parsers will be updated with future changes to stop creating these
semantic types.

Resolver creates semantic types from the AST types. Most downstream
logic continues to use the semantic types, however transforms will now
need to rebuild AST type information instead of reassigning semantic
information, as semantic nodes are fully rebuilt by the Resolver.

Bug: tint:724
Change-Id: I4ce03a075f13c77648cda5c3691bae202752ecc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49747
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton
2021-05-05 09:09:41 +00:00
committed by Commit Bot service account
parent 781de097eb
commit 02ebf0dcae
72 changed files with 1267 additions and 1091 deletions

View File

@@ -308,7 +308,7 @@ namespace ArrayStrideTests {
namespace {
struct Params {
create_type_func_ptr create_el_type;
create_ast_type_func_ptr create_el_type;
uint32_t stride;
bool should_pass;
};
@@ -318,17 +318,16 @@ struct TestWithParams : ResolverTestWithParam<Params> {};
using ArrayStrideTest = TestWithParams;
TEST_P(ArrayStrideTest, All) {
auto& params = GetParam();
auto el_ty = params.create_el_type(ty);
auto* el_ty = params.create_el_type(ty);
std::stringstream ss;
ss << "el_ty: " << el_ty->FriendlyName(Symbols())
<< ", stride: " << params.stride
ss << "el_ty: " << FriendlyName(el_ty) << ", stride: " << params.stride
<< ", should_pass: " << params.should_pass;
SCOPED_TRACE(ss.str());
auto arr = ty.array(el_ty, 4, params.stride);
auto arr = ty.array(Source{{12, 34}}, el_ty, 4, params.stride);
Global(Source{{12, 34}}, "myarray", arr, ast::StorageClass::kInput);
Global("myarray", arr, ast::StorageClass::kInput);
if (params.should_pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -366,58 +365,58 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(
// Succeed because stride >= element size (while being multiple of
// element alignment)
Params{ty_u32, default_u32.size, true},
Params{ty_i32, default_i32.size, true},
Params{ty_f32, default_f32.size, true},
Params{ty_vec2<f32>, default_vec2.size, true},
Params{ast_u32, default_u32.size, true},
Params{ast_i32, default_i32.size, true},
Params{ast_f32, default_f32.size, true},
Params{ast_vec2<f32>, default_vec2.size, true},
// vec3's default size is not a multiple of its alignment
// Params{ty_vec3<f32>, default_vec3.size, true},
Params{ty_vec4<f32>, default_vec4.size, true},
Params{ty_mat2x2<f32>, default_mat2x2.size, true},
Params{ty_mat3x3<f32>, default_mat3x3.size, true},
Params{ty_mat4x4<f32>, default_mat4x4.size, true},
// Params{ast_vec3<f32>, default_vec3.size, true},
Params{ast_vec4<f32>, default_vec4.size, true},
Params{ast_mat2x2<f32>, default_mat2x2.size, true},
Params{ast_mat3x3<f32>, default_mat3x3.size, true},
Params{ast_mat4x4<f32>, default_mat4x4.size, true},
// Fail because stride is < element size
Params{ty_u32, default_u32.size - 1, false},
Params{ty_i32, default_i32.size - 1, false},
Params{ty_f32, default_f32.size - 1, false},
Params{ty_vec2<f32>, default_vec2.size - 1, false},
Params{ty_vec3<f32>, default_vec3.size - 1, false},
Params{ty_vec4<f32>, default_vec4.size - 1, false},
Params{ty_mat2x2<f32>, default_mat2x2.size - 1, false},
Params{ty_mat3x3<f32>, default_mat3x3.size - 1, false},
Params{ty_mat4x4<f32>, default_mat4x4.size - 1, false},
Params{ast_u32, default_u32.size - 1, false},
Params{ast_i32, default_i32.size - 1, false},
Params{ast_f32, default_f32.size - 1, false},
Params{ast_vec2<f32>, default_vec2.size - 1, false},
Params{ast_vec3<f32>, default_vec3.size - 1, false},
Params{ast_vec4<f32>, default_vec4.size - 1, false},
Params{ast_mat2x2<f32>, default_mat2x2.size - 1, false},
Params{ast_mat3x3<f32>, default_mat3x3.size - 1, false},
Params{ast_mat4x4<f32>, default_mat4x4.size - 1, false},
// Succeed because stride equals multiple of element alignment
Params{ty_u32, default_u32.align * 7, true},
Params{ty_i32, default_i32.align * 7, true},
Params{ty_f32, default_f32.align * 7, true},
Params{ty_vec2<f32>, default_vec2.align * 7, true},
Params{ty_vec3<f32>, default_vec3.align * 7, true},
Params{ty_vec4<f32>, default_vec4.align * 7, true},
Params{ty_mat2x2<f32>, default_mat2x2.align * 7, true},
Params{ty_mat3x3<f32>, default_mat3x3.align * 7, true},
Params{ty_mat4x4<f32>, default_mat4x4.align * 7, true},
Params{ast_u32, default_u32.align * 7, true},
Params{ast_i32, default_i32.align * 7, true},
Params{ast_f32, default_f32.align * 7, true},
Params{ast_vec2<f32>, default_vec2.align * 7, true},
Params{ast_vec3<f32>, default_vec3.align * 7, true},
Params{ast_vec4<f32>, default_vec4.align * 7, true},
Params{ast_mat2x2<f32>, default_mat2x2.align * 7, true},
Params{ast_mat3x3<f32>, default_mat3x3.align * 7, true},
Params{ast_mat4x4<f32>, default_mat4x4.align * 7, true},
// Fail because stride is not multiple of element alignment
Params{ty_u32, (default_u32.align - 1) * 7, false},
Params{ty_i32, (default_i32.align - 1) * 7, false},
Params{ty_f32, (default_f32.align - 1) * 7, false},
Params{ty_vec2<f32>, (default_vec2.align - 1) * 7, false},
Params{ty_vec3<f32>, (default_vec3.align - 1) * 7, false},
Params{ty_vec4<f32>, (default_vec4.align - 1) * 7, false},
Params{ty_mat2x2<f32>, (default_mat2x2.align - 1) * 7, false},
Params{ty_mat3x3<f32>, (default_mat3x3.align - 1) * 7, false},
Params{ty_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false}));
Params{ast_u32, (default_u32.align - 1) * 7, false},
Params{ast_i32, (default_i32.align - 1) * 7, false},
Params{ast_f32, (default_f32.align - 1) * 7, false},
Params{ast_vec2<f32>, (default_vec2.align - 1) * 7, false},
Params{ast_vec3<f32>, (default_vec3.align - 1) * 7, false},
Params{ast_vec4<f32>, (default_vec4.align - 1) * 7, false},
Params{ast_mat2x2<f32>, (default_mat2x2.align - 1) * 7, false},
Params{ast_mat3x3<f32>, (default_mat3x3.align - 1) * 7, false},
Params{ast_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false}));
TEST_F(ArrayStrideTest, MultipleDecorations) {
auto arr = ty.array(ty.i32(), 4,
auto arr = ty.array(Source{{12, 34}}, ty.i32(), 4,
{
create<ast::StrideDecoration>(4),
create<ast::StrideDecoration>(4),
});
Global(Source{{12, 34}}, "myarray", arr, ast::StorageClass::kInput);
Global("myarray", arr, ast::StorageClass::kInput);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),

View File

@@ -40,6 +40,7 @@
#include "src/ast/storage_texture.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type_name.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/ast/vector.h"
@@ -151,8 +152,17 @@ void Resolver::set_referenced_from_function_if_needed(VariableInfo* var,
}
bool Resolver::Resolve() {
if (builder_->Diagnostics().contains_errors()) {
return false;
}
bool result = ResolveInternal();
if (result && diagnostics_.contains_errors()) {
TINT_ICE(diagnostics_) << "resolving failed, but no error was raised";
return false;
}
// Even if resolving failed, create all the semantic nodes for information we
// did generate.
CreateSemanticNodes();
@@ -169,13 +179,15 @@ bool Resolver::IsStorable(const sem::Type* type) {
if (auto* arr = type->As<sem::ArrayType>()) {
return IsStorable(arr->type());
}
if (auto* str = type->As<sem::StructType>()) {
for (const auto* member : str->impl()->members()) {
if (!IsStorable(member->type())) {
return false;
if (auto* str_ty = type->As<sem::StructType>()) {
if (auto* str = Structure(str_ty)) {
for (const auto* member : str->members) {
if (!IsStorable(member->Type())) {
return false;
}
}
return true;
}
return true;
}
return false;
}
@@ -196,8 +208,12 @@ bool Resolver::IsHostShareable(const sem::Type* type) {
return IsHostShareable(arr->type());
}
if (auto* str = type->As<sem::StructType>()) {
for (auto* member : str->impl()->members()) {
if (!IsHostShareable(member->type())) {
auto* info = Structure(str);
if (!info) {
return false;
}
for (auto* member : info->members) {
if (!IsHostShareable(member->Type())) {
return false;
}
}
@@ -225,11 +241,28 @@ bool Resolver::IsValidAssignment(const sem::Type* lhs, const sem::Type* rhs) {
bool Resolver::ResolveInternal() {
Mark(&builder_->AST());
auto register_named_type = [this](Symbol name, const sem::Type* type,
const Source& source) {
auto added = named_types_.emplace(name, type).second;
if (!added) {
diagnostics_.add_error("type with the name '" +
builder_->Symbols().NameFor(name) +
"' was already declared",
source);
return false;
}
return true;
};
// Process everything else in the order they appear in the module. This is
// necessary for validation of use-before-declaration.
for (auto* decl : builder_->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<sem::Type>()) {
if (!Type(ty)) {
if (auto* ty = decl->As<ast::NamedType>()) {
auto* sem_ty = Type(ty);
if (sem_ty == nullptr) {
return false;
}
if (!register_named_type(ty->name(), sem_ty, ty->source())) {
return false;
}
} else if (auto* func = decl->As<ast::Function>()) {
@@ -249,6 +282,8 @@ bool Resolver::ResolveInternal() {
}
}
bool result = true;
for (auto* node : builder_->ASTNodes().Objects()) {
if (marked_.count(node) == 0) {
if (node->IsAnyOf<ast::AccessDecoration, ast::StrideDecoration,
@@ -268,10 +303,11 @@ bool Resolver::ResolveInternal() {
<< "At: " << node->source() << "\n"
<< "Content: " << builder_->str(node) << "\n"
<< "Pointer: " << node;
result = false;
}
}
return true;
return result;
}
const sem::Type* Resolver::Type(const ast::Type* ty) {
@@ -360,6 +396,16 @@ const sem::Type* Resolver::Type(const ast::Type* ty) {
}
return nullptr;
}
if (auto* t = ty->As<ast::TypeName>()) {
auto it = named_types_.find(t->name());
if (it == named_types_.end()) {
diagnostics_.add_error(
"unknown type '" + builder_->Symbols().NameFor(t->name()) + "'",
t->source());
return nullptr;
}
return it->second;
}
TINT_UNREACHABLE(diagnostics_)
<< "Unhandled ast::Type: " << ty->TypeInfo().name;
return nullptr;
@@ -392,21 +438,26 @@ bool Resolver::Type(const sem::Type* ty, const Source& source /* = {} */) {
Resolver::VariableInfo* Resolver::Variable(
ast::Variable* var,
const sem::Type* type /* = nullptr*/) {
const sem::Type* type, /* = nullptr */
std::string type_name /* = "" */) {
auto it = variable_to_info_.find(var);
if (it != variable_to_info_.end()) {
return it->second;
}
if (!type) {
type = var->declared_type();
if (type == nullptr && var->type()) {
type = Type(var->type());
type_name = var->type()->FriendlyName(builder_->Symbols());
}
if (type == nullptr) {
return nullptr;
}
auto type_name = type->FriendlyName(builder_->Symbols());
auto* ctype = Canonical(type);
auto* info = variable_infos_.Create(var, ctype, type_name);
variable_to_info_.emplace(var, info);
// TODO(bclayton): Why is this here? Needed?
// Resolve variable's type
if (auto* arr = info->type->As<sem::ArrayType>()) {
if (!Array(arr, var->source())) {
@@ -805,12 +856,12 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
if (auto* struct_ty = Canonical(ty)->As<sem::StructType>()) {
// Validate the decorations for each struct members, and also check for
// invalid member types.
for (auto* member : struct_ty->impl()->members()) {
auto* member_ty = Canonical(member->type());
for (auto* member : Structure(struct_ty)->members) {
auto* member_ty = Canonical(member->Type());
if (member_ty->Is<sem::StructType>()) {
diagnostics_.add_error(
"entry point IO types cannot contain nested structures",
member->source());
member->Declaration()->source());
diagnostics_.add_note("while analysing entry point " +
builder_->Symbols().NameFor(func->symbol()),
func->source());
@@ -819,7 +870,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
if (arr->IsRuntimeArray()) {
diagnostics_.add_error(
"entry point IO types cannot contain runtime sized arrays",
member->source());
member->Declaration()->source());
diagnostics_.add_note(
"while analysing entry point " +
builder_->Symbols().NameFor(func->symbol()),
@@ -828,9 +879,9 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
}
}
if (!validate_entry_point_decorations_inner(member->decorations(),
member_ty, member->source(),
param_or_ret, true)) {
if (!validate_entry_point_decorations_inner(
member->Declaration()->decorations(), member_ty,
member->Declaration()->source(), param_or_ret, true)) {
diagnostics_.add_note("while analysing entry point " +
builder_->Symbols().NameFor(func->symbol()),
func->source());
@@ -842,10 +893,10 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
return true;
};
for (auto* param : func->params()) {
for (auto* param : info->parameters) {
if (!validate_entry_point_decorations(
param->decorations(), param->declared_type(), param->source(),
ParamOrRetType::kParameter)) {
param->declaration->decorations(), param->type,
param->declaration->source(), ParamOrRetType::kParameter)) {
return false;
}
}
@@ -943,19 +994,18 @@ bool Resolver::Function(ast::Function* func) {
}
}
if (func->return_type().ast || func->return_type().sem) {
info->return_type = func->return_type();
if (!info->return_type) {
info->return_type = Type(func->return_type().ast);
}
if (auto* ty = func->return_type()) {
info->return_type = Type(ty);
info->return_type_name = ty->FriendlyName(builder_->Symbols());
if (!info->return_type) {
return false;
}
} else {
info->return_type = builder_->create<sem::Void>();
info->return_type_name =
info->return_type->FriendlyName(builder_->Symbols());
}
info->return_type_name = info->return_type->FriendlyName(builder_->Symbols());
info->return_type = Canonical(info->return_type);
if (auto* str = info->return_type->As<sem::StructType>()) {
@@ -1374,17 +1424,16 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
SetType(expr, type_ctor->type());
const sem::Type* type = TypeOf(expr);
// Now that the argument types have been determined, make sure that they
// obey the constructor type rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
if (auto* vec_type = type_ctor->type()->As<sem::Vector>()) {
return ValidateVectorConstructor(type_ctor, vec_type,
type_ctor->values());
if (auto* vec_type = type->As<sem::Vector>()) {
return ValidateVectorConstructor(type_ctor, vec_type);
}
if (auto* mat_type = type_ctor->type()->As<sem::Matrix>()) {
auto mat_typename = TypeNameOf(type_ctor);
return ValidateMatrixConstructor(type_ctor, mat_type,
type_ctor->values());
if (auto* mat_type = type->As<sem::Matrix>()) {
return ValidateMatrixConstructor(type_ctor, mat_type);
}
// TODO(crbug.com/tint/634): Validate array constructor
} else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
@@ -1398,8 +1447,8 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
bool Resolver::ValidateVectorConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type,
const ast::ExpressionList& values) {
const sem::Vector* vec_type) {
auto& values = ctor->values();
auto* elem_type = vec_type->type()->UnwrapAll();
size_t value_cardinality_sum = 0;
for (auto* value : values) {
@@ -1467,8 +1516,8 @@ bool Resolver::ValidateVectorConstructor(
bool Resolver::ValidateMatrixConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type,
const ast::ExpressionList& values) {
const sem::Matrix* matrix_type) {
auto& values = ctor->values();
// Zero Value expression
if (values.empty()) {
return true;
@@ -1600,7 +1649,7 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
const sem::StructMember* member = nullptr;
for (auto* m : str->members) {
if (m->Declaration()->symbol() == symbol) {
ret = m->Declaration()->type();
ret = m->Type();
member = m;
break;
}
@@ -1961,7 +2010,16 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
ast::Variable* var = stmt->variable();
Mark(var);
const sem::Type* type = var->declared_type();
// If the variable has a declared type, resolve it.
std::string type_name;
const sem::Type* type = nullptr;
if (auto* ast_ty = var->type()) {
type_name = ast_ty->FriendlyName(builder_->Symbols());
type = Type(ast_ty);
if (!type) {
return false;
}
}
bool is_global = false;
if (variable_stack_.get(var->symbol(), nullptr, &is_global)) {
@@ -1982,14 +2040,15 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
// If the variable has no type, infer it from the rhs
if (type == nullptr) {
type_name = TypeNameOf(ctor);
type = rhs_type->UnwrapPtrIfNeeded();
}
if (!IsValidAssignment(type, rhs_type)) {
diagnostics_.add_error(
"variable of type '" + type->FriendlyName(builder_->Symbols()) +
"variable of type '" + type_name +
"' cannot be initialized with a value of type '" +
rhs_type->FriendlyName(builder_->Symbols()) + "'",
TypeNameOf(ctor) + "'",
stmt->source());
return false;
}
@@ -2000,7 +2059,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
Mark(deco);
}
auto* info = Variable(var, type);
auto* info = Variable(var, type, type_name);
if (!info) {
return false;
}
@@ -2071,13 +2130,19 @@ const sem::Type* Resolver::TypeOf(const ast::Literal* lit) {
return nullptr;
}
void Resolver::SetType(ast::Expression* expr, const sem::Type* type) {
SetType(expr, type, type->FriendlyName(builder_->Symbols()));
void Resolver::SetType(ast::Expression* expr, typ::Type type) {
SetType(expr, type,
type.sem ? type.sem->FriendlyName(builder_->Symbols())
: type.ast->FriendlyName(builder_->Symbols()));
}
void Resolver::SetType(ast::Expression* expr,
const sem::Type* type,
typ::Type type,
const std::string& type_name) {
if (!type.sem) {
type.sem = Type(type.ast);
TINT_ASSERT(type.sem);
}
if (expr_info_.count(expr)) {
TINT_ICE(builder_->Diagnostics())
<< "SetType() called twice for the same expression";
@@ -2195,7 +2260,7 @@ void Resolver::CreateSemanticNodes() const {
}
}
bool Resolver::DefaultAlignAndSize(sem::Type* ty,
bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size,
const Source& source) {
@@ -2363,24 +2428,24 @@ bool Resolver::ValidateArrayStrideDecoration(const ast::StrideDecoration* deco,
return true;
}
bool Resolver::ValidateStructure(const sem::StructType* st) {
for (auto* member : st->impl()->members()) {
if (auto* r = member->type()->UnwrapAll()->As<sem::ArrayType>()) {
bool Resolver::ValidateStructure(const StructInfo* st) {
for (auto* member : st->members) {
if (auto* r = member->Type()->UnwrapAll()->As<sem::ArrayType>()) {
if (r->IsRuntimeArray()) {
if (member != st->impl()->members().back()) {
if (member != st->members.back()) {
diagnostics_.add_error(
"v-0015",
"runtime arrays may only appear as the last member of a struct",
member->source());
member->Declaration()->source());
return false;
}
if (!st->IsBlockDecorated()) {
if (!st->type->impl()->IsBlockDecorated()) {
diagnostics_.add_error(
"v-0015",
"a struct containing a runtime-sized array "
"requires the [[block]] attribute: '" +
builder_->Symbols().NameFor(st->impl()->name()) + "'",
member->source());
builder_->Symbols().NameFor(st->type->impl()->name()) + "'",
member->Declaration()->source());
return false;
}
@@ -2394,7 +2459,7 @@ bool Resolver::ValidateStructure(const sem::StructType* st) {
}
}
for (auto* deco : member->decorations()) {
for (auto* deco : member->Declaration()->decorations()) {
if (!(deco->Is<ast::BuiltinDecoration>() ||
deco->Is<ast::LocationDecoration>() ||
deco->Is<ast::StructMemberOffsetDecoration>() ||
@@ -2407,7 +2472,7 @@ bool Resolver::ValidateStructure(const sem::StructType* st) {
}
}
for (auto* deco : st->impl()->decorations()) {
for (auto* deco : st->type->impl()->decorations()) {
if (!(deco->Is<ast::StructBlockDecoration>())) {
diagnostics_.add_error("decoration is not valid for struct declarations",
deco->source());
@@ -2425,15 +2490,10 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
return info_it->second;
}
Mark(str->impl());
for (auto* deco : str->impl()->decorations()) {
Mark(deco);
}
if (!ValidateStructure(str)) {
return nullptr;
}
sem::StructMemberList sem_members;
sem_members.reserve(str->impl()->members().size());
@@ -2454,12 +2514,16 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
for (auto* member : str->impl()->members()) {
Mark(member);
auto type = member->type();
// Resolve member type
auto* type = Type(member->type());
if (!type) {
return nullptr;
}
// First check the member type is legal
// Validate member type
if (!IsStorable(type)) {
builder_->Diagnostics().add_error(
std::string(type->FriendlyName(builder_->Symbols())) +
type->FriendlyName(builder_->Symbols()) +
" cannot be used as the type of a structure member");
return nullptr;
}
@@ -2518,8 +2582,8 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
offset = utils::RoundUp(align, offset);
auto* sem_member =
builder_->create<sem::StructMember>(member, type, offset, align, size);
auto* sem_member = builder_->create<sem::StructMember>(
member, const_cast<sem::Type*>(type), offset, align, size);
builder_->Sem().Add(member, sem_member);
sem_members.emplace_back(sem_member);
@@ -2531,11 +2595,17 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
struct_size = utils::RoundUp(struct_align, struct_size);
auto* info = struct_infos_.Create();
info->type = str;
info->members = std::move(sem_members);
info->align = struct_align;
info->size = struct_size;
info->size_no_padding = size_no_padding;
struct_info_.emplace(str, info);
if (!ValidateStructure(info)) {
return nullptr;
}
return info;
}
@@ -2745,13 +2815,13 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
return true; // Already applied
}
info->storage_class_usage.emplace(sc);
for (auto* member : str->impl()->members()) {
if (!ApplyStorageClassUsageToType(sc, member->type(), usage)) {
for (auto* member : info->members) {
if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) {
std::stringstream err;
err << "while analysing structure member "
<< str->FriendlyName(builder_->Symbols()) << "."
<< builder_->Symbols().NameFor(member->symbol());
diagnostics_.add_note(err.str(), member->source());
<< builder_->Symbols().NameFor(member->Declaration()->symbol());
diagnostics_.add_note(err.str(), member->Declaration()->source());
return false;
}
}
@@ -2798,6 +2868,11 @@ const sem::Type* Resolver::Canonical(const sem::Type* type) {
using Type = sem::Type;
using Vector = sem::Vector;
if (!type) {
TINT_ICE(diagnostics_) << "Canonical() called with nullptr";
return nullptr;
}
std::function<const Type*(const Type*)> make_canonical;
make_canonical = [&](const Type* t) -> const sem::Type* {
// Unwrap alias sequence

View File

@@ -73,11 +73,11 @@ class Resolver {
/// @param type the given type
/// @returns true if the given type is storable
static bool IsStorable(const sem::Type* type);
bool IsStorable(const sem::Type* type);
/// @param type the given type
/// @returns true if the given type is host-shareable
static bool IsHostShareable(const sem::Type* type);
bool IsHostShareable(const sem::Type* type);
/// @param lhs the assignment store type (non-pointer)
/// @param rhs the assignment source type (non-pointer or pointer with
@@ -148,6 +148,7 @@ class Resolver {
StructInfo();
~StructInfo();
sem::StructType const* type = nullptr;
std::vector<const sem::StructMember*> members;
uint32_t align = 0;
uint32_t size = 0;
@@ -253,16 +254,14 @@ class Resolver {
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type,
const ast::ExpressionList& values);
const sem::Matrix* matrix_type);
bool ValidateParameter(const ast::Variable* param);
bool ValidateReturn(const ast::ReturnStatement* ret);
bool ValidateStructure(const sem::StructType* st);
bool ValidateStructure(const StructInfo* st);
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const ast::Variable* param);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type,
const ast::ExpressionList& values);
const sem::Vector* vec_type);
/// @returns the sem::Type for the ast::Type `ty`, building it if it
/// hasn't been constructed already. If an error is raised, nullptr is
@@ -284,9 +283,12 @@ class Resolver {
/// @returns the VariableInfo for the variable `var`, building it if it hasn't
/// been constructed already. If an error is raised, nullptr is returned.
/// @param var the variable to create or return the `VariableInfo` for
/// @param type optional type of `var` to use instead of
/// `var->declared_type()`. For type inference.
VariableInfo* Variable(ast::Variable* var, const sem::Type* type = nullptr);
/// @param type optional type of `var` to use instead of `var->type()`.
/// @param type_name optional type name of `var` to use instead of
/// `var->type()->FriendlyName()`.
VariableInfo* Variable(ast::Variable* var,
const sem::Type* type = nullptr,
std::string type_name = "");
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
@@ -304,7 +306,7 @@ class Resolver {
/// @param size the output default size in bytes for the type `ty`
/// @param source the Source of the variable declaration of type `ty`
/// @returns true on success, false on error
bool DefaultAlignAndSize(sem::Type* ty,
bool DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size,
const Source& source);
@@ -325,7 +327,7 @@ class Resolver {
/// assigns this semantic node to the expression `expr`.
/// @param expr the expression
/// @param type the resolved type
void SetType(ast::Expression* expr, const sem::Type* type);
void SetType(ast::Expression* expr, typ::Type type);
/// Creates a sem::Expression node with the resolved type `type`, the declared
/// type name `type_name` and assigns this semantic node to the expression
@@ -334,7 +336,7 @@ class Resolver {
/// @param type the resolved type
/// @param type_name the declared type name
void SetType(ast::Expression* expr,
const sem::Type* type,
typ::Type type,
const std::string& type_name);
/// Constructs a new BlockInfo with the given type and with #current_block_ as
@@ -369,6 +371,7 @@ class Resolver {
std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<const sem::StructType*, StructInfo*> struct_info_;
std::unordered_map<const sem::Type*, const sem::Type*> type_to_canonical_;
std::unordered_map<Symbol, const sem::Type*> named_types_;
std::unordered_set<const ast::Node*> marked_;
FunctionInfo* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;

View File

@@ -1024,15 +1024,15 @@ namespace ExprBinaryTest {
struct Params {
ast::BinaryOp op;
create_type_func_ptr create_lhs_type;
create_type_func_ptr create_rhs_type;
create_type_func_ptr create_result_type;
create_ast_type_func_ptr create_lhs_type;
create_ast_type_func_ptr create_rhs_type;
create_sem_type_func_ptr create_result_type;
};
static constexpr create_type_func_ptr all_create_type_funcs[] = {
ty_bool_, ty_u32, ty_i32, ty_f32,
ty_vec3<bool>, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<f32>,
ty_mat3x3<i32>, ty_mat3x3<u32>, ty_mat3x3<f32>};
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
ast_bool, ast_u32, ast_i32, ast_f32,
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>};
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
// matNxN, we only test N=3.
@@ -1041,156 +1041,163 @@ static constexpr Params all_valid_cases[] = {
// https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
// Binary logical expressions
Params{Op::kLogicalAnd, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kLogicalOr, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kLogicalAnd, ast_bool, ast_bool, sem_bool},
Params{Op::kLogicalOr, ast_bool, ast_bool, sem_bool},
Params{Op::kAnd, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kOr, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kAnd, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kOr, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kAnd, ast_bool, ast_bool, sem_bool},
Params{Op::kOr, ast_bool, ast_bool, sem_bool},
Params{Op::kAnd, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
Params{Op::kOr, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
// Arithmetic expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
// Binary arithmetic expressions over scalars
Params{Op::kAdd, ty_i32, ty_i32, ty_i32},
Params{Op::kSubtract, ty_i32, ty_i32, ty_i32},
Params{Op::kMultiply, ty_i32, ty_i32, ty_i32},
Params{Op::kDivide, ty_i32, ty_i32, ty_i32},
Params{Op::kModulo, ty_i32, ty_i32, ty_i32},
Params{Op::kAdd, ast_i32, ast_i32, sem_i32},
Params{Op::kSubtract, ast_i32, ast_i32, sem_i32},
Params{Op::kMultiply, ast_i32, ast_i32, sem_i32},
Params{Op::kDivide, ast_i32, ast_i32, sem_i32},
Params{Op::kModulo, ast_i32, ast_i32, sem_i32},
Params{Op::kAdd, ty_u32, ty_u32, ty_u32},
Params{Op::kSubtract, ty_u32, ty_u32, ty_u32},
Params{Op::kMultiply, ty_u32, ty_u32, ty_u32},
Params{Op::kDivide, ty_u32, ty_u32, ty_u32},
Params{Op::kModulo, ty_u32, ty_u32, ty_u32},
Params{Op::kAdd, ast_u32, ast_u32, sem_u32},
Params{Op::kSubtract, ast_u32, ast_u32, sem_u32},
Params{Op::kMultiply, ast_u32, ast_u32, sem_u32},
Params{Op::kDivide, ast_u32, ast_u32, sem_u32},
Params{Op::kModulo, ast_u32, ast_u32, sem_u32},
Params{Op::kAdd, ty_f32, ty_f32, ty_f32},
Params{Op::kSubtract, ty_f32, ty_f32, ty_f32},
Params{Op::kMultiply, ty_f32, ty_f32, ty_f32},
Params{Op::kDivide, ty_f32, ty_f32, ty_f32},
Params{Op::kModulo, ty_f32, ty_f32, ty_f32},
Params{Op::kAdd, ast_f32, ast_f32, sem_f32},
Params{Op::kSubtract, ast_f32, ast_f32, sem_f32},
Params{Op::kMultiply, ast_f32, ast_f32, sem_f32},
Params{Op::kDivide, ast_f32, ast_f32, sem_f32},
Params{Op::kModulo, ast_f32, ast_f32, sem_f32},
// Binary arithmetic expressions over vectors
Params{Op::kAdd, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kSubtract, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kMultiply, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kDivide, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kModulo, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kAdd, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kSubtract, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kMultiply, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kDivide, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kModulo, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kAdd, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kSubtract, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kMultiply, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kDivide, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kModulo, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kAdd, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kSubtract, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kMultiply, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kDivide, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kModulo, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kAdd, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kSubtract, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kDivide, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kModulo, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kAdd, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kSubtract, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kDivide, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kModulo, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
// Binary arithmetic expressions with mixed scalar, vector, and matrix
// operands
Params{Op::kMultiply, ty_vec3<f32>, ty_f32, ty_vec3<f32>},
Params{Op::kMultiply, ty_f32, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_f32, ty_mat3x3<f32>},
Params{Op::kMultiply, ty_f32, ty_mat3x3<f32>, ty_mat3x3<f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ty_vec3<f32>, ty_mat3x3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_mat3x3<f32>, ty_mat3x3<f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>},
// Comparison expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
// Comparisons over scalars
Params{Op::kEqual, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kNotEqual, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kEqual, ast_bool, ast_bool, sem_bool},
Params{Op::kNotEqual, ast_bool, ast_bool, sem_bool},
Params{Op::kEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kNotEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kLessThan, ty_i32, ty_i32, ty_bool_},
Params{Op::kLessThanEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kGreaterThan, ty_i32, ty_i32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kNotEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kLessThan, ast_i32, ast_i32, sem_bool},
Params{Op::kLessThanEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kGreaterThan, ast_i32, ast_i32, sem_bool},
Params{Op::kGreaterThanEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kNotEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kLessThan, ty_u32, ty_u32, ty_bool_},
Params{Op::kLessThanEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kGreaterThan, ty_u32, ty_u32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kNotEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kLessThan, ast_u32, ast_u32, sem_bool},
Params{Op::kLessThanEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kGreaterThan, ast_u32, ast_u32, sem_bool},
Params{Op::kGreaterThanEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kNotEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kLessThan, ty_f32, ty_f32, ty_bool_},
Params{Op::kLessThanEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kGreaterThan, ty_f32, ty_f32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kEqual, ast_f32, ast_f32, sem_bool},
Params{Op::kNotEqual, ast_f32, ast_f32, sem_bool},
Params{Op::kLessThan, ast_f32, ast_f32, sem_bool},
Params{Op::kLessThanEqual, ast_f32, ast_f32, sem_bool},
Params{Op::kGreaterThan, ast_f32, ast_f32, sem_bool},
Params{Op::kGreaterThanEqual, ast_f32, ast_f32, sem_bool},
// Comparisons over vectors
Params{Op::kEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
Params{Op::kEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kLessThan, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kLessThanEqual, ast_vec3<i32>, ast_vec3<i32>,
sem_vec3<sem_bool>},
Params{Op::kGreaterThan, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kGreaterThanEqual, ast_vec3<i32>, ast_vec3<i32>,
sem_vec3<sem_bool>},
Params{Op::kEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kLessThan, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kLessThanEqual, ast_vec3<u32>, ast_vec3<u32>,
sem_vec3<sem_bool>},
Params{Op::kGreaterThan, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kGreaterThanEqual, ast_vec3<u32>, ast_vec3<u32>,
sem_vec3<sem_bool>},
Params{Op::kEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kLessThan, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kLessThanEqual, ast_vec3<f32>, ast_vec3<f32>,
sem_vec3<sem_bool>},
Params{Op::kGreaterThan, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kGreaterThanEqual, ast_vec3<f32>, ast_vec3<f32>,
sem_vec3<sem_bool>},
// Bit expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#bit-expr
// Binary bitwise operations
Params{Op::kOr, ty_i32, ty_i32, ty_i32},
Params{Op::kAnd, ty_i32, ty_i32, ty_i32},
Params{Op::kXor, ty_i32, ty_i32, ty_i32},
Params{Op::kOr, ast_i32, ast_i32, sem_i32},
Params{Op::kAnd, ast_i32, ast_i32, sem_i32},
Params{Op::kXor, ast_i32, ast_i32, sem_i32},
Params{Op::kOr, ty_u32, ty_u32, ty_u32},
Params{Op::kAnd, ty_u32, ty_u32, ty_u32},
Params{Op::kXor, ty_u32, ty_u32, ty_u32},
Params{Op::kOr, ast_u32, ast_u32, sem_u32},
Params{Op::kAnd, ast_u32, ast_u32, sem_u32},
Params{Op::kXor, ast_u32, ast_u32, sem_u32},
// Bit shift expressions
Params{Op::kShiftLeft, ty_i32, ty_u32, ty_i32},
Params{Op::kShiftLeft, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
Params{Op::kShiftLeft, ast_i32, ast_u32, sem_i32},
Params{Op::kShiftLeft, ast_vec3<i32>, ast_vec3<u32>, sem_vec3<sem_i32>},
Params{Op::kShiftLeft, ty_u32, ty_u32, ty_u32},
Params{Op::kShiftLeft, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kShiftLeft, ast_u32, ast_u32, sem_u32},
Params{Op::kShiftLeft, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kShiftRight, ty_i32, ty_u32, ty_i32},
Params{Op::kShiftRight, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
Params{Op::kShiftRight, ast_i32, ast_u32, sem_i32},
Params{Op::kShiftRight, ast_vec3<i32>, ast_vec3<u32>, sem_vec3<sem_i32>},
Params{Op::kShiftRight, ty_u32, ty_u32, ty_u32},
Params{Op::kShiftRight, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>}};
Params{Op::kShiftRight, ast_u32, ast_u32, sem_u32},
Params{Op::kShiftRight, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}};
using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
TEST_P(Expr_Binary_Test_Valid, All) {
auto& params = GetParam();
auto lhs_type = params.create_lhs_type(ty);
auto rhs_type = params.create_rhs_type(ty);
auto result_type = params.create_result_type(ty);
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = params.create_rhs_type(ty);
auto* result_type = params.create_result_type(ty);
std::stringstream ss;
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols());
ss << FriendlyName(lhs_type) << " " << params.op << " "
<< FriendlyName(rhs_type);
SCOPED_TRACE(ss.str());
Global("lhs", lhs_type, ast::StorageClass::kInput);
@@ -1215,27 +1222,28 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
const Params& params = std::get<0>(GetParam());
BinaryExprSide side = std::get<1>(GetParam());
auto lhs_type = params.create_lhs_type(ty);
auto rhs_type = params.create_rhs_type(ty);
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = params.create_rhs_type(ty);
std::stringstream ss;
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols());
ss << FriendlyName(lhs_type) << " " << params.op << " "
<< FriendlyName(rhs_type);
// For vectors and matrices, wrap the sub type in an alias
auto make_alias = [this](sem::Type* type) -> sem::Type* {
sem::Type* result;
if (auto* v = type->As<sem::Vector>()) {
result = create<sem::Vector>(
create<sem::Alias>(Symbols().New(), v->type()), v->size());
} else if (auto* m = type->As<sem::Matrix>()) {
result =
create<sem::Matrix>(create<sem::Alias>(Symbols().New(), m->type()),
m->rows(), m->columns());
} else {
result = create<sem::Alias>(Symbols().New(), type);
auto make_alias = [this](ast::Type* type) -> ast::Type* {
if (auto* v = type->As<ast::Vector>()) {
auto alias = ty.alias(Symbols().New(), v->type());
AST().AddConstructedType(alias);
return ty.vec(alias, v->size());
}
return result;
if (auto* m = type->As<ast::Matrix>()) {
auto alias = ty.alias(Symbols().New(), m->type());
AST().AddConstructedType(alias);
return ty.mat(alias, m->columns(), m->rows());
}
auto alias = ty.alias(Symbols().New(), type);
AST().AddConstructedType(alias);
return ty.type_name(alias.ast->name());
};
// Wrap in alias
@@ -1246,8 +1254,8 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
rhs_type = make_alias(rhs_type);
}
ss << ", After aliasing: " << lhs_type->FriendlyName(Symbols()) << " "
<< params.op << " " << rhs_type->FriendlyName(Symbols());
ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op
<< " " << FriendlyName(rhs_type);
SCOPED_TRACE(ss.str());
Global("lhs", lhs_type, ast::StorageClass::kInput);
@@ -1261,7 +1269,7 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
ASSERT_NE(TypeOf(expr), nullptr);
// TODO(amaiorano): Bring this back once we have a way to get the canonical
// type
// auto* result_type = params.create_result_type(ty);
// auto* *result_type = params.create_result_type(ty);
// ASSERT_TRUE(TypeOf(expr) == result_type);
}
INSTANTIATE_TEST_SUITE_P(
@@ -1273,10 +1281,10 @@ INSTANTIATE_TEST_SUITE_P(
BinaryExprSide::Both)));
using Expr_Binary_Test_Invalid =
ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
ResolverTestWithParam<std::tuple<Params, create_ast_type_func_ptr>>;
TEST_P(Expr_Binary_Test_Invalid, All) {
const Params& params = std::get<0>(GetParam());
const create_type_func_ptr& create_type_func = std::get<1>(GetParam());
auto& create_type_func = std::get<1>(GetParam());
// Currently, for most operations, for a given lhs type, there is exactly one
// rhs type allowed. The only exception is for multiplication, which allows
@@ -1290,8 +1298,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
return;
}
auto lhs_type = params.create_lhs_type(ty);
auto rhs_type = create_type_func(ty);
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = create_type_func(ty);
// Skip exceptions: multiplication of f32, vecN<f32>, and matNxN<f32>
if (params.op == Op::kMultiply &&
@@ -1301,8 +1309,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
}
std::stringstream ss;
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols());
ss << FriendlyName(lhs_type) << " " << params.op << " "
<< FriendlyName(rhs_type);
SCOPED_TRACE(ss.str());
Global("lhs", lhs_type, ast::StorageClass::kInput);
@@ -1316,9 +1324,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op()) +
" " + FriendlyName(rhs_type));
}
INSTANTIATE_TEST_SUITE_P(
ResolverTest,
@@ -1365,9 +1372,8 @@ TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op()) +
" " + FriendlyName(rhs_type));
}
}
auto all_dimension_values = testing::Values(2u, 3u, 4u);
@@ -1405,9 +1411,8 @@ TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op()) +
" " + FriendlyName(rhs_type));
}
}
INSTANTIATE_TEST_SUITE_P(ResolverTest,

View File

@@ -16,6 +16,7 @@
#define SRC_RESOLVER_RESOLVER_TEST_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include "gtest/gtest.h"
@@ -95,6 +96,14 @@ class TestHelper : public ProgramBuilder {
return true;
}
/// @param type a type
/// @returns the name for `type` that closely resembles how it would be
/// declared in WGSL.
std::string FriendlyName(typ::Type type) {
return type.ast ? type.ast->FriendlyName(Symbols())
: type.sem->FriendlyName(Symbols());
}
private:
std::unique_ptr<Resolver> resolver_;
};
@@ -105,94 +114,151 @@ template <typename T>
class ResolverTestWithParam : public TestHelper,
public testing::TestWithParam<T> {};
inline typ::Type ty_bool_(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_bool(const ProgramBuilder::TypesBuilder& ty) {
return ty.bool_();
}
inline typ::Type ty_i32(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_i32(const ProgramBuilder::TypesBuilder& ty) {
return ty.i32();
}
inline typ::Type ty_u32(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_u32(const ProgramBuilder::TypesBuilder& ty) {
return ty.u32();
}
inline typ::Type ty_f32(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.f32();
}
using create_type_func_ptr =
typ::Type (*)(const ProgramBuilder::TypesBuilder& ty);
using create_ast_type_func_ptr =
ast::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <typename T>
typ::Type ty_vec2(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec2<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_vec2(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec2(create_type(ty));
}
template <typename T>
typ::Type ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3(create_type(ty));
}
template <typename T>
typ::Type ty_vec4(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec4<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_vec4(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec4(create_type(ty));
}
template <typename T>
typ::Type ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x2<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x2(create_type(ty));
}
template <typename T>
typ::Type ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3(create_type(ty));
}
template <typename T>
typ::Type ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat4x4<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat4x4(create_type(ty));
}
template <create_type_func_ptr create_type>
typ::Type ty_alias(const ProgramBuilder::TypesBuilder& ty) {
auto type = create_type(ty);
return ty.alias("alias_" + type->type_name(), type);
template <create_ast_type_func_ptr create_type>
ast::Type* ast_alias(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
auto name = ty.builder->Symbols().Register("alias_" + type->type_name());
if (!ty.builder->AST().LookupType(name)) {
ty.builder->AST().AddConstructedType(ty.alias(name, type));
}
return ty.builder->create<ast::TypeName>(name);
}
template <create_type_func_ptr create_type>
typ::Type ty_access(const ProgramBuilder::TypesBuilder& ty) {
auto type = create_type(ty);
template <create_ast_type_func_ptr create_type>
ast::Type* ast_access(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.access(ast::AccessControl::kReadOnly, type);
}
inline sem::Type* sem_bool(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Bool>();
}
inline sem::Type* sem_i32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::I32>();
}
inline sem::Type* sem_u32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::U32>();
}
inline sem::Type* sem_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::F32>();
}
using create_sem_type_func_ptr =
sem::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <create_sem_type_func_ptr create_type>
sem::Type* sem_vec2(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Vector>(create_type(ty), 2);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Vector>(create_type(ty), 3);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_vec4(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Vector>(create_type(ty), 4);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 2, 2);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 3, 3);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 4, 4);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_access(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.builder->create<sem::AccessControl>(ast::AccessControl::kReadOnly,
type);
}
} // namespace resolver
} // namespace tint

View File

@@ -19,11 +19,11 @@ namespace resolver {
namespace {
/// @return the element type of `type` for vec and mat, otherwise `type` itself
sem::Type* ElementTypeOf(sem::Type* type) {
if (auto* v = type->As<sem::Vector>()) {
ast::Type* ElementTypeOf(ast::Type* type) {
if (auto* v = type->As<ast::Vector>()) {
return v->type();
}
if (auto* m = type->As<sem::Matrix>()) {
if (auto* m = type->As<ast::Matrix>()) {
return m->type();
}
return type;
@@ -34,7 +34,8 @@ class ResolverTypeConstructorValidationTest : public resolver::TestHelper,
namespace InferTypeTest {
struct Params {
create_type_func_ptr create_rhs_type;
create_ast_type_func_ptr create_rhs_ast_type;
create_sem_type_func_ptr create_rhs_sem_type;
};
// Helpers and typedefs
@@ -66,7 +67,7 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
// }
auto& params = GetParam();
auto rhs_type = params.create_rhs_type(ty);
auto* rhs_type = params.create_rhs_ast_type(ty);
auto* constructor_expr = ConstructValueFilledWith(rhs_type, 0);
auto sc = ast::StorageClass::kFunction;
@@ -77,30 +78,33 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type->UnwrapAliasIfNeeded(), sc));
auto* got = TypeOf(a_ident);
auto* expected = ty.pointer(params.create_rhs_sem_type(ty), sc).sem;
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
static constexpr Params from_constructor_expression_cases[] = {
Params{ty_bool_},
Params{ty_i32},
Params{ty_u32},
Params{ty_f32},
Params{ty_vec3<i32>},
Params{ty_vec3<u32>},
Params{ty_vec3<f32>},
Params{ty_mat3x3<i32>},
Params{ty_mat3x3<u32>},
Params{ty_mat3x3<f32>},
Params{ty_alias<ty_bool_>},
Params{ty_alias<ty_i32>},
Params{ty_alias<ty_u32>},
Params{ty_alias<ty_f32>},
Params{ty_alias<ty_vec3<i32>>},
Params{ty_alias<ty_vec3<u32>>},
Params{ty_alias<ty_vec3<f32>>},
Params{ty_alias<ty_mat3x3<i32>>},
Params{ty_alias<ty_mat3x3<u32>>},
Params{ty_alias<ty_mat3x3<f32>>},
Params{ast_bool, sem_bool},
Params{ast_i32, sem_i32},
Params{ast_u32, sem_u32},
Params{ast_f32, sem_f32},
Params{ast_vec3<i32>, sem_vec3<sem_i32>},
Params{ast_vec3<u32>, sem_vec3<sem_u32>},
Params{ast_vec3<f32>, sem_vec3<sem_f32>},
Params{ast_mat3x3<i32>, sem_mat3x3<sem_i32>},
Params{ast_mat3x3<u32>, sem_mat3x3<sem_u32>},
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_bool>, sem_bool},
Params{ast_alias<ast_i32>, sem_i32},
Params{ast_alias<ast_u32>, sem_u32},
Params{ast_alias<ast_f32>, sem_f32},
Params{ast_alias<ast_vec3<i32>>, sem_vec3<sem_i32>},
Params{ast_alias<ast_vec3<u32>>, sem_vec3<sem_u32>},
Params{ast_alias<ast_vec3<f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_mat3x3<i32>>, sem_mat3x3<sem_i32>},
Params{ast_alias<ast_mat3x3<u32>>, sem_mat3x3<sem_u32>},
Params{ast_alias<ast_mat3x3<f32>>, sem_mat3x3<sem_f32>},
};
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
InferTypeTest_FromConstructorExpression,
@@ -114,7 +118,7 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) {
// }
auto& params = GetParam();
auto rhs_type = params.create_rhs_type(ty);
auto* rhs_type = params.create_rhs_ast_type(ty);
auto* arith_lhs_expr = ConstructValueFilledWith(rhs_type, 2);
auto* arith_rhs_expr = ConstructValueFilledWith(ElementTypeOf(rhs_type), 3);
@@ -128,11 +132,17 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) {
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc));
auto* got = TypeOf(a_ident);
auto* expected = ty.pointer(params.create_rhs_sem_type(ty), sc).sem;
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
static constexpr Params from_arithmetic_expression_cases[] = {
Params{ty_i32}, Params{ty_u32}, Params{ty_f32},
Params{ty_vec3<f32>}, Params{ty_mat3x3<f32>},
Params{ast_i32, sem_i32},
Params{ast_u32, sem_u32},
Params{ast_f32, sem_f32},
Params{ast_vec3<f32>, sem_vec3<sem_f32>},
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
// TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed
// Params{ty_alias<ty_i32>},
@@ -159,43 +169,44 @@ TEST_P(InferTypeTest_FromCallExpression, All) {
// }
auto& params = GetParam();
auto rhs_type = params.create_rhs_type(ty);
Func("foo", {}, rhs_type, {Return(ConstructValueFilledWith(rhs_type, 0))},
Func("foo", {}, params.create_rhs_ast_type(ty),
{Return(ConstructValueFilledWith(params.create_rhs_ast_type(ty), 0))},
{});
auto* constructor_expr = Call(Expr("foo"));
auto sc = ast::StorageClass::kFunction;
auto* a = Var("a", nullptr, sc, constructor_expr);
auto* a = Var("a", nullptr, sc, Call(Expr("foo")));
// Self-assign 'a' to force the expression to be resolved so we can test its
// type below
auto* a_ident = Expr("a");
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type->UnwrapAliasIfNeeded(), sc));
auto* got = TypeOf(a_ident);
auto* expected = ty.pointer(params.create_rhs_sem_type(ty), sc).sem;
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
static constexpr Params from_call_expression_cases[] = {
Params{ty_bool_},
Params{ty_i32},
Params{ty_u32},
Params{ty_f32},
Params{ty_vec3<i32>},
Params{ty_vec3<u32>},
Params{ty_vec3<f32>},
Params{ty_mat3x3<i32>},
Params{ty_mat3x3<u32>},
Params{ty_mat3x3<f32>},
Params{ty_alias<ty_bool_>},
Params{ty_alias<ty_i32>},
Params{ty_alias<ty_u32>},
Params{ty_alias<ty_f32>},
Params{ty_alias<ty_vec3<i32>>},
Params{ty_alias<ty_vec3<u32>>},
Params{ty_alias<ty_vec3<f32>>},
Params{ty_alias<ty_mat3x3<i32>>},
Params{ty_alias<ty_mat3x3<u32>>},
Params{ty_alias<ty_mat3x3<f32>>},
Params{ast_bool, sem_bool},
Params{ast_i32, sem_i32},
Params{ast_u32, sem_u32},
Params{ast_f32, sem_f32},
Params{ast_vec3<i32>, sem_vec3<sem_i32>},
Params{ast_vec3<u32>, sem_vec3<sem_u32>},
Params{ast_vec3<f32>, sem_vec3<sem_f32>},
Params{ast_mat3x3<i32>, sem_mat3x3<sem_i32>},
Params{ast_mat3x3<u32>, sem_mat3x3<sem_u32>},
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_bool>, sem_bool},
Params{ast_alias<ast_i32>, sem_i32},
Params{ast_alias<ast_u32>, sem_u32},
Params{ast_alias<ast_f32>, sem_f32},
Params{ast_alias<ast_vec3<i32>>, sem_vec3<sem_i32>},
Params{ast_alias<ast_vec3<u32>>, sem_vec3<sem_u32>},
Params{ast_alias<ast_vec3<f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_mat3x3<i32>>, sem_mat3x3<sem_i32>},
Params{ast_alias<ast_mat3x3<u32>>, sem_mat3x3<sem_u32>},
Params{ast_alias<ast_mat3x3<f32>>, sem_mat3x3<sem_f32>},
};
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
InferTypeTest_FromCallExpression,

View File

@@ -445,48 +445,57 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
namespace GetCanonicalTests {
struct Params {
create_type_func_ptr create_type;
create_type_func_ptr create_canonical_type;
create_ast_type_func_ptr create_ast_type;
create_sem_type_func_ptr create_sem_type;
};
static constexpr Params cases[] = {
Params{ty_bool_, ty_bool_},
Params{ty_alias<ty_bool_>, ty_bool_},
Params{ty_alias<ty_alias<ty_bool_>>, ty_bool_},
Params{ast_bool, sem_bool},
Params{ast_alias<ast_bool>, sem_bool},
Params{ast_alias<ast_alias<ast_bool>>, sem_bool},
Params{ty_vec3<ty_f32>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_f32>>>, ty_vec3<ty_f32>},
Params{ast_vec3<ast_f32>, sem_vec3<sem_f32>},
Params{ast_alias<ast_vec3<ast_f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_alias<ast_vec3<ast_f32>>>, sem_vec3<sem_f32>},
Params{ty_vec3<ty_alias<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_alias<ty_f32>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_f32>>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_alias<ty_f32>>>>>,
ty_vec3<ty_f32>},
Params{ast_vec3<ast_alias<ast_f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_vec3<ast_alias<ast_f32>>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_alias<ast_vec3<ast_alias<ast_f32>>>>,
sem_vec3<sem_f32>},
Params{ast_alias<ast_alias<ast_vec3<ast_alias<ast_alias<ast_f32>>>>>,
sem_vec3<sem_f32>},
Params{ty_mat3x3<ty_alias<ty_f32>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_mat3x3<ty_alias<ty_f32>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_f32>>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_alias<ty_f32>>>>>,
ty_mat3x3<ty_f32>},
Params{ast_mat3x3<ast_alias<ast_f32>>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_mat3x3<ast_alias<ast_f32>>>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_alias<ast_mat3x3<ast_alias<ast_f32>>>>,
sem_mat3x3<sem_f32>},
Params{ast_alias<ast_alias<ast_mat3x3<ast_alias<ast_alias<ast_f32>>>>>,
sem_mat3x3<sem_f32>},
Params{ty_alias<ty_access<ty_alias<ty_bool_>>>, ty_access<ty_bool_>},
Params{ty_alias<ty_access<ty_alias<ty_vec3<ty_access<ty_f32>>>>>,
ty_access<ty_vec3<ty_access<ty_f32>>>},
Params{ty_alias<ty_access<ty_alias<ty_mat3x3<ty_access<ty_f32>>>>>,
ty_access<ty_mat3x3<ty_access<ty_f32>>>},
Params{ast_alias<ast_access<ast_alias<ast_bool>>>, sem_access<sem_bool>},
Params{ast_alias<ast_access<ast_alias<ast_vec3<ast_access<ast_f32>>>>>,
sem_access<sem_vec3<sem_access<sem_f32>>>},
Params{ast_alias<ast_access<ast_alias<ast_mat3x3<ast_access<ast_f32>>>>>,
sem_access<sem_mat3x3<sem_access<sem_f32>>>},
};
using CanonicalTest = ResolverTestWithParam<Params>;
TEST_P(CanonicalTest, All) {
auto& params = GetParam();
auto type = params.create_type(ty);
auto expected_canonical_type = params.create_canonical_type(ty);
auto* type = params.create_ast_type(ty);
auto* canonical_type = r()->Canonical(type);
auto* var = Var("v", type, ast::StorageClass::kFunction);
auto* expr = Expr("v");
WrapInFunction(var, expr);
EXPECT_EQ(canonical_type, expected_canonical_type);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(expr)->UnwrapPtrIfNeeded();
auto* expected = params.create_sem_type(ty);
EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
CanonicalTest,
@@ -529,26 +538,26 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
testing::ValuesIn(dimension_cases));
struct TypeParams {
create_type_func_ptr type_func;
create_ast_type_func_ptr type_func;
bool is_valid;
};
static constexpr TypeParams type_cases[] = {
TypeParams{ty_bool_, false},
TypeParams{ty_i32, true},
TypeParams{ty_u32, true},
TypeParams{ty_f32, true},
TypeParams{ast_bool, false},
TypeParams{ast_i32, true},
TypeParams{ast_u32, true},
TypeParams{ast_f32, true},
TypeParams{ty_alias<ty_bool_>, false},
TypeParams{ty_alias<ty_i32>, true},
TypeParams{ty_alias<ty_u32>, true},
TypeParams{ty_alias<ty_f32>, true},
TypeParams{ast_alias<ast_bool>, false},
TypeParams{ast_alias<ast_i32>, true},
TypeParams{ast_alias<ast_u32>, true},
TypeParams{ast_alias<ast_f32>, true},
TypeParams{ty_vec3<ty_f32>, false},
TypeParams{ty_mat3x3<ty_f32>, false},
TypeParams{ast_vec3<ast_f32>, false},
TypeParams{ast_mat3x3<ast_f32>, false},
TypeParams{ty_alias<ty_vec3<ty_f32>>, false},
TypeParams{ty_alias<ty_mat3x3<ty_f32>>, false}};
TypeParams{ast_alias<ast_vec3<ast_f32>>, false},
TypeParams{ast_alias<ast_mat3x3<ast_f32>>, false}};
using MultisampledTextureTypeTest = ResolverTestWithParam<TypeParams>;
TEST_P(MultisampledTextureTypeTest, All) {

View File

@@ -2041,9 +2041,10 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
TEST_F(ResolverValidationTest, Expr_MatrixConstructor_ArgumentTypeAlias_Error) {
auto alias = ty.alias("VectorUnsigned2", ty.vec2<u32>());
AST().AddConstructedType(alias);
auto* tc = mat2x2<f32>(create<ast::TypeConstructorExpression>(
Source{{12, 34}}, alias, ExprList()),
vec2<f32>());
auto* tc = mat2x2<f32>(
create<ast::TypeConstructorExpression>(
Source{{12, 34}}, ty.MaybeCreateTypename(alias), ExprList()),
vec2<f32>());
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
@@ -2062,7 +2063,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_alias, ExprList()));
Source{{12, i}}, ty.MaybeCreateTypename(vec_alias), ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,