[tint] Move validation code into a Validator class.

This CL moves the Validate methods from the Resolver into a specific
Validator class used by the Resolver.

Bug: tint:1313
Change-Id: Ida21a0cc65f2679739c8499de7065ff8b58c4efc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/87150
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-04-21 13:40:16 +00:00 committed by Dawn LUCI CQ
parent 4091e0fa9c
commit f05575bb21
12 changed files with 1066 additions and 449 deletions

View File

@ -376,9 +376,10 @@ libtint_source_set("libtint_core_all_src") {
"resolver/resolver.cc", "resolver/resolver.cc",
"resolver/resolver.h", "resolver/resolver.h",
"resolver/resolver_constants.cc", "resolver/resolver_constants.cc",
"resolver/resolver_validation.cc",
"resolver/sem_helper.cc", "resolver/sem_helper.cc",
"resolver/sem_helper.h", "resolver/sem_helper.h",
"resolver/validator.cc",
"resolver/validator.h",
"scope_stack.h", "scope_stack.h",
"sem/array.h", "sem/array.h",
"sem/atomic_type.h", "sem/atomic_type.h",

View File

@ -256,10 +256,11 @@ set(TINT_LIB_SRCS
resolver/dependency_graph.h resolver/dependency_graph.h
resolver/resolver.cc resolver/resolver.cc
resolver/resolver_constants.cc resolver/resolver_constants.cc
resolver/resolver_validation.cc
resolver/resolver.h resolver/resolver.h
resolver/sem_helper.cc resolver/sem_helper.cc
resolver/sem_helper.h resolver/sem_helper.h
resolver/validator.cc
resolver/validator.h
scope_stack.h scope_stack.h
sem/array.cc sem/array.cc
sem/array.h sem/array.h

View File

@ -85,7 +85,8 @@ Resolver::Resolver(ProgramBuilder* builder)
: builder_(builder), : builder_(builder),
diagnostics_(builder->Diagnostics()), diagnostics_(builder->Diagnostics()),
builtin_table_(BuiltinTable::Create(*builder)), builtin_table_(BuiltinTable::Create(*builder)),
sem_(builder) {} sem_(builder, dependencies_),
validator_(builder, sem_) {}
Resolver::~Resolver() = default; Resolver::~Resolver() = default;
@ -138,7 +139,7 @@ bool Resolver::ResolveInternal() {
SetShadows(); SetShadows();
if (!ValidatePipelineStages()) { if (!validator_.PipelineStages(entry_points_)) {
return false; return false;
} }
@ -172,7 +173,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
} }
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
if (auto* vector = builder_->create<sem::Vector>(el, t->width)) { if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
if (ValidateVector(vector, t->source)) { if (validator_.Vector(vector, t->source)) {
return vector; return vector;
} }
} }
@ -188,7 +189,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) { if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
if (auto* matrix = if (auto* matrix =
builder_->create<sem::Matrix>(column_type, t->columns)) { builder_->create<sem::Matrix>(column_type, t->columns)) {
if (ValidateMatrix(matrix, t->source)) { if (validator_.Matrix(matrix, t->source)) {
return matrix; return matrix;
} }
} }
@ -200,7 +201,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
[&](const ast::Atomic* t) -> sem::Atomic* { [&](const ast::Atomic* t) -> sem::Atomic* {
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
auto* a = builder_->create<sem::Atomic>(el); auto* a = builder_->create<sem::Atomic>(el);
if (!ValidateAtomic(t, a)) { if (!validator_.Atomic(t, a)) {
return nullptr; return nullptr;
} }
return a; return a;
@ -240,7 +241,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}, },
[&](const ast::StorageTexture* t) -> sem::StorageTexture* { [&](const ast::StorageTexture* t) -> sem::StorageTexture* {
if (auto* el = Type(t->type)) { if (auto* el = Type(t->type)) {
if (!ValidateStorageTexture(t)) { if (!validator_.StorageTexture(t)) {
return nullptr; return nullptr;
} }
return builder_->create<sem::StorageTexture>(t->dim, t->format, return builder_->create<sem::StorageTexture>(t->dim, t->format,
@ -252,7 +253,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
return builder_->create<sem::ExternalTexture>(); return builder_->create<sem::ExternalTexture>();
}, },
[&](Default) { [&](Default) {
auto* resolved = ResolvedSymbol(ty); auto* resolved = sem_.ResolvedSymbol(ty);
return Switch( return Switch(
resolved, // resolved, //
[&](sem::Type* type) { return type; }, [&](sem::Type* type) { return type; },
@ -366,8 +367,8 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
if (kind == VariableKind::kLocal && !var->is_const && if (kind == VariableKind::kLocal && !var->is_const &&
storage_class != ast::StorageClass::kFunction && storage_class != ast::StorageClass::kFunction &&
IsValidationEnabled(var->attributes, validator_.IsValidationEnabled(
ast::DisabledValidation::kIgnoreStorageClass)) { var->attributes, ast::DisabledValidation::kIgnoreStorageClass)) {
AddError("function variable has a non-function storage class", var->source); AddError("function variable has a non-function storage class", var->source);
return nullptr; return nullptr;
} }
@ -385,8 +386,8 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
builder_->create<sem::Reference>(storage_ty, storage_class, access); builder_->create<sem::Reference>(storage_ty, storage_class, access);
} }
if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty, if (rhs && !validator_.VariableConstructorOrCast(var, storage_class,
rhs->Type())) { storage_ty, rhs->Type())) {
return nullptr; return nullptr;
} }
@ -547,17 +548,17 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) {
} }
} }
if (!ValidateNoDuplicateAttributes(var->attributes)) { if (!validator_.NoDuplicateAttributes(var->attributes)) {
return nullptr; return nullptr;
} }
if (!ValidateGlobalVariable(sem)) { if (!validator_.GlobalVariable(sem, constant_ids_, atomic_composite_info_)) {
return nullptr; return nullptr;
} }
// TODO(bclayton): Call this at the end of resolve on all uniform and storage // TODO(bclayton): Call this at the end of resolve on all uniform and storage
// referenced structs // referenced structs
if (!ValidateStorageClassLayout(sem, valid_type_storage_layouts_)) { if (!validator_.StorageClassLayout(sem, valid_type_storage_layouts_)) {
return nullptr; return nullptr;
} }
@ -592,7 +593,7 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
for (auto* attr : param->attributes) { for (auto* attr : param->attributes) {
Mark(attr); Mark(attr);
} }
if (!ValidateNoDuplicateAttributes(param->attributes)) { if (!validator_.NoDuplicateAttributes(param->attributes)) {
return nullptr; return nullptr;
} }
@ -691,21 +692,21 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
for (auto* attr : decl->attributes) { for (auto* attr : decl->attributes) {
Mark(attr); Mark(attr);
} }
if (!ValidateNoDuplicateAttributes(decl->attributes)) { if (!validator_.NoDuplicateAttributes(decl->attributes)) {
return nullptr; return nullptr;
} }
for (auto* attr : decl->return_type_attributes) { for (auto* attr : decl->return_type_attributes) {
Mark(attr); Mark(attr);
} }
if (!ValidateNoDuplicateAttributes(decl->return_type_attributes)) { if (!validator_.NoDuplicateAttributes(decl->return_type_attributes)) {
return nullptr; return nullptr;
} }
auto stage = current_function_ auto stage = current_function_
? current_function_->Declaration()->PipelineStage() ? current_function_->Declaration()->PipelineStage()
: ast::PipelineStage::kNone; : ast::PipelineStage::kNone;
if (!ValidateFunction(func, stage)) { if (!validator_.Function(func, stage)) {
return nullptr; return nullptr;
} }
@ -809,7 +810,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
<< "could not resolve constant workgroup_size constant value"; << "could not resolve constant workgroup_size constant value";
continue; continue;
} }
// Validate and set the default value for this dimension. // validator_.Validate and set the default value for this dimension.
if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) { if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source); AddError("workgroup_size argument must be at least 1", values[i]->source);
return false; return false;
@ -843,7 +844,7 @@ bool Resolver::Statements(const ast::StatementList& stmts) {
current_statement_->Behaviors() = behaviors; current_statement_->Behaviors() = behaviors;
if (!ValidateStatements(stmts)) { if (!validator_.Statements(stmts)) {
return false; return false;
} }
@ -958,7 +959,7 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
sem->Behaviors().Add(sem::Behavior::kNext); sem->Behaviors().Add(sem::Behavior::kNext);
} }
return ValidateIfStatement(sem); return validator_.IfStatement(sem);
}); });
} }
@ -989,7 +990,7 @@ sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
} }
sem->Behaviors().Add(body->Behaviors()); sem->Behaviors().Add(body->Behaviors());
return ValidateElseStatement(sem); return validator_.ElseStatement(sem);
}); });
} }
@ -1039,7 +1040,7 @@ sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
} }
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue); behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
return ValidateLoopStatement(sem); return validator_.LoopStatement(sem);
}); });
}); });
} }
@ -1095,7 +1096,7 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(
} }
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue); behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
return ValidateForLoopStatement(sem); return validator_.ForLoopStatement(sem);
}); });
} }
@ -1226,7 +1227,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
sem->Behaviors() = inner->Behaviors(); sem->Behaviors() = inner->Behaviors();
if (!ValidateBitcast(expr, ty)) { if (!validator_.Bitcast(expr, ty)) {
return nullptr; return nullptr;
} }
@ -1316,7 +1317,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
Mark(vec); Mark(vec);
auto* v = builder_->create<sem::Vector>( auto* v = builder_->create<sem::Vector>(
arg_el_ty, static_cast<uint32_t>(vec->width)); arg_el_ty, static_cast<uint32_t>(vec->width));
if (!ValidateVector(v, vec->source)) { if (!validator_.Vector(v, vec->source)) {
return nullptr; return nullptr;
} }
builder_->Sem().Add(vec, v); builder_->Sem().Add(vec, v);
@ -1337,7 +1338,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
auto* column_type = auto* column_type =
builder_->create<sem::Vector>(arg_el_ty, mat->rows); builder_->create<sem::Vector>(arg_el_ty, mat->rows);
auto* m = builder_->create<sem::Matrix>(column_type, mat->columns); auto* m = builder_->create<sem::Matrix>(column_type, mat->columns);
if (!ValidateMatrix(m, mat->source)) { if (!validator_.Matrix(m, mat->source)) {
return nullptr; return nullptr;
} }
builder_->Sem().Add(mat, m); builder_->Sem().Add(mat, m);
@ -1359,7 +1360,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
auto* ident = expr->target.name; auto* ident = expr->target.name;
Mark(ident); Mark(ident);
auto* resolved = ResolvedSymbol(ident); auto* resolved = sem_.ResolvedSymbol(ident);
return Switch( return Switch(
resolved, // resolved, //
[&](sem::Type* type) { return type_ctor_or_conv(type); }, [&](sem::Type* type) { return type_ctor_or_conv(type); },
@ -1414,7 +1415,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
current_function_->AddDirectlyCalledBuiltin(builtin); current_function_->AddDirectlyCalledBuiltin(builtin);
if (IsTextureBuiltin(builtin_type)) { if (IsTextureBuiltin(builtin_type)) {
if (!ValidateTextureBuiltinFunction(call)) { if (!validator_.TextureBuiltinFunction(call)) {
return nullptr; return nullptr;
} }
// Collect a texture/sampler pair for this builtin. // Collect a texture/sampler pair for this builtin.
@ -1436,7 +1437,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
} }
} }
if (!ValidateBuiltinCall(call)) { if (!validator_.BuiltinCall(call)) {
return nullptr; return nullptr;
} }
@ -1500,7 +1501,7 @@ sem::Call* Resolver::FunctionCall(
call->Behaviors() = arg_behaviors + target->Behaviors(); call->Behaviors() = arg_behaviors + target->Behaviors();
if (!ValidateFunctionCall(call)) { if (!validator_.FunctionCall(call, current_statement_)) {
return nullptr; return nullptr;
} }
@ -1527,23 +1528,23 @@ sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr,
bool ok = Switch( bool ok = Switch(
target, target,
[&](const sem::Vector* vec_type) { [&](const sem::Vector* vec_type) {
return ValidateVectorConstructorOrCast(expr, vec_type); return validator_.VectorConstructorOrCast(expr, vec_type);
}, },
[&](const sem::Matrix* mat_type) { [&](const sem::Matrix* mat_type) {
// Note: Matrix types currently cannot be converted (the element // Note: Matrix types currently cannot be converted (the element
// type must only be f32). We implement this for the day we // type must only be f32). We implement this for the day we
// support other matrix element types. // support other matrix element types.
return ValidateMatrixConstructorOrCast(expr, mat_type); return validator_.MatrixConstructorOrCast(expr, mat_type);
}, },
[&](const sem::Array* arr_type) { [&](const sem::Array* arr_type) {
return ValidateArrayConstructorOrCast(expr, arr_type); return validator_.ArrayConstructorOrCast(expr, arr_type);
}, },
[&](const sem::Struct* struct_type) { [&](const sem::Struct* struct_type) {
return ValidateStructureConstructorOrCast(expr, struct_type); return validator_.StructureConstructorOrCast(expr, struct_type);
}, },
[&](Default) { [&](Default) {
if (target->is_scalar()) { if (target->is_scalar()) {
return ValidateScalarConstructorOrCast(expr, target); return validator_.ScalarConstructorOrCast(expr, target);
} }
AddError("type is not constructible", expr->source); AddError("type is not constructible", expr->source);
return false; return false;
@ -1593,20 +1594,20 @@ sem::Call* Resolver::TypeConstructor(
bool ok = Switch( bool ok = Switch(
ty, ty,
[&](const sem::Vector* vec_type) { [&](const sem::Vector* vec_type) {
return ValidateVectorConstructorOrCast(expr, vec_type); return validator_.VectorConstructorOrCast(expr, vec_type);
}, },
[&](const sem::Matrix* mat_type) { [&](const sem::Matrix* mat_type) {
return ValidateMatrixConstructorOrCast(expr, mat_type); return validator_.MatrixConstructorOrCast(expr, mat_type);
}, },
[&](const sem::Array* arr_type) { [&](const sem::Array* arr_type) {
return ValidateArrayConstructorOrCast(expr, arr_type); return validator_.ArrayConstructorOrCast(expr, arr_type);
}, },
[&](const sem::Struct* struct_type) { [&](const sem::Struct* struct_type) {
return ValidateStructureConstructorOrCast(expr, struct_type); return validator_.StructureConstructorOrCast(expr, struct_type);
}, },
[&](Default) { [&](Default) {
if (ty->is_scalar()) { if (ty->is_scalar()) {
return ValidateScalarConstructorOrCast(expr, ty); return validator_.ScalarConstructorOrCast(expr, ty);
} }
AddError("type is not constructible", expr->source); AddError("type is not constructible", expr->source);
return false; return false;
@ -1652,7 +1653,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
auto symbol = expr->symbol; auto symbol = expr->symbol;
auto* resolved = ResolvedSymbol(expr); auto* resolved = sem_.ResolvedSymbol(expr);
if (auto* var = As<sem::Variable>(resolved)) { if (auto* var = As<sem::Variable>(resolved)) {
auto* user = auto* user =
builder_->create<sem::VariableUser>(expr, current_statement_, var); builder_->create<sem::VariableUser>(expr, current_statement_, var);
@ -2156,7 +2157,8 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr; return nullptr;
} }
if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize() if (!validator_.IsPlain(
elem_type)) { // Check must come before GetDefaultAlignAndSize()
AddError(sem_.TypeNameOf(elem_type) + AddError(sem_.TypeNameOf(elem_type) +
" cannot be used as an element type of an array", " cannot be used as an element type of an array",
source); source);
@ -2166,7 +2168,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
uint32_t el_align = elem_type->Align(); uint32_t el_align = elem_type->Align();
uint32_t el_size = elem_type->Size(); uint32_t el_size = elem_type->Size();
if (!ValidateNoDuplicateAttributes(arr->attributes)) { if (!validator_.NoDuplicateAttributes(arr->attributes)) {
return nullptr; return nullptr;
} }
@ -2176,7 +2178,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
Mark(attr); Mark(attr);
if (auto* sd = attr->As<ast::StrideAttribute>()) { if (auto* sd = attr->As<ast::StrideAttribute>()) {
explicit_stride = sd->stride; explicit_stride = sd->stride;
if (!ValidateArrayStrideAttribute(sd, el_size, el_align, source)) { if (!validator_.ArrayStrideAttribute(sd, el_size, el_align, source)) {
return nullptr; return nullptr;
} }
continue; continue;
@ -2210,7 +2212,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) { if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant. // Make sure the identifier is a non-overridable module-scope constant.
auto* var = ResolvedSymbol<sem::GlobalVariable>(ident); auto* var = sem_.ResolvedSymbol<sem::GlobalVariable>(ident);
if (!var || !var->Declaration()->is_const) { if (!var || !var->Declaration()->is_const) {
AddError("array size identifier must be a module-scope constant", AddError("array size identifier must be a module-scope constant",
size_source); size_source);
@ -2266,7 +2268,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
elem_type, count, el_align, static_cast<uint32_t>(size), elem_type, count, el_align, static_cast<uint32_t>(size),
static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride)); static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
if (!ValidateArray(out, source)) { if (!validator_.Array(out, source)) {
return nullptr; return nullptr;
} }
@ -2287,14 +2289,14 @@ sem::Type* Resolver::Alias(const ast::Alias* alias) {
if (!ty) { if (!ty) {
return nullptr; return nullptr;
} }
if (!ValidateAlias(alias)) { if (!validator_.Alias(alias)) {
return nullptr; return nullptr;
} }
return ty; return ty;
} }
sem::Struct* Resolver::Structure(const ast::Struct* str) { sem::Struct* Resolver::Structure(const ast::Struct* str) {
if (!ValidateNoDuplicateAttributes(str->attributes)) { if (!validator_.NoDuplicateAttributes(str->attributes)) {
return nullptr; return nullptr;
} }
for (auto* attr : str->attributes) { for (auto* attr : str->attributes) {
@ -2335,8 +2337,8 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
return nullptr; return nullptr;
} }
// Validate member type // validator_.Validate member type
if (!IsPlain(type)) { if (!validator_.IsPlain(type)) {
AddError(sem_.TypeNameOf(type) + AddError(sem_.TypeNameOf(type) +
" cannot be used as the type of a structure member", " cannot be used as the type of a structure member",
member->source); member->source);
@ -2347,7 +2349,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
uint64_t align = type->Align(); uint64_t align = type->Align();
uint64_t size = type->Size(); uint64_t size = type->Size();
if (!ValidateNoDuplicateAttributes(member->attributes)) { if (!validator_.NoDuplicateAttributes(member->attributes)) {
return nullptr; return nullptr;
} }
@ -2453,7 +2455,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
auto stage = current_function_ auto stage = current_function_
? current_function_->Declaration()->PipelineStage() ? current_function_->Declaration()->PipelineStage()
: ast::PipelineStage::kNone; : ast::PipelineStage::kNone;
if (!ValidateStructure(out, stage)) { if (!validator_.Structure(out, stage)) {
return nullptr; return nullptr;
} }
@ -2479,7 +2481,8 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
// is available for validation. // is available for validation.
auto* ret_type = stmt->value ? sem_.TypeOf(stmt->value)->UnwrapRef() auto* ret_type = stmt->value ? sem_.TypeOf(stmt->value)->UnwrapRef()
: builder_->create<sem::Void>(); : builder_->create<sem::Void>();
return ValidateReturn(stmt, current_function_->ReturnType(), ret_type); return validator_.Return(stmt, current_function_->ReturnType(), ret_type,
current_statement_);
}); });
} }
@ -2510,7 +2513,7 @@ sem::SwitchStatement* Resolver::SwitchStatement(
} }
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough); behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
return ValidateSwitch(stmt); return validator_.SwitchStatement(stmt);
}); });
} }
@ -2542,7 +2545,7 @@ sem::Statement* Resolver::VariableDeclStatement(
sem->Behaviors() = ctor->Behaviors(); sem->Behaviors() = ctor->Behaviors();
} }
return ValidateVariable(var); return validator_.Variable(var);
}); });
} }
@ -2567,7 +2570,7 @@ sem::Statement* Resolver::AssignmentStatement(
behaviors.Add(lhs->Behaviors()); behaviors.Add(lhs->Behaviors());
} }
return ValidateAssignment(stmt, sem_.TypeOf(stmt->rhs)); return validator_.Assignment(stmt, sem_.TypeOf(stmt->rhs));
}); });
} }
@ -2577,7 +2580,7 @@ sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
return StatementScope(stmt, sem, [&] { return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kBreak; sem->Behaviors() = sem::Behavior::kBreak;
return ValidateBreakStatement(sem); return validator_.BreakStatement(sem, current_statement_);
}); });
} }
@ -2620,7 +2623,7 @@ sem::Statement* Resolver::CompoundAssignmentStatement(
stmt->source); stmt->source);
return false; return false;
} }
return ValidateAssignment(stmt, ty); return validator_.Assignment(stmt, ty);
}); });
} }
@ -2639,7 +2642,7 @@ sem::Statement* Resolver::ContinueStatement(
} }
} }
return ValidateContinueStatement(sem); return validator_.ContinueStatement(sem, current_statement_);
}); });
} }
@ -2650,7 +2653,7 @@ sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
sem->Behaviors() = sem::Behavior::kDiscard; sem->Behaviors() = sem::Behavior::kDiscard;
current_function_->SetHasDiscard(); current_function_->SetHasDiscard();
return ValidateDiscardStatement(sem); return validator_.DiscardStatement(sem, current_statement_);
}); });
} }
@ -2661,7 +2664,7 @@ sem::Statement* Resolver::FallthroughStatement(
return StatementScope(stmt, sem, [&] { return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kFallthrough; sem->Behaviors() = sem::Behavior::kFallthrough;
return ValidateFallthroughStatement(sem); return validator_.FallthroughStatement(sem);
}); });
} }
@ -2676,7 +2679,7 @@ sem::Statement* Resolver::IncrementDecrementStatement(
} }
sem->Behaviors() = lhs->Behaviors(); sem->Behaviors() = lhs->Behaviors();
return ValidateIncrementDecrementStatement(stmt); return validator_.IncrementDecrementStatement(stmt);
}); });
} }
@ -2718,7 +2721,7 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
sc, const_cast<sem::Type*>(arr->ElemType()), usage); sc, const_cast<sem::Type*>(arr->ElemType()), usage);
} }
if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) { if (ast::IsHostShareable(sc) && !validator_.IsHostShareable(ty)) {
std::stringstream err; std::stringstream err;
err << "Type '" << sem_.TypeNameOf(ty) err << "Type '" << sem_.TypeNameOf(ty)
<< "' cannot be used in storage class '" << sc << "' cannot be used in storage class '" << sc
@ -2782,62 +2785,6 @@ void Resolver::AddNote(const std::string& msg, const Source& source) const {
diagnostics_.add_note(diag::System::Resolver, msg, source); diagnostics_.add_note(diag::System::Resolver, msg, source);
} }
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
bool Resolver::IsPlain(const sem::Type* type) const {
return type->is_scalar() ||
type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
sem::Struct>();
}
// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types
bool Resolver::IsFixedFootprint(const sem::Type* type) const {
return Switch(
type, //
[&](const sem::Vector*) { return true; }, //
[&](const sem::Matrix*) { return true; }, //
[&](const sem::Atomic*) { return true; },
[&](const sem::Array* arr) {
return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType());
},
[&](const sem::Struct* str) {
for (auto* member : str->Members()) {
if (!IsFixedFootprint(member->Type())) {
return false;
}
}
return true;
},
[&](Default) { return type->is_scalar(); });
}
// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
bool Resolver::IsStorable(const sem::Type* type) const {
return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
}
// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
bool Resolver::IsHostShareable(const sem::Type* type) const {
if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
return true;
}
return Switch(
type, //
[&](const sem::Vector* vec) { return IsHostShareable(vec->type()); },
[&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); },
[&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); },
[&](const sem::Struct* str) {
for (auto* member : str->Members()) {
if (!IsHostShareable(member->Type())) {
return false;
}
}
return true;
},
[&](const sem::Atomic* atomic) {
return IsHostShareable(atomic->Type());
});
}
bool Resolver::IsBuiltin(Symbol symbol) const { bool Resolver::IsBuiltin(Symbol symbol) const {
std::string name = builder_->Symbols().NameFor(symbol); std::string name = builder_->Symbols().NameFor(symbol);
return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone; return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone;
@ -2849,26 +2796,6 @@ bool Resolver::IsCallStatement(const ast::Expression* expr) const {
[&](auto* stmt) { return stmt->expr == expr; }); [&](auto* stmt) { return stmt->expr == expr; });
} }
const ast::Statement* Resolver::ClosestContinuing(bool stop_at_loop) const {
for (const auto* s = current_statement_; s != nullptr; s = s->Parent()) {
if (stop_at_loop && s->Is<sem::LoopStatement>()) {
break;
}
if (s->Is<sem::LoopContinuingBlockStatement>()) {
return s->Declaration();
}
if (auto* f = As<sem::ForLoopStatement>(s->Parent())) {
if (f->Declaration()->continuing == s->Declaration()) {
return s->Declaration();
}
if (stop_at_loop) {
break;
}
}
}
return nullptr;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Resolver::TypeConversionSig // Resolver::TypeConversionSig
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -16,7 +16,6 @@
#define SRC_TINT_RESOLVER_RESOLVER_H_ #define SRC_TINT_RESOLVER_RESOLVER_H_
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@ -27,13 +26,13 @@
#include "src/tint/program_builder.h" #include "src/tint/program_builder.h"
#include "src/tint/resolver/dependency_graph.h" #include "src/tint/resolver/dependency_graph.h"
#include "src/tint/resolver/sem_helper.h" #include "src/tint/resolver/sem_helper.h"
#include "src/tint/resolver/validator.h"
#include "src/tint/scope_stack.h" #include "src/tint/scope_stack.h"
#include "src/tint/sem/binding_point.h" #include "src/tint/sem/binding_point.h"
#include "src/tint/sem/block_statement.h" #include "src/tint/sem/block_statement.h"
#include "src/tint/sem/constant.h" #include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h" #include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h" #include "src/tint/sem/struct.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/unique_vector.h" #include "src/tint/utils/unique_vector.h"
// Forward declarations // Forward declarations
@ -89,27 +88,31 @@ class Resolver {
/// @param type the given type /// @param type the given type
/// @returns true if the given type is a plain type /// @returns true if the given type is a plain type
bool IsPlain(const sem::Type* type) const; bool IsPlain(const sem::Type* type) const { return validator_.IsPlain(type); }
/// @param type the given type /// @param type the given type
/// @returns true if the given type is a fixed-footprint type /// @returns true if the given type is a fixed-footprint type
bool IsFixedFootprint(const sem::Type* type) const; bool IsFixedFootprint(const sem::Type* type) const {
return validator_.IsFixedFootprint(type);
}
/// @param type the given type /// @param type the given type
/// @returns true if the given type is storable /// @returns true if the given type is storable
bool IsStorable(const sem::Type* type) const; bool IsStorable(const sem::Type* type) const {
return validator_.IsStorable(type);
}
/// @param type the given type /// @param type the given type
/// @returns true if the given type is host-shareable /// @returns true if the given type is host-shareable
bool IsHostShareable(const sem::Type* type) const; bool IsHostShareable(const sem::Type* type) const {
return validator_.IsHostShareable(type);
}
private: private:
/// Describes the context in which a variable is declared /// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal }; enum class VariableKind { kParameter, kLocal, kGlobal };
using ValidTypeStorageLayouts = Validator::ValidTypeStorageLayouts valid_type_storage_layouts_;
std::set<std::pair<const sem::Type*, ast::StorageClass>>;
ValidTypeStorageLayouts valid_type_storage_layouts_;
/// Structure holding semantic information about a block (i.e. scope), such as /// Structure holding semantic information about a block (i.e. scope), such as
/// parent block and variables declared in the block. /// parent block and variables declared in the block.
@ -237,106 +240,6 @@ class Resolver {
const sem::Type* rhs_ty, const sem::Type* rhs_ty,
ast::BinaryOp op); ast::BinaryOp op);
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidatePipelineStages() const;
bool ValidateAlias(const ast::Alias*) const;
bool ValidateArray(const sem::Array* arr, const Source& source) const;
bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size,
uint32_t el_align,
const Source& source) const;
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) const;
bool ValidateAtomicVariable(const sem::Variable* var) const;
bool ValidateAssignment(const ast::Statement* a,
const sem::Type* rhs_ty) const;
bool ValidateBitcast(const ast::BitcastExpression* cast,
const sem::Type* to) const;
bool ValidateBreakStatement(const sem::Statement* stmt) const;
bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_type,
ast::PipelineStage stage,
const bool is_input) const;
bool ValidateContinueStatement(const sem::Statement* stmt) const;
bool ValidateDiscardStatement(const sem::Statement* stmt) const;
bool ValidateElseStatement(const sem::ElseStatement* stmt) const;
bool ValidateEntryPoint(const sem::Function* func,
ast::PipelineStage stage) const;
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt) const;
bool ValidateFallthroughStatement(const sem::Statement* stmt) const;
bool ValidateFunction(const sem::Function* func,
ast::PipelineStage stage) const;
bool ValidateFunctionCall(const sem::Call* call) const;
bool ValidateGlobalVariable(const sem::Variable* var) const;
bool ValidateIfStatement(const sem::IfStatement* stmt) const;
bool ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* stmt) const;
bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr,
const sem::Type* storage_type) const;
bool ValidateBuiltinCall(const sem::Call* call) const;
bool ValidateLocationAttribute(const ast::LocationAttribute* location,
const sem::Type* type,
std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input = false) const;
bool ValidateLoopStatement(const sem::LoopStatement* stmt) const;
bool ValidateMatrix(const sem::Matrix* ty, const Source& source) const;
bool ValidateFunctionParameter(const ast::Function* func,
const sem::Variable* var) const;
bool ValidateReturn(const ast::ReturnStatement* ret,
const sem::Type* func_type,
const sem::Type* ret_type) const;
bool ValidateStatements(const ast::StatementList& stmts) const;
bool ValidateStorageTexture(const ast::StorageTexture* t) const;
bool ValidateStructure(const sem::Struct* str,
ast::PipelineStage stage) const;
bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
const sem::Struct* struct_type) const;
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const sem::Variable* var) const;
bool ValidateVariableConstructorOrCast(const ast::Variable* var,
ast::StorageClass storage_class,
const sem::Type* storage_type,
const sem::Type* rhs_type) const;
bool ValidateVector(const sem::Vector* ty, const Source& source) const;
bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
const sem::Vector* vec_type) const;
bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
const sem::Matrix* matrix_type) const;
bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* type) const;
bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
const sem::Array* arr_type) const;
bool ValidateTextureBuiltinFunction(const sem::Call* call) const;
bool ValidateNoDuplicateAttributes(
const ast::AttributeList& attributes) const;
bool ValidateStorageClassLayout(const sem::Type* type,
ast::StorageClass sc,
Source source,
ValidTypeStorageLayouts& layouts) const;
bool ValidateStorageClassLayout(const sem::Variable* var,
ValidTypeStorageLayouts& layouts) const;
/// @returns true if the attribute list contains a
/// ast::DisableValidationAttribute with the validation mode equal to
/// `validation`
bool IsValidationDisabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const;
/// @returns true if the attribute list does not contains a
/// ast::DisableValidationAttribute with the validation mode equal to
/// `validation`
bool IsValidationEnabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const;
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, const sem::Type* element_type) const;
/// Resolves the WorkgroupSize for the given function, assigning it to /// Resolves the WorkgroupSize for the given function, assigning it to
/// current_function_ /// current_function_
bool WorkgroupSize(const ast::Function*); bool WorkgroupSize(const ast::Function*);
@ -457,23 +360,6 @@ class Resolver {
/// @returns true if `expr` is the current CallStatement's CallExpression /// @returns true if `expr` is the current CallStatement's CallExpression
bool IsCallStatement(const ast::Expression* expr) const; bool IsCallStatement(const ast::Expression* expr) const;
/// Searches the current statement and up through parents of the current
/// statement looking for a loop or for-loop continuing statement.
/// @returns the closest continuing statement to the current statement that
/// (transitively) owns the current statement.
/// @param stop_at_loop if true then the function will return nullptr if a
/// loop or for-loop was found before the continuing.
const ast::Statement* ClosestContinuing(bool stop_at_loop) const;
/// @returns the resolved symbol (function, type or variable) for the given
/// ast::Identifier or ast::TypeName cast to the given semantic type.
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
: nullptr;
}
struct TypeConversionSig { struct TypeConversionSig {
const sem::Type* target; const sem::Type* target;
const sem::Type* source; const sem::Type* source;
@ -511,6 +397,7 @@ class Resolver {
std::unique_ptr<BuiltinTable> const builtin_table_; std::unique_ptr<BuiltinTable> const builtin_table_;
DependencyGraph dependencies_; DependencyGraph dependencies_;
SemHelper sem_; SemHelper sem_;
Validator validator_;
std::vector<sem::Function*> entry_points_; std::vector<sem::Function*> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_; std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
std::unordered_set<const ast::Node*> marked_; std::unordered_set<const ast::Node*> marked_;

View File

@ -0,0 +1,79 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/resolver.h"
#include "gmock/gmock.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/atomic_type.h"
namespace tint::resolver {
namespace {
using ResolverIsStorableTest = ResolverTest;
TEST_F(ResolverIsStorableTest, Struct_AllMembersStorable) {
Structure("S", {
Member("a", ty.i32()),
Member("b", ty.f32()),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIsStorableTest, Struct_SomeMembersNonStorable) {
Structure("S", {
Member("a", ty.i32()),
Member("b", ty.pointer<i32>(ast::StorageClass::kPrivate)),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(error: ptr<private, i32, read_write> cannot be used as the type of a structure member)");
}
TEST_F(ResolverIsStorableTest, Struct_NestedStorable) {
auto* storable = Structure("Storable", {
Member("a", ty.i32()),
Member("b", ty.f32()),
});
Structure("S", {
Member("a", ty.i32()),
Member("b", ty.Of(storable)),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverIsStorableTest, Struct_NestedNonStorable) {
auto* non_storable =
Structure("nonstorable",
{
Member("a", ty.i32()),
Member("b", ty.pointer<i32>(ast::StorageClass::kPrivate)),
});
Structure("S", {
Member("a", ty.i32()),
Member("b", ty.Of(non_storable)),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(error: ptr<private, i32, read_write> cannot be used as the type of a structure member)");
}
} // namespace
} // namespace tint::resolver

View File

@ -18,7 +18,8 @@
namespace tint::resolver { namespace tint::resolver {
SemHelper::SemHelper(ProgramBuilder* builder) : builder_(builder) {} SemHelper::SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies)
: builder_(builder), dependencies_(dependencies) {}
SemHelper::~SemHelper() = default; SemHelper::~SemHelper() = default;

View File

@ -19,6 +19,8 @@
#include "src/tint/diagnostic/diagnostic.h" #include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/program_builder.h" #include "src/tint/program_builder.h"
#include "src/tint/resolver/dependency_graph.h"
#include "src/tint/utils/map.h"
namespace tint::resolver { namespace tint::resolver {
@ -27,7 +29,8 @@ class SemHelper {
public: public:
/// Constructor /// Constructor
/// @param builder the program builder /// @param builder the program builder
explicit SemHelper(ProgramBuilder* builder); /// @param dependencies the program dependency graph
explicit SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies);
~SemHelper(); ~SemHelper();
/// Get is a helper for obtaining the semantic node for the given AST node. /// Get is a helper for obtaining the semantic node for the given AST node.
@ -47,6 +50,16 @@ class SemHelper {
return const_cast<T*>(As<T>(sem)); return const_cast<T*>(As<T>(sem));
} }
/// @returns the resolved symbol (function, type or variable) for the given
/// ast::Identifier or ast::TypeName cast to the given semantic type.
/// @param node the node to retrieve
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
: nullptr;
}
/// @returns the resolved type of the ast::Expression `expr` /// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression /// @param expr the expression
sem::Type* TypeOf(const ast::Expression* expr) const; sem::Type* TypeOf(const ast::Expression* expr) const;
@ -67,6 +80,7 @@ class SemHelper {
private: private:
ProgramBuilder* builder_; ProgramBuilder* builder_;
DependencyGraph& dependencies_;
}; };
} // namespace tint::resolver } // namespace tint::resolver

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/tint/resolver/resolver.h" #include "src/tint/resolver/validator.h"
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
@ -149,8 +149,104 @@ void TraverseCallChain(diag::List& diagnostics,
} // namespace } // namespace
bool Resolver::ValidateAtomic(const ast::Atomic* a, Validator::Validator(ProgramBuilder* builder, SemHelper& sem)
const sem::Atomic* s) const { : symbols_(builder->Symbols()),
diagnostics_(builder->Diagnostics()),
sem_(sem) {}
Validator::~Validator() = default;
void Validator::AddError(const std::string& msg, const Source& source) const {
diagnostics_.add_error(diag::System::Resolver, msg, source);
}
void Validator::AddWarning(const std::string& msg, const Source& source) const {
diagnostics_.add_warning(diag::System::Resolver, msg, source);
}
void Validator::AddNote(const std::string& msg, const Source& source) const {
diagnostics_.add_note(diag::System::Resolver, msg, source);
}
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
bool Validator::IsPlain(const sem::Type* type) const {
return type->is_scalar() ||
type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
sem::Struct>();
}
// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types
bool Validator::IsFixedFootprint(const sem::Type* type) const {
return Switch(
type, //
[&](const sem::Vector*) { return true; }, //
[&](const sem::Matrix*) { return true; }, //
[&](const sem::Atomic*) { return true; },
[&](const sem::Array* arr) {
return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType());
},
[&](const sem::Struct* str) {
for (auto* member : str->Members()) {
if (!IsFixedFootprint(member->Type())) {
return false;
}
}
return true;
},
[&](Default) { return type->is_scalar(); });
}
// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
bool Validator::IsHostShareable(const sem::Type* type) const {
if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
return true;
}
return Switch(
type, //
[&](const sem::Vector* vec) { return IsHostShareable(vec->type()); },
[&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); },
[&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); },
[&](const sem::Struct* str) {
for (auto* member : str->Members()) {
if (!IsHostShareable(member->Type())) {
return false;
}
}
return true;
},
[&](const sem::Atomic* atomic) {
return IsHostShareable(atomic->Type());
});
}
// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
bool Validator::IsStorable(const sem::Type* type) const {
return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
}
const ast::Statement* Validator::ClosestContinuing(
bool stop_at_loop,
sem::Statement* current_statement) const {
for (const auto* s = current_statement; s != nullptr; s = s->Parent()) {
if (stop_at_loop && s->Is<sem::LoopStatement>()) {
break;
}
if (s->Is<sem::LoopContinuingBlockStatement>()) {
return s->Declaration();
}
if (auto* f = As<sem::ForLoopStatement>(s->Parent())) {
if (f->Declaration()->continuing == s->Declaration()) {
return s->Declaration();
}
if (stop_at_loop) {
break;
}
}
}
return nullptr;
}
bool Validator::Atomic(const ast::Atomic* a, const sem::Atomic* s) const {
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// T must be either u32 or i32. // T must be either u32 or i32.
if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) { if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) {
@ -161,7 +257,7 @@ bool Resolver::ValidateAtomic(const ast::Atomic* a,
return true; return true;
} }
bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const { bool Validator::StorageTexture(const ast::StorageTexture* t) const {
switch (t->access) { switch (t->access) {
case ast::Access::kWrite: case ast::Access::kWrite:
break; break;
@ -190,11 +286,10 @@ bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const {
return true; return true;
} }
bool Resolver::ValidateVariableConstructorOrCast( bool Validator::VariableConstructorOrCast(const ast::Variable* var,
const ast::Variable* var, ast::StorageClass storage_class,
ast::StorageClass storage_class, const sem::Type* storage_ty,
const sem::Type* storage_ty, const sem::Type* rhs_ty) const {
const sem::Type* rhs_ty) const {
auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type // Value type has to match storage type
@ -229,11 +324,10 @@ bool Resolver::ValidateVariableConstructorOrCast(
return true; return true;
} }
bool Resolver::ValidateStorageClassLayout( bool Validator::StorageClassLayout(const sem::Type* store_ty,
const sem::Type* store_ty, ast::StorageClass sc,
ast::StorageClass sc, Source source,
Source source, ValidTypeStorageLayouts& layouts) const {
ValidTypeStorageLayouts& layouts) const {
// https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints // https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints
auto is_uniform_struct_or_array = [sc](const sem::Type* ty) { auto is_uniform_struct_or_array = [sc](const sem::Type* ty) {
@ -255,7 +349,7 @@ bool Resolver::ValidateStorageClassLayout(
}; };
auto member_name_of = [this](const sem::StructMember* sm) { auto member_name_of = [this](const sem::StructMember* sm) {
return builder_->Symbols().NameFor(sm->Declaration()->symbol); return symbols_.NameFor(sm->Declaration()->symbol);
}; };
// Cache result of type + storage class pair. // Cache result of type + storage class pair.
@ -273,9 +367,9 @@ bool Resolver::ValidateStorageClassLayout(
uint32_t required_align = required_alignment_of(m->Type()); uint32_t required_align = required_alignment_of(m->Type());
// Recurse into the member type. // Recurse into the member type.
if (!ValidateStorageClassLayout( if (!StorageClassLayout(m->Type(), sc, m->Declaration()->type->source,
m->Type(), sc, m->Declaration()->type->source, layouts)) { layouts)) {
AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()), AddNote("see layout of struct:\n" + str->Layout(symbols_),
str->Declaration()->source); str->Declaration()->source);
return false; return false;
} }
@ -283,7 +377,7 @@ bool Resolver::ValidateStorageClassLayout(
// Validate that member is at a valid byte offset // Validate that member is at a valid byte offset
if (m->Offset() % required_align != 0) { if (m->Offset() % required_align != 0) {
AddError("the offset of a struct member of type '" + AddError("the offset of a struct member of type '" +
m->Type()->UnwrapRef()->FriendlyName(builder_->Symbols()) + m->Type()->UnwrapRef()->FriendlyName(symbols_) +
"' in storage class '" + ast::ToString(sc) + "' in storage class '" + ast::ToString(sc) +
"' must be a multiple of " + "' must be a multiple of " +
std::to_string(required_align) + " bytes, but '" + std::to_string(required_align) + " bytes, but '" +
@ -293,13 +387,13 @@ bool Resolver::ValidateStorageClassLayout(
std::to_string(required_align) + ") on this member", std::to_string(required_align) + ") on this member",
m->Declaration()->source); m->Declaration()->source);
AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()), AddNote("see layout of struct:\n" + str->Layout(symbols_),
str->Declaration()->source); str->Declaration()->source);
if (auto* member_str = m->Type()->As<sem::Struct>()) { if (auto* member_str = m->Type()->As<sem::Struct>()) {
AddNote("and layout of struct member:\n" + AddNote(
member_str->Layout(builder_->Symbols()), "and layout of struct member:\n" + member_str->Layout(symbols_),
member_str->Declaration()->source); member_str->Declaration()->source);
} }
return false; return false;
@ -322,12 +416,12 @@ bool Resolver::ValidateStorageClassLayout(
"'. Consider setting @align(16) on this member", "'. Consider setting @align(16) on this member",
m->Declaration()->source); m->Declaration()->source);
AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()), AddNote("see layout of struct:\n" + str->Layout(symbols_),
str->Declaration()->source); str->Declaration()->source);
auto* prev_member_str = prev_member->Type()->As<sem::Struct>(); auto* prev_member_str = prev_member->Type()->As<sem::Struct>();
AddNote("and layout of previous member struct:\n" + AddNote("and layout of previous member struct:\n" +
prev_member_str->Layout(builder_->Symbols()), prev_member_str->Layout(symbols_),
prev_member_str->Declaration()->source); prev_member_str->Declaration()->source);
return false; return false;
} }
@ -342,7 +436,7 @@ bool Resolver::ValidateStorageClassLayout(
// TODO(crbug.com/tint/1388): Ideally we'd pass the source for nested // TODO(crbug.com/tint/1388): Ideally we'd pass the source for nested
// element type here, but we can't easily get that from the semantic node. // element type here, but we can't easily get that from the semantic node.
// We should consider recursing through the AST type nodes instead. // We should consider recursing through the AST type nodes instead.
if (!ValidateStorageClassLayout(arr->ElemType(), sc, source, layouts)) { if (!StorageClassLayout(arr->ElemType(), sc, source, layouts)) {
return false; return false;
} }
@ -384,12 +478,11 @@ bool Resolver::ValidateStorageClassLayout(
return true; return true;
} }
bool Resolver::ValidateStorageClassLayout( bool Validator::StorageClassLayout(const sem::Variable* var,
const sem::Variable* var, ValidTypeStorageLayouts& layouts) const {
ValidTypeStorageLayouts& layouts) const {
if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) { if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
if (!ValidateStorageClassLayout(str, var->StorageClass(), if (!StorageClassLayout(str, var->StorageClass(),
str->Declaration()->source, layouts)) { str->Declaration()->source, layouts)) {
AddNote("see declaration of variable", var->Declaration()->source); AddNote("see declaration of variable", var->Declaration()->source);
return false; return false;
} }
@ -398,8 +491,8 @@ bool Resolver::ValidateStorageClassLayout(
if (var->Declaration()->type) { if (var->Declaration()->type) {
source = var->Declaration()->type->source; source = var->Declaration()->type->source;
} }
if (!ValidateStorageClassLayout(var->Type()->UnwrapRef(), if (!StorageClassLayout(var->Type()->UnwrapRef(), var->StorageClass(),
var->StorageClass(), source, layouts)) { source, layouts)) {
return false; return false;
} }
} }
@ -407,9 +500,13 @@ bool Resolver::ValidateStorageClassLayout(
return true; return true;
} }
bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const { bool Validator::GlobalVariable(
const sem::Variable* var,
std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info)
const {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
if (!ValidateNoDuplicateAttributes(decl->attributes)) { if (!NoDuplicateAttributes(decl->attributes)) {
return false; return false;
} }
@ -417,8 +514,8 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const {
if (decl->is_const) { if (decl->is_const) {
if (auto* id_attr = attr->As<ast::IdAttribute>()) { if (auto* id_attr = attr->As<ast::IdAttribute>()) {
uint32_t id = id_attr->value; uint32_t id = id_attr->value;
auto it = constant_ids_.find(id); auto it = constant_ids.find(id);
if (it != constant_ids_.end() && it->second != var) { if (it != constant_ids.end() && it->second != var) {
AddError("pipeline constant IDs must be unique", attr->source); AddError("pipeline constant IDs must be unique", attr->source);
AddNote("a pipeline constant with an ID of " + std::to_string(id) + AddNote("a pipeline constant with an ID of " + std::to_string(id) +
" was previously declared " " was previously declared "
@ -502,18 +599,21 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const {
} }
if (!decl->is_const) { if (!decl->is_const) {
if (!ValidateAtomicVariable(var)) { if (!AtomicVariable(var, atomic_composite_info)) {
return false; return false;
} }
} }
return ValidateVariable(var); return Variable(var);
} }
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// Atomic types may only be instantiated by variables in the workgroup storage // Atomic types may only be instantiated by variables in the workgroup storage
// class or by storage buffer variables with a read_write access mode. // class or by storage buffer variables with a read_write access mode.
bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const { bool Validator::AtomicVariable(
const sem::Variable* var,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info)
const {
auto sc = var->StorageClass(); auto sc = var->StorageClass();
auto* decl = var->Declaration(); auto* decl = var->Declaration();
auto access = var->Access(); auto access = var->Access();
@ -529,8 +629,8 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const {
return false; return false;
} }
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) { } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
auto found = atomic_composite_info_.find(type); auto found = atomic_composite_info.find(type);
if (found != atomic_composite_info_.end()) { if (found != atomic_composite_info.end()) {
if (sc != ast::StorageClass::kStorage && if (sc != ast::StorageClass::kStorage &&
sc != ast::StorageClass::kWorkgroup) { sc != ast::StorageClass::kWorkgroup) {
AddError( AddError(
@ -557,12 +657,12 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const {
return true; return true;
} }
bool Resolver::ValidateVariable(const sem::Variable* var) const { bool Validator::Variable(const sem::Variable* var) const {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
auto* storage_ty = var->Type()->UnwrapRef(); auto* storage_ty = var->Type()->UnwrapRef();
if (var->Is<sem::GlobalVariable>()) { if (var->Is<sem::GlobalVariable>()) {
auto name = builder_->Symbols().NameFor(decl->symbol); auto name = symbols_.NameFor(decl->symbol);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
auto* kind = var->Declaration()->is_const ? "let" : "var"; auto* kind = var->Declaration()->is_const ? "let" : "var";
AddError( AddError(
@ -634,9 +734,9 @@ bool Resolver::ValidateVariable(const sem::Variable* var) const {
return true; return true;
} }
bool Resolver::ValidateFunctionParameter(const ast::Function* func, bool Validator::FunctionParameter(const ast::Function* func,
const sem::Variable* var) const { const sem::Variable* var) const {
if (!ValidateVariable(var)) { if (!Variable(var)) {
return false; return false;
} }
@ -697,10 +797,10 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func,
return true; return true;
} }
bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, bool Validator::BuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_ty, const sem::Type* storage_ty,
ast::PipelineStage stage, ast::PipelineStage stage,
const bool is_input) const { const bool is_input) const {
auto* type = storage_ty->UnwrapRef(); auto* type = storage_ty->UnwrapRef();
std::stringstream stage_name; std::stringstream stage_name;
stage_name << stage; stage_name << stage;
@ -816,9 +916,8 @@ bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
return true; return true;
} }
bool Resolver::ValidateInterpolateAttribute( bool Validator::InterpolateAttribute(const ast::InterpolateAttribute* attr,
const ast::InterpolateAttribute* attr, const sem::Type* storage_ty) const {
const sem::Type* storage_ty) const {
auto* type = storage_ty->UnwrapRef(); auto* type = storage_ty->UnwrapRef();
if (type->is_integer_scalar_or_vector() && if (type->is_integer_scalar_or_vector() &&
@ -839,11 +938,11 @@ bool Resolver::ValidateInterpolateAttribute(
return true; return true;
} }
bool Resolver::ValidateFunction(const sem::Function* func, bool Validator::Function(const sem::Function* func,
ast::PipelineStage stage) const { ast::PipelineStage stage) const {
auto* decl = func->Declaration(); auto* decl = func->Declaration();
auto name = builder_->Symbols().NameFor(decl->symbol); auto name = symbols_.NameFor(decl->symbol);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError( AddError(
"'" + name + "' is a builtin and cannot be redeclared as a function", "'" + name + "' is a builtin and cannot be redeclared as a function",
@ -873,7 +972,7 @@ bool Resolver::ValidateFunction(const sem::Function* func,
} }
for (size_t i = 0; i < decl->params.size(); i++) { for (size_t i = 0; i < decl->params.size(); i++) {
if (!ValidateFunctionParameter(decl, func->Parameters()[i])) { if (!FunctionParameter(decl, func->Parameters()[i])) {
return false; return false;
} }
} }
@ -898,8 +997,7 @@ bool Resolver::ValidateFunction(const sem::Function* func,
decl->attributes, decl->attributes,
ast::DisabledValidation::kFunctionHasNoBody)) { ast::DisabledValidation::kFunctionHasNoBody)) {
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "Function " << builder_->Symbols().NameFor(decl->symbol) << "Function " << symbols_.NameFor(decl->symbol) << " has no body";
<< " has no body";
} }
for (auto* attr : decl->return_type_attributes) { for (auto* attr : decl->return_type_attributes) {
@ -925,7 +1023,7 @@ bool Resolver::ValidateFunction(const sem::Function* func,
} }
if (decl->IsEntryPoint()) { if (decl->IsEntryPoint()) {
if (!ValidateEntryPoint(func, stage)) { if (!EntryPoint(func, stage)) {
return false; return false;
} }
} }
@ -945,8 +1043,8 @@ bool Resolver::ValidateFunction(const sem::Function* func,
return true; return true;
} }
bool Resolver::ValidateEntryPoint(const sem::Function* func, bool Validator::EntryPoint(const sem::Function* func,
ast::PipelineStage stage) const { ast::PipelineStage stage) const {
auto* decl = func->Declaration(); auto* decl = func->Declaration();
// Use a lambda to validate the entry point attributes for a type. // Use a lambda to validate the entry point attributes for a type.
@ -994,7 +1092,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func,
return false; return false;
} }
if (!ValidateBuiltinAttribute( if (!BuiltinAttribute(
builtin, ty, stage, builtin, ty, stage,
/* is_input */ param_or_ret == ParamOrRetType::kParameter)) { /* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
return false; return false;
@ -1011,14 +1109,14 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func,
bool is_input = param_or_ret == ParamOrRetType::kParameter; bool is_input = param_or_ret == ParamOrRetType::kParameter;
if (!ValidateLocationAttribute(location, ty, locations, stage, source, if (!LocationAttribute(location, ty, locations, stage, source,
is_input)) { is_input)) {
return false; return false;
} }
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
if (decl->PipelineStage() == ast::PipelineStage::kCompute) { if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_attribute = true; is_invalid_compute_shader_attribute = true;
} else if (!ValidateInterpolateAttribute(interpolate, ty)) { } else if (!InterpolateAttribute(interpolate, ty)) {
return false; return false;
} }
interpolate_attribute = interpolate; interpolate_attribute = interpolate;
@ -1122,7 +1220,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func,
member->Declaration()->source, param_or_ret, member->Declaration()->source, param_or_ret,
/*is_struct_member*/ true)) { /*is_struct_member*/ true)) {
AddNote("while analysing entry point '" + AddNote("while analysing entry point '" +
builder_->Symbols().NameFor(decl->symbol) + "'", symbols_.NameFor(decl->symbol) + "'",
decl->source); decl->source);
return false; return false;
} }
@ -1206,7 +1304,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func,
// variables in the resource interface of a given shader must not have // variables in the resource interface of a given shader must not have
// the same group and binding values, when considered as a pair of // the same group and binding values, when considered as a pair of
// values. // values.
auto func_name = builder_->Symbols().NameFor(decl->symbol); auto func_name = symbols_.NameFor(decl->symbol);
AddError("entry point '" + func_name + AddError("entry point '" + func_name +
"' references multiple variables that use the " "' references multiple variables that use the "
"same resource binding @group(" + "same resource binding @group(" +
@ -1222,7 +1320,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func,
return true; return true;
} }
bool Resolver::ValidateStatements(const ast::StatementList& stmts) const { bool Validator::Statements(const ast::StatementList& stmts) const {
for (auto* stmt : stmts) { for (auto* stmt : stmts) {
if (!sem_.Get(stmt)->IsReachable()) { if (!sem_.Get(stmt)->IsReachable()) {
/// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to /// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to
@ -1234,8 +1332,8 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) const {
return true; return true;
} }
bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, bool Validator::Bitcast(const ast::BitcastExpression* cast,
const sem::Type* to) const { const sem::Type* to) const {
auto* from = sem_.TypeOf(cast->expr)->UnwrapRef(); auto* from = sem_.TypeOf(cast->expr)->UnwrapRef();
if (!from->is_numeric_scalar_or_vector()) { if (!from->is_numeric_scalar_or_vector()) {
AddError("'" + sem_.TypeNameOf(from) + "' cannot be bitcast", AddError("'" + sem_.TypeNameOf(from) + "' cannot be bitcast",
@ -1265,13 +1363,15 @@ bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
return true; return true;
} }
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const { bool Validator::BreakStatement(const sem::Statement* stmt,
sem::Statement* current_statement) const {
if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) { if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case", AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source); stmt->Declaration()->source);
return false; return false;
} }
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) { if (auto* continuing =
ClosestContinuing(/*stop_at_loop*/ true, current_statement)) {
auto fail = [&](const char* note_msg, const Source& note_src) { auto fail = [&](const char* note_msg, const Source& note_src) {
constexpr const char* kErrorMsg = constexpr const char* kErrorMsg =
"break statement in a continuing block must be the single statement " "break statement in a continuing block must be the single statement "
@ -1332,8 +1432,10 @@ bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const {
return true; return true;
} }
bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const { bool Validator::ContinueStatement(const sem::Statement* stmt,
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) { sem::Statement* current_statement) const {
if (auto* continuing =
ClosestContinuing(/*stop_at_loop*/ true, current_statement)) {
AddError("continuing blocks must not contain a continue statement", AddError("continuing blocks must not contain a continue statement",
stmt->Declaration()->source); stmt->Declaration()->source);
if (continuing != stmt->Declaration() && if (continuing != stmt->Declaration() &&
@ -1352,8 +1454,10 @@ bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const {
return true; return true;
} }
bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const { bool Validator::DiscardStatement(const sem::Statement* stmt,
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { sem::Statement* current_statement) const {
if (auto* continuing =
ClosestContinuing(/*stop_at_loop*/ false, current_statement)) {
AddError("continuing blocks must not contain a discard statement", AddError("continuing blocks must not contain a discard statement",
stmt->Declaration()->source); stmt->Declaration()->source);
if (continuing != stmt->Declaration() && if (continuing != stmt->Declaration() &&
@ -1365,7 +1469,7 @@ bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const {
return true; return true;
} }
bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const { bool Validator::FallthroughStatement(const sem::Statement* stmt) const {
if (auto* block = As<sem::BlockStatement>(stmt->Parent())) { if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
if (auto* c = As<sem::CaseStatement>(block->Parent())) { if (auto* c = As<sem::CaseStatement>(block->Parent())) {
if (block->Declaration()->Last() == stmt->Declaration()) { if (block->Declaration()->Last() == stmt->Declaration()) {
@ -1388,7 +1492,7 @@ bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const {
return false; return false;
} }
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const { bool Validator::ElseStatement(const sem::ElseStatement* stmt) const {
if (auto* cond = stmt->Condition()) { if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef(); auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) { if (!cond_ty->Is<sem::Bool>()) {
@ -1401,7 +1505,7 @@ bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const {
return true; return true;
} }
bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const { bool Validator::LoopStatement(const sem::LoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) { if (stmt->Behaviors().Empty()) {
AddError("loop does not exit", stmt->Declaration()->source.Begin()); AddError("loop does not exit", stmt->Declaration()->source.Begin());
return false; return false;
@ -1409,8 +1513,7 @@ bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const {
return true; return true;
} }
bool Resolver::ValidateForLoopStatement( bool Validator::ForLoopStatement(const sem::ForLoopStatement* stmt) const {
const sem::ForLoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) { if (stmt->Behaviors().Empty()) {
AddError("for-loop does not exit", stmt->Declaration()->source.Begin()); AddError("for-loop does not exit", stmt->Declaration()->source.Begin());
return false; return false;
@ -1427,7 +1530,7 @@ bool Resolver::ValidateForLoopStatement(
return true; return true;
} }
bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const { bool Validator::IfStatement(const sem::IfStatement* stmt) const {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef(); auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) { if (!cond_ty->Is<sem::Bool>()) {
AddError( AddError(
@ -1438,7 +1541,7 @@ bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const {
return true; return true;
} }
bool Resolver::ValidateBuiltinCall(const sem::Call* call) const { bool Validator::BuiltinCall(const sem::Call* call) const {
if (call->Type()->Is<sem::Void>()) { if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false; bool is_call_statement = false;
if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) { if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
@ -1451,7 +1554,7 @@ bool Resolver::ValidateBuiltinCall(const sem::Call* call) const {
// If the called function does not return a value, a function call // If the called function does not return a value, a function call
// statement should be used instead. // statement should be used instead.
auto* ident = call->Declaration()->target.name; auto* ident = call->Declaration()->target.name;
auto name = builder_->Symbols().NameFor(ident->symbol); auto name = symbols_.NameFor(ident->symbol);
AddError("builtin '" + name + "' does not return a value", AddError("builtin '" + name + "' does not return a value",
call->Declaration()->source); call->Declaration()->source);
return false; return false;
@ -1461,7 +1564,7 @@ bool Resolver::ValidateBuiltinCall(const sem::Call* call) const {
return true; return true;
} }
bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const { bool Validator::TextureBuiltinFunction(const sem::Call* call) const {
auto* builtin = call->Target()->As<sem::Builtin>(); auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin) { if (!builtin) {
return false; return false;
@ -1533,11 +1636,12 @@ bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const {
check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3); check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3);
} }
bool Resolver::ValidateFunctionCall(const sem::Call* call) const { bool Validator::FunctionCall(const sem::Call* call,
sem::Statement* current_statement) const {
auto* decl = call->Declaration(); auto* decl = call->Declaration();
auto* target = call->Target()->As<sem::Function>(); auto* target = call->Target()->As<sem::Function>();
auto sym = decl->target.name->symbol; auto sym = decl->target.name->symbol;
auto name = builder_->Symbols().NameFor(sym); auto name = symbols_.NameFor(sym);
if (target->Declaration()->IsEntryPoint()) { if (target->Declaration()->IsEntryPoint()) {
// https://www.w3.org/TR/WGSL/#function-restriction // https://www.w3.org/TR/WGSL/#function-restriction
@ -1575,7 +1679,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const {
if (param_type->Is<sem::Pointer>()) { if (param_type->Is<sem::Pointer>()) {
auto is_valid = false; auto is_valid = false;
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) { if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
auto* var = ResolvedSymbol<sem::Variable>(ident_expr); auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_expr);
if (!var) { if (!var) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false; return false;
@ -1587,7 +1691,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const {
if (unary->op == ast::UnaryOp::kAddressOf) { if (unary->op == ast::UnaryOp::kAddressOf) {
if (auto* ident_unary = if (auto* ident_unary =
unary->expr->As<ast::IdentifierExpression>()) { unary->expr->As<ast::IdentifierExpression>()) {
auto* var = ResolvedSymbol<sem::Variable>(ident_unary); auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_unary);
if (!var) { if (!var) {
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve identifier"; << "failed to resolve identifier";
@ -1634,7 +1738,8 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const {
} }
if (call->Behaviors().Contains(sem::Behavior::kDiscard)) { if (call->Behaviors().Contains(sem::Behavior::kDiscard)) {
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { if (auto* continuing =
ClosestContinuing(/*stop_at_loop*/ false, current_statement)) {
AddError( AddError(
"cannot call a function that may discard inside a continuing block", "cannot call a function that may discard inside a continuing block",
call->Declaration()->source); call->Declaration()->source);
@ -1649,7 +1754,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const {
return true; return true;
} }
bool Resolver::ValidateStructureConstructorOrCast( bool Validator::StructureConstructorOrCast(
const ast::CallExpression* ctor, const ast::CallExpression* ctor,
const sem::Struct* struct_type) const { const sem::Struct* struct_type) const {
if (!struct_type->IsConstructible()) { if (!struct_type->IsConstructible()) {
@ -1684,9 +1789,8 @@ bool Resolver::ValidateStructureConstructorOrCast(
return true; return true;
} }
bool Resolver::ValidateArrayConstructorOrCast( bool Validator::ArrayConstructorOrCast(const ast::CallExpression* ctor,
const ast::CallExpression* ctor, const sem::Array* array_type) const {
const sem::Array* array_type) const {
auto& values = ctor->args; auto& values = ctor->args;
auto* elem_ty = array_type->ElemType(); auto* elem_ty = array_type->ElemType();
for (auto* value : values) { for (auto* value : values) {
@ -1726,9 +1830,8 @@ bool Resolver::ValidateArrayConstructorOrCast(
return true; return true;
} }
bool Resolver::ValidateVectorConstructorOrCast( bool Validator::VectorConstructorOrCast(const ast::CallExpression* ctor,
const ast::CallExpression* ctor, const sem::Vector* vec_type) const {
const sem::Vector* vec_type) const {
auto& values = ctor->args; auto& values = ctor->args;
auto* elem_ty = vec_type->type(); auto* elem_ty = vec_type->type();
size_t value_cardinality_sum = 0; size_t value_cardinality_sum = 0;
@ -1790,8 +1893,7 @@ bool Resolver::ValidateVectorConstructorOrCast(
return true; return true;
} }
bool Resolver::ValidateVector(const sem::Vector* ty, bool Validator::Vector(const sem::Vector* ty, const Source& source) const {
const Source& source) const {
if (!ty->type()->is_scalar()) { if (!ty->type()->is_scalar()) {
AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'", AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
source); source);
@ -1800,8 +1902,7 @@ bool Resolver::ValidateVector(const sem::Vector* ty,
return true; return true;
} }
bool Resolver::ValidateMatrix(const sem::Matrix* ty, bool Validator::Matrix(const sem::Matrix* ty, const Source& source) const {
const Source& source) const {
if (!ty->is_float_matrix()) { if (!ty->is_float_matrix()) {
AddError("matrix element type must be 'f32'", source); AddError("matrix element type must be 'f32'", source);
return false; return false;
@ -1809,16 +1910,15 @@ bool Resolver::ValidateMatrix(const sem::Matrix* ty,
return true; return true;
} }
bool Resolver::ValidateMatrixConstructorOrCast( bool Validator::MatrixConstructorOrCast(const ast::CallExpression* ctor,
const ast::CallExpression* ctor, const sem::Matrix* matrix_ty) const {
const sem::Matrix* matrix_ty) const {
auto& values = ctor->args; auto& values = ctor->args;
// Zero Value expression // Zero Value expression
if (values.empty()) { if (values.empty()) {
return true; return true;
} }
if (!ValidateMatrix(matrix_ty, ctor->source)) { if (!Matrix(matrix_ty, ctor->source)) {
return false; return false;
} }
@ -1844,7 +1944,7 @@ bool Resolver::ValidateMatrixConstructorOrCast(
if (i > 0) { if (i > 0) {
ss << ", "; ss << ", ";
} }
ss << arg_tys[i]->FriendlyName(builder_->Symbols()); ss << arg_tys[i]->FriendlyName(symbols_);
} }
ss << ")" << std::endl << std::endl; ss << ")" << std::endl << std::endl;
ss << "3 candidates available:" << std::endl; ss << "3 candidates available:" << std::endl;
@ -1885,8 +1985,8 @@ bool Resolver::ValidateMatrixConstructorOrCast(
return true; return true;
} }
bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, bool Validator::ScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* ty) const { const sem::Type* ty) const {
if (ctor->args.size() == 0) { if (ctor->args.size() == 0) {
return true; return true;
} }
@ -1921,7 +2021,8 @@ bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
return true; return true;
} }
bool Resolver::ValidatePipelineStages() const { bool Validator::PipelineStages(
const std::vector<sem::Function*>& entry_points) const {
auto check_workgroup_storage = [&](const sem::Function* func, auto check_workgroup_storage = [&](const sem::Function* func,
const sem::Function* entry_point) { const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage(); auto stage = entry_point->Declaration()->PipelineStage();
@ -1940,17 +2041,14 @@ bool Resolver::ValidatePipelineStages() const {
} }
AddNote("variable is declared here", var->Declaration()->source); AddNote("variable is declared here", var->Declaration()->source);
if (func != entry_point) { if (func != entry_point) {
TraverseCallChain(diagnostics_, entry_point, func, TraverseCallChain(
[&](const sem::Function* f) { diagnostics_, entry_point, func, [&](const sem::Function* f) {
AddNote("called by function '" + AddNote("called by function '" +
builder_->Symbols().NameFor( symbols_.NameFor(f->Declaration()->symbol) + "'",
f->Declaration()->symbol) + f->Declaration()->source);
"'", });
f->Declaration()->source);
});
AddNote("called by entry point '" + AddNote("called by entry point '" +
builder_->Symbols().NameFor( symbols_.NameFor(entry_point->Declaration()->symbol) +
entry_point->Declaration()->symbol) +
"'", "'",
entry_point->Declaration()->source); entry_point->Declaration()->source);
} }
@ -1961,7 +2059,7 @@ bool Resolver::ValidatePipelineStages() const {
return true; return true;
}; };
for (auto* entry_point : entry_points_) { for (auto* entry_point : entry_points) {
if (!check_workgroup_storage(entry_point, entry_point)) { if (!check_workgroup_storage(entry_point, entry_point)) {
return false; return false;
} }
@ -1985,15 +2083,12 @@ bool Resolver::ValidatePipelineStages() const {
if (func != entry_point) { if (func != entry_point) {
TraverseCallChain( TraverseCallChain(
diagnostics_, entry_point, func, [&](const sem::Function* f) { diagnostics_, entry_point, func, [&](const sem::Function* f) {
AddNote( AddNote("called by function '" +
"called by function '" + symbols_.NameFor(f->Declaration()->symbol) + "'",
builder_->Symbols().NameFor(f->Declaration()->symbol) + f->Declaration()->source);
"'",
f->Declaration()->source);
}); });
AddNote("called by entry point '" + AddNote("called by entry point '" +
builder_->Symbols().NameFor( symbols_.NameFor(entry_point->Declaration()->symbol) +
entry_point->Declaration()->symbol) +
"'", "'",
entry_point->Declaration()->source); entry_point->Declaration()->source);
} }
@ -2003,7 +2098,7 @@ bool Resolver::ValidatePipelineStages() const {
return true; return true;
}; };
for (auto* entry_point : entry_points_) { for (auto* entry_point : entry_points) {
if (!check_builtin_calls(entry_point, entry_point)) { if (!check_builtin_calls(entry_point, entry_point)) {
return false; return false;
} }
@ -2016,8 +2111,7 @@ bool Resolver::ValidatePipelineStages() const {
return true; return true;
} }
bool Resolver::ValidateArray(const sem::Array* arr, bool Validator::Array(const sem::Array* arr, const Source& source) const {
const Source& source) const {
auto* el_ty = arr->ElemType(); auto* el_ty = arr->ElemType();
if (!IsFixedFootprint(el_ty)) { if (!IsFixedFootprint(el_ty)) {
@ -2028,10 +2122,10 @@ bool Resolver::ValidateArray(const sem::Array* arr,
return true; return true;
} }
bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, bool Validator::ArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size, uint32_t el_size,
uint32_t el_align, uint32_t el_align,
const Source& source) const { const Source& source) const {
auto stride = attr->stride; auto stride = attr->stride;
bool is_valid_stride = bool is_valid_stride =
(stride >= el_size) && (stride >= el_align) && (stride % el_align == 0); (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0);
@ -2050,8 +2144,8 @@ bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
return true; return true;
} }
bool Resolver::ValidateAlias(const ast::Alias* alias) const { bool Validator::Alias(const ast::Alias* alias) const {
auto name = builder_->Symbols().NameFor(alias->name); auto name = symbols_.NameFor(alias->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as an alias", AddError("'" + name + "' is a builtin and cannot be redeclared as an alias",
alias->source); alias->source);
@ -2061,9 +2155,9 @@ bool Resolver::ValidateAlias(const ast::Alias* alias) const {
return true; return true;
} }
bool Resolver::ValidateStructure(const sem::Struct* str, bool Validator::Structure(const sem::Struct* str,
ast::PipelineStage stage) const { ast::PipelineStage stage) const {
auto name = builder_->Symbols().NameFor(str->Declaration()->name); auto name = symbols_.NameFor(str->Declaration()->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as a struct", AddError("'" + name + "' is a builtin and cannot be redeclared as a struct",
str->Declaration()->source); str->Declaration()->source);
@ -2122,13 +2216,13 @@ bool Resolver::ValidateStructure(const sem::Struct* str,
invariant_attribute = invariant; invariant_attribute = invariant;
} else if (auto* location = attr->As<ast::LocationAttribute>()) { } else if (auto* location = attr->As<ast::LocationAttribute>()) {
has_location = true; has_location = true;
if (!ValidateLocationAttribute(location, member->Type(), locations, if (!LocationAttribute(location, member->Type(), locations, stage,
stage, member->Declaration()->source)) { member->Declaration()->source)) {
return false; return false;
} }
} else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { } else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
if (!ValidateBuiltinAttribute(builtin, member->Type(), stage, if (!BuiltinAttribute(builtin, member->Type(), stage,
/* is_input */ false)) { /* is_input */ false)) {
return false; return false;
} }
if (builtin->builtin == ast::Builtin::kPosition) { if (builtin->builtin == ast::Builtin::kPosition) {
@ -2136,7 +2230,7 @@ bool Resolver::ValidateStructure(const sem::Struct* str,
} }
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
interpolate_attribute = interpolate; interpolate_attribute = interpolate;
if (!ValidateInterpolateAttribute(interpolate, member->Type())) { if (!InterpolateAttribute(interpolate, member->Type())) {
return false; return false;
} }
} }
@ -2165,13 +2259,12 @@ bool Resolver::ValidateStructure(const sem::Struct* str,
return true; return true;
} }
bool Resolver::ValidateLocationAttribute( bool Validator::LocationAttribute(const ast::LocationAttribute* location,
const ast::LocationAttribute* location, const sem::Type* type,
const sem::Type* type, std::unordered_set<uint32_t>& locations,
std::unordered_set<uint32_t>& locations, ast::PipelineStage stage,
ast::PipelineStage stage, const Source& source,
const Source& source, const bool is_input) const {
const bool is_input) const {
std::string inputs_or_output = is_input ? "inputs" : "output"; std::string inputs_or_output = is_input ? "inputs" : "output";
if (stage == ast::PipelineStage::kCompute) { if (stage == ast::PipelineStage::kCompute) {
AddError("attribute is not valid for compute shader " + inputs_or_output, AddError("attribute is not valid for compute shader " + inputs_or_output,
@ -2201,9 +2294,10 @@ bool Resolver::ValidateLocationAttribute(
return true; return true;
} }
bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, bool Validator::Return(const ast::ReturnStatement* ret,
const sem::Type* func_type, const sem::Type* func_type,
const sem::Type* ret_type) const { const sem::Type* ret_type,
sem::Statement* current_statement) const {
if (func_type->UnwrapRef() != ret_type) { if (func_type->UnwrapRef() != ret_type) {
AddError( AddError(
"return statement type must match its function " "return statement type must match its function "
@ -2215,7 +2309,8 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret,
} }
auto* sem = sem_.Get(ret); auto* sem = sem_.Get(ret);
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { if (auto* continuing =
ClosestContinuing(/*stop_at_loop*/ false, current_statement)) {
AddError("continuing blocks must not contain a return statement", AddError("continuing blocks must not contain a return statement",
ret->source); ret->source);
if (continuing != sem->Declaration() && if (continuing != sem->Declaration() &&
@ -2228,7 +2323,7 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret,
return true; return true;
} }
bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
auto* cond_ty = sem_.TypeOf(s->condition)->UnwrapRef(); auto* cond_ty = sem_.TypeOf(s->condition)->UnwrapRef();
if (!cond_ty->is_integer_scalar()) { if (!cond_ty->is_integer_scalar()) {
AddError( AddError(
@ -2284,8 +2379,8 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
return true; return true;
} }
bool Resolver::ValidateAssignment(const ast::Statement* a, bool Validator::Assignment(const ast::Statement* a,
const sem::Type* rhs_ty) const { const sem::Type* rhs_ty) const {
const ast::Expression* lhs; const ast::Expression* lhs;
const ast::Expression* rhs; const ast::Expression* rhs;
if (auto* assign = a->As<ast::AssignmentStatement>()) { if (auto* assign = a->As<ast::AssignmentStatement>()) {
@ -2317,19 +2412,17 @@ bool Resolver::ValidateAssignment(const ast::Statement* a,
// https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
auto const* lhs_ty = sem_.TypeOf(lhs); auto const* lhs_ty = sem_.TypeOf(lhs);
if (auto* var = ResolvedSymbol<sem::Variable>(lhs)) { if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
if (var->Is<sem::Parameter>()) { if (var->Is<sem::Parameter>()) {
AddError("cannot assign to function parameter", lhs->source); AddError("cannot assign to function parameter", lhs->source);
AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
"' is declared here:",
decl->source); decl->source);
return false; return false;
} }
if (decl->is_const) { if (decl->is_const) {
AddError("cannot assign to const", lhs->source); AddError("cannot assign to const", lhs->source);
AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
"' is declared here:",
decl->source); decl->source);
return false; return false;
} }
@ -2366,25 +2459,23 @@ bool Resolver::ValidateAssignment(const ast::Statement* a,
return true; return true;
} }
bool Resolver::ValidateIncrementDecrementStatement( bool Validator::IncrementDecrementStatement(
const ast::IncrementDecrementStatement* inc) const { const ast::IncrementDecrementStatement* inc) const {
const ast::Expression* lhs = inc->lhs; const ast::Expression* lhs = inc->lhs;
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
if (auto* var = ResolvedSymbol<sem::Variable>(lhs)) { if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
if (var->Is<sem::Parameter>()) { if (var->Is<sem::Parameter>()) {
AddError("cannot modify function parameter", lhs->source); AddError("cannot modify function parameter", lhs->source);
AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
"' is declared here:",
decl->source); decl->source);
return false; return false;
} }
if (decl->is_const) { if (decl->is_const) {
AddError("cannot modify constant value", lhs->source); AddError("cannot modify constant value", lhs->source);
AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
"' is declared here:",
decl->source); decl->source);
return false; return false;
} }
@ -2415,7 +2506,7 @@ bool Resolver::ValidateIncrementDecrementStatement(
return true; return true;
} }
bool Resolver::ValidateNoDuplicateAttributes( bool Validator::NoDuplicateAttributes(
const ast::AttributeList& attributes) const { const ast::AttributeList& attributes) const {
std::unordered_map<const TypeInfo*, Source> seen; std::unordered_map<const TypeInfo*, Source> seen;
for (auto* d : attributes) { for (auto* d : attributes) {
@ -2429,8 +2520,8 @@ bool Resolver::ValidateNoDuplicateAttributes(
return true; return true;
} }
bool Resolver::IsValidationDisabled(const ast::AttributeList& attributes, bool Validator::IsValidationDisabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const { ast::DisabledValidation validation) const {
for (auto* attribute : attributes) { for (auto* attribute : attributes) {
if (auto* dv = attribute->As<ast::DisableValidationAttribute>()) { if (auto* dv = attribute->As<ast::DisableValidationAttribute>()) {
if (dv->validation == validation) { if (dv->validation == validation) {
@ -2441,15 +2532,15 @@ bool Resolver::IsValidationDisabled(const ast::AttributeList& attributes,
return false; return false;
} }
bool Resolver::IsValidationEnabled(const ast::AttributeList& attributes, bool Validator::IsValidationEnabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const { ast::DisabledValidation validation) const {
return !IsValidationDisabled(attributes, validation); return !IsValidationDisabled(attributes, validation);
} }
std::string Resolver::VectorPretty(uint32_t size, std::string Validator::VectorPretty(uint32_t size,
const sem::Type* element_type) const { const sem::Type* element_type) const {
sem::Vector vec_type(element_type, size); sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols()); return vec_type.FriendlyName(symbols_);
} }
} // namespace tint::resolver } // namespace tint::resolver

View File

@ -0,0 +1,457 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_RESOLVER_VALIDATOR_H_
#define SRC_TINT_RESOLVER_VALIDATOR_H_
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/tint/ast/pipeline_stage.h"
#include "src/tint/program_builder.h"
#include "src/tint/resolver/sem_helper.h"
#include "src/tint/source.h"
// Forward declarations
namespace tint::ast {
class IndexAccessorExpression;
class BinaryExpression;
class BitcastExpression;
class CallExpression;
class CallStatement;
class CaseStatement;
class ForLoopStatement;
class Function;
class IdentifierExpression;
class LoopStatement;
class MemberAccessorExpression;
class ReturnStatement;
class SwitchStatement;
class UnaryOpExpression;
class Variable;
} // namespace tint::ast
namespace tint::sem {
class Array;
class Atomic;
class BlockStatement;
class Builtin;
class CaseStatement;
class ElseStatement;
class ForLoopStatement;
class IfStatement;
class LoopStatement;
class Statement;
class SwitchStatement;
class TypeConstructor;
} // namespace tint::sem
namespace tint::resolver {
/// Validation logic for various ast nodes. The validations in general should
/// be shallow and depend on the resolver to call on children. The validations
/// also assume that sem changes have already been made. The validation checks
/// should not alter the AST or SEM trees.
class Validator {
public:
/// The valid type storage layouts typedef
using ValidTypeStorageLayouts =
std::set<std::pair<const sem::Type*, ast::StorageClass>>;
/// Constructor
/// @param builder the program builder
/// @param helper the SEM helper to validate with
Validator(ProgramBuilder* builder, SemHelper& helper);
~Validator();
/// Adds the given error message to the diagnostics
/// @param msg the error message
/// @param source the error source
void AddError(const std::string& msg, const Source& source) const;
/// Adds the given warning message to the diagnostics
/// @param msg the warning message
/// @param source the warning source
void AddWarning(const std::string& msg, const Source& source) const;
/// Adds the given note message to the diagnostics
/// @param msg the note message
/// @param source the note source
void AddNote(const std::string& msg, const Source& source) const;
/// @param type the given type
/// @returns true if the given type is a plain type
bool IsPlain(const sem::Type* type) const;
/// @param type the given type
/// @returns true if the given type is a fixed-footprint type
bool IsFixedFootprint(const sem::Type* type) const;
/// @param type the given type
/// @returns true if the given type is storable
bool IsStorable(const sem::Type* type) const;
/// @param type the given type
/// @returns true if the given type is host-shareable
bool IsHostShareable(const sem::Type* type) const;
/// Validates pipeline stages
/// @param entry_points the entry points to the module
/// @returns true on success, false otherwise.
bool PipelineStages(const std::vector<sem::Function*>& entry_points) const;
/// Validates aliases
/// @param alias the alias to validate
/// @returns true on success, false otherwise.
bool Alias(const ast::Alias* alias) const;
/// Validates the array
/// @param arr the array to validate
/// @param source the source of the array
/// @returns true on success, false otherwise.
bool Array(const sem::Array* arr, const Source& source) const;
/// Validates an array stride attribute
/// @param attr the stride attribute to validate
/// @param el_size the element size
/// @param el_align the element alignment
/// @param source the source of the attribute
/// @returns true on success, false otherwise
bool ArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size,
uint32_t el_align,
const Source& source) const;
/// Validates an atomic
/// @param a the atomic ast node to validate
/// @param s the atomic sem node
/// @returns true on success, false otherwise.
bool Atomic(const ast::Atomic* a, const sem::Atomic* s) const;
/// Validates an atoic variable
/// @param var the variable to validate
/// @param atomic_composite_info store atomic information
/// @returns true on success, false otherwise.
bool AtomicVariable(const sem::Variable* var,
std::unordered_map<const sem::Type*, const Source&>
atomic_composite_info) const;
/// Validates an assignment
/// @param a the assignment statement
/// @param rhs_ty the type of the right hand side
/// @returns true on success, false otherwise.
bool Assignment(const ast::Statement* a, const sem::Type* rhs_ty) const;
/// Validates a bitcase
/// @param cast the bitcast expression
/// @param to the destination type
/// @returns true on success, false otherwise
bool Bitcast(const ast::BitcastExpression* cast, const sem::Type* to) const;
/// Validates a break statement
/// @param stmt the break statement to validate
/// @param current_statement the current statement being resolved
/// @returns true on success, false otherwise.
bool BreakStatement(const sem::Statement* stmt,
sem::Statement* current_statement) const;
/// Validates a builtin attribute
/// @param attr the attribute to validate
/// @param storage_type the attribute storage type
/// @param stage the current pipeline stage
/// @param is_input true if this is an input attribute
/// @returns true on success, false otherwise.
bool BuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_type,
ast::PipelineStage stage,
const bool is_input) const;
/// Validates a continue statement
/// @param stmt the continue statement to validate
/// @param current_statement the current statement being resolved
/// @returns true on success, false otherwise
bool ContinueStatement(const sem::Statement* stmt,
sem::Statement* current_statement) const;
/// Validates a discard statement
/// @param stmt the statement to validate
/// @param current_statement the current statement being resolved
/// @returns true on success, false otherwise
bool DiscardStatement(const sem::Statement* stmt,
sem::Statement* current_statement) const;
/// Validates an else statement
/// @param stmt the else statement to validate
/// @returns true on success, false otherwise
bool ElseStatement(const sem::ElseStatement* stmt) const;
/// Validates an entry point
/// @param func the entry point function to validate
/// @param stage the pipeline stage for the entry point
/// @returns true on success, false otherwise
bool EntryPoint(const sem::Function* func, ast::PipelineStage stage) const;
/// Validates a for loop
/// @param stmt the for loop statement to validate
/// @returns true on success, false otherwise
bool ForLoopStatement(const sem::ForLoopStatement* stmt) const;
/// Validates a fallthrough statement
/// @param stmt the fallthrough to validate
/// @returns true on success, false otherwise
bool FallthroughStatement(const sem::Statement* stmt) const;
/// Validates a function
/// @param func the function to validate
/// @param stage the current pipeline stage
/// @returns true on success, false otherwise.
bool Function(const sem::Function* func, ast::PipelineStage stage) const;
/// Validates a function call
/// @param call the function call to validate
/// @param current_statement the current statement being resolved
/// @returns true on success, false otherwise
bool FunctionCall(const sem::Call* call,
sem::Statement* current_statement) const;
/// Validates a global variable
/// @param var the global variable to validate
/// @param constant_ids the set of constant ids in the module
/// @param atomic_composite_info atomic composite info in the module
/// @returns true on success, false otherwise
bool GlobalVariable(
const sem::Variable* var,
std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info)
const;
/// Validates an if statement
/// @param stmt the statement to validate
/// @returns true on success, false otherwise
bool IfStatement(const sem::IfStatement* stmt) const;
/// Validates an increment or decrement statement
/// @param stmt the statement to validate
/// @returns true on success, false otherwise
bool IncrementDecrementStatement(
const ast::IncrementDecrementStatement* stmt) const;
/// Validates an interpolate attribute
/// @param attr the interpolation attribute to validate
/// @param storage_type the storage type of the attached variable
/// @returns true on succes, false otherwise
bool InterpolateAttribute(const ast::InterpolateAttribute* attr,
const sem::Type* storage_type) const;
/// Validates a builtin call
/// @param call the builtin call to validate
/// @returns true on success, false otherwise.
bool BuiltinCall(const sem::Call* call) const;
/// Validates a location attribute
/// @param location the location attribute to validate
/// @param type the variable type
/// @param locations the set of locations in the module
/// @param stage the current pipeline stage
/// @param source the source of the attribute
/// @param is_input true if this is an input variable
/// @returns true on success, false otherwise.
bool LocationAttribute(const ast::LocationAttribute* location,
const sem::Type* type,
std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input = false) const;
/// Validates a loop statement
/// @param stmt the loop statement
/// @returns true on success, false otherwise.
bool LoopStatement(const sem::LoopStatement* stmt) const;
/// Validates a matrix
/// @param ty the matrix to validate
/// @param source the source of the matrix
/// @returns true on success, false otherwise
bool Matrix(const sem::Matrix* ty, const Source& source) const;
/// Validates a function parameter
/// @param func the function the variable is for
/// @param var the variable to validate
/// @returns true on success, false otherwise
bool FunctionParameter(const ast::Function* func,
const sem::Variable* var) const;
/// Validates a return
/// @param ret the return statement to validate
/// @param func_type the return type of the curreunt function
/// @param ret_type the return type
/// @param current_statement the current statement being resolved
/// @returns true on success, false otherwise
bool Return(const ast::ReturnStatement* ret,
const sem::Type* func_type,
const sem::Type* ret_type,
sem::Statement* current_statement) const;
/// Validates a list of statements
/// @param stmts the statements to validate
/// @returns true on success, false otherwise
bool Statements(const ast::StatementList& stmts) const;
/// Validates a storage texture
/// @param t the texture to validate
/// @returns true on success, false otherwise
bool StorageTexture(const ast::StorageTexture* t) const;
/// Validates a structure
/// @param str the structure to validate
/// @param stage the current pipeline stage
/// @returns true on success, false otherwise.
bool Structure(const sem::Struct* str, ast::PipelineStage stage) const;
/// Validates a structure constructor or cast
/// @param ctor the call expression to validate
/// @param struct_type the type of the structure
/// @returns true on success, false otherwise
bool StructureConstructorOrCast(const ast::CallExpression* ctor,
const sem::Struct* struct_type) const;
/// Validates a switch statement
/// @param s the switch to validate
/// @returns true on success, false otherwise
bool SwitchStatement(const ast::SwitchStatement* s);
/// Validates a variable
/// @param var the variable to validate
/// @returns true on success, false otherwise.
bool Variable(const sem::Variable* var) const;
/// Validates a variable constructor or cast
/// @param var the variable to validate
/// @param storage_class the storage class of the variable
/// @param storage_type the type of the storage
/// @param rhs_type the right hand side of the expression
/// @returns true on succes, false otherwise
bool VariableConstructorOrCast(const ast::Variable* var,
ast::StorageClass storage_class,
const sem::Type* storage_type,
const sem::Type* rhs_type) const;
/// Validates a vector
/// @param ty the vector to validate
/// @param source the source of the vector
/// @returns true on success, false otherwise
bool Vector(const sem::Vector* ty, const Source& source) const;
/// Validates a vector constructor or cast
/// @param ctor the call expression to validate
/// @param vec_type the vector type
/// @returns true on success, false otherwise
bool VectorConstructorOrCast(const ast::CallExpression* ctor,
const sem::Vector* vec_type) const;
/// Validates a matrix constructor or cast
/// @param ctor the call expression to validate
/// @param matrix_type the type of the matrix
/// @returns true on success, false otherwise
bool MatrixConstructorOrCast(const ast::CallExpression* ctor,
const sem::Matrix* matrix_type) const;
/// Validates a scalar constructor or cast
/// @param ctor the call expression to validate
/// @param type the type of the scalar
/// @returns true on success, false otherwise.
bool ScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* type) const;
/// Validates an array constructor or cast
/// @param ctor the call expresion to validate
/// @param arr_type the type of the array
/// @returns true on success, false otherwise
bool ArrayConstructorOrCast(const ast::CallExpression* ctor,
const sem::Array* arr_type) const;
/// Validates a texture builtin function
/// @param call the builtin call to validate
/// @returns true on success, false otherwise
bool TextureBuiltinFunction(const sem::Call* call) const;
/// Validates there are no duplicate attributes
/// @param attributes the list of attributes to validate
/// @returns true on success, false otherwise.
bool NoDuplicateAttributes(const ast::AttributeList& attributes) const;
/// Validates a storage class layout
/// @param type the type to validate
/// @param sc the storage class
/// @param source the source of the type
/// @param layouts previously validated storage layouts
/// @returns true on success, false otherwise
bool StorageClassLayout(const sem::Type* type,
ast::StorageClass sc,
Source source,
ValidTypeStorageLayouts& layouts) const;
/// Validates a storage class layout
/// @param var the variable to validate
/// @param layouts previously validated storage layouts
/// @returns true on success, false otherwise.
bool StorageClassLayout(const sem::Variable* var,
ValidTypeStorageLayouts& layouts) const;
/// @returns true if the attribute list contains a
/// ast::DisableValidationAttribute with the validation mode equal to
/// `validation`
/// @param attributes the attribute list to check
/// @param validation the validation mode to check
bool IsValidationDisabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const;
/// @returns true if the attribute list does not contains a
/// ast::DisableValidationAttribute with the validation mode equal to
/// `validation`
/// @param attributes the attribute list to check
/// @param validation the validation mode to check
bool IsValidationEnabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const;
private:
/// Searches the current statement and up through parents of the current
/// statement looking for a loop or for-loop continuing statement.
/// @returns the closest continuing statement to the current statement that
/// (transitively) owns the current statement.
/// @param stop_at_loop if true then the function will return nullptr if a
/// loop or for-loop was found before the continuing.
/// @param current_statement the current statement being resolved
const ast::Statement* ClosestContinuing(
bool stop_at_loop,
sem::Statement* current_statement) const;
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, const sem::Type* element_type) const;
SymbolTable& symbols_;
diag::List& diagnostics_;
SemHelper& sem_;
};
} // namespace tint::resolver
#endif // SRC_TINT_RESOLVER_VALIDATOR_H_

View File

@ -0,0 +1,86 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/validator.h"
#include "gmock/gmock.h"
#include "src/tint/resolver/validator_test_helper.h"
#include "src/tint/sem/atomic_type.h"
namespace tint::resolver {
namespace {
using ValidatorIsStorableTest = ValidatorTest;
TEST_F(ValidatorIsStorableTest, Void) {
EXPECT_FALSE(v()->IsStorable(create<sem::Void>()));
}
TEST_F(ValidatorIsStorableTest, Scalar) {
EXPECT_TRUE(v()->IsStorable(create<sem::Bool>()));
EXPECT_TRUE(v()->IsStorable(create<sem::I32>()));
EXPECT_TRUE(v()->IsStorable(create<sem::U32>()));
EXPECT_TRUE(v()->IsStorable(create<sem::F32>()));
}
TEST_F(ValidatorIsStorableTest, Vector) {
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::I32>(), 2u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::I32>(), 3u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::I32>(), 4u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::U32>(), 2u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::U32>(), 3u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::U32>(), 4u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::F32>(), 2u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::F32>(), 3u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::F32>(), 4u)));
}
TEST_F(ValidatorIsStorableTest, Matrix) {
auto* vec2 = create<sem::Vector>(create<sem::F32>(), 2u);
auto* vec3 = create<sem::Vector>(create<sem::F32>(), 3u);
auto* vec4 = create<sem::Vector>(create<sem::F32>(), 4u);
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec2, 2u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec2, 3u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec2, 4u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec3, 2u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec3, 3u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec3, 4u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec4, 2u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec4, 3u)));
EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec4, 4u)));
}
TEST_F(ValidatorIsStorableTest, Pointer) {
auto* ptr = create<sem::Pointer>(
create<sem::I32>(), ast::StorageClass::kPrivate, ast::Access::kReadWrite);
EXPECT_FALSE(v()->IsStorable(ptr));
}
TEST_F(ValidatorIsStorableTest, Atomic) {
EXPECT_TRUE(v()->IsStorable(create<sem::Atomic>(create<sem::I32>())));
EXPECT_TRUE(v()->IsStorable(create<sem::Atomic>(create<sem::U32>())));
}
TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
} // namespace
} // namespace tint::resolver

View File

@ -0,0 +1,27 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/validator_test_helper.h"
#include <memory>
namespace tint::resolver {
TestHelper::TestHelper()
: validator_(
std::make_unique<Validator>(this->Symbols(), this->Diagnostics())) {}
TestHelper::~TestHelper() = default;
} // namespace tint::resolver

View File

@ -0,0 +1,46 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_
#define SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_
#include <memory>
#include "gtest/gtest.h"
#include "src/tint/program_builder.h"
#include "src/tint/resolver/validator.h"
namespace tint::resolver {
/// Helper class for testing
class TestHelper : public ProgramBuilder {
public:
/// Constructor
TestHelper();
/// Destructor
~TestHelper() override;
/// @return a pointer to the Validator
Validator* v() const { return validator_.get(); }
private:
std::unique_ptr<Validator> validator_;
};
class ValidatorTest : public TestHelper, public testing::Test {};
} // namespace tint::resolver
#endif // SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_