resolver: DepGraph - Traverse types and decorations.

These also need to depend on types / values.

Bug: tint:819
Bug: tint:1266
Change-Id: Ia044d7823aca845dc57a887a164e07137d913429
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/70522
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-11-23 18:40:27 +00:00 committed by Tint LUCI CQ
parent 1185d61648
commit b93ba6ead5
3 changed files with 318 additions and 52 deletions

View File

@ -26,6 +26,7 @@
#include "src/scope_stack.h" #include "src/scope_stack.h"
#include "src/sem/intrinsic.h" #include "src/sem/intrinsic.h"
#include "src/utils/defer.h" #include "src/utils/defer.h"
#include "src/utils/map.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
#include "src/utils/unique_vector.h" #include "src/utils/unique_vector.h"
@ -87,11 +88,10 @@ struct Global {
/// A map of global name to Global /// A map of global name to Global
using GlobalMap = std::unordered_map<Symbol, Global*>; using GlobalMap = std::unordered_map<Symbol, Global*>;
/// Raises an ICE that a global ast::Node declaration type was not handled by /// Raises an ICE that a global ast::Node type was not handled by this system.
/// this system. void UnhandledNode(diag::List& diagnostics, const ast::Node* node) {
void UnhandledDecl(diag::List& diagnostics, const ast::Node* node) { TINT_ICE(Resolver, diagnostics)
TINT_UNREACHABLE(Resolver, diagnostics) << "unhandled node type: " << node->TypeInfo().name;
<< "unhandled global declaration: " << node->TypeInfo().name;
} }
/// Raises an error diagnostic with the given message and source. /// Raises an error diagnostic with the given message and source.
@ -143,55 +143,59 @@ class DependencyScanner {
if (auto* str = global->node->As<ast::Struct>()) { if (auto* str = global->node->As<ast::Struct>()) {
Declare(str->name, str); Declare(str->name, str);
for (auto* member : str->members) { for (auto* member : str->members) {
ResolveTypeDependency(member->type); TraverseType(member->type);
} }
return; return;
} }
if (auto* alias = global->node->As<ast::Alias>()) { if (auto* alias = global->node->As<ast::Alias>()) {
Declare(alias->name, alias); Declare(alias->name, alias);
ResolveTypeDependency(alias->type); TraverseType(alias->type);
return; return;
} }
if (auto* func = global->node->As<ast::Function>()) { if (auto* func = global->node->As<ast::Function>()) {
Declare(func->symbol, func); Declare(func->symbol, func);
TraverseDecorations(func->decorations);
TraverseFunction(func); TraverseFunction(func);
return; return;
} }
if (auto* var = global->node->As<ast::Variable>()) { if (auto* var = global->node->As<ast::Variable>()) {
Declare(var->symbol, var); Declare(var->symbol, var);
ResolveTypeDependency(var->type); TraverseType(var->type);
if (var->constructor) { if (var->constructor) {
TraverseExpression(var->constructor); TraverseExpression(var->constructor);
} }
return; return;
} }
UnhandledDecl(diagnostics_, global->node); UnhandledNode(diagnostics_, global->node);
} }
private: private:
/// Traverses the function determining global dependencies. /// Traverses the function, performing symbol resolution and determining
/// global dependencies.
void TraverseFunction(const ast::Function* func) { void TraverseFunction(const ast::Function* func) {
scope_stack_.Push(); scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop()); TINT_DEFER(scope_stack_.Pop());
for (auto* param : func->params) { for (auto* param : func->params) {
Declare(param->symbol, param); Declare(param->symbol, param);
ResolveTypeDependency(param->type); TraverseType(param->type);
} }
if (func->body) { if (func->body) {
TraverseStatements(func->body->statements); TraverseStatements(func->body->statements);
} }
ResolveTypeDependency(func->return_type); TraverseType(func->return_type);
} }
/// Traverses the statements determining global dependencies. /// Traverses the statements, performing symbol resolution and determining
/// global dependencies.
void TraverseStatements(const ast::StatementList& stmts) { void TraverseStatements(const ast::StatementList& stmts) {
for (auto* s : stmts) { for (auto* s : stmts) {
TraverseStatement(s); TraverseStatement(s);
} }
} }
/// Traverses the statement determining global dependencies. /// Traverses the statement, performing symbol resolution and determining
/// global dependencies.
void TraverseStatement(const ast::Statement* stmt) { void TraverseStatement(const ast::Statement* stmt) {
if (stmt == nullptr) { if (stmt == nullptr) {
return; return;
@ -253,7 +257,7 @@ class DependencyScanner {
} }
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
Declare(v->variable->symbol, v->variable); Declare(v->variable->symbol, v->variable);
ResolveTypeDependency(v->variable->type); TraverseType(v->variable->type);
TraverseExpression(v->variable->constructor); TraverseExpression(v->variable->constructor);
return; return;
} }
@ -262,9 +266,7 @@ class DependencyScanner {
return; return;
} }
AddError(diagnostics_, UnhandledNode(diagnostics_, stmt);
"unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source);
} }
/// Adds the symbol definition to the current scope, raising an error if two /// Adds the symbol definition to the current scope, raising an error if two
@ -279,7 +281,8 @@ class DependencyScanner {
} }
} }
/// Traverses the expression determining global dependencies. /// Traverses the expression, performing symbol resolution and determining
/// global dependencies.
void TraverseExpression(const ast::Expression* root) { void TraverseExpression(const ast::Expression* root) {
if (!root) { if (!root) {
return; return;
@ -309,24 +312,101 @@ class DependencyScanner {
ResolveGlobalDependency(call->target.name, ResolveGlobalDependency(call->target.name,
call->target.name->symbol, "function", call->target.name->symbol, "function",
"calls"); "calls");
graph_.resolved_symbols.emplace(
call,
utils::Lookup(graph_.resolved_symbols, call->target.name));
} }
} }
if (call->target.type) { if (call->target.type) {
ResolveTypeDependency(call->target.type); TraverseType(call->target.type);
graph_.resolved_symbols.emplace(
call,
utils::Lookup(graph_.resolved_symbols, call->target.type));
} }
} }
return ast::TraverseAction::Descend; return ast::TraverseAction::Descend;
}); });
} }
/// Adds the type dependency to the currently processed global /// Traverses the type node, performing symbol resolution and determining
void ResolveTypeDependency(const ast::Type* ty) { /// global dependencies.
void TraverseType(const ast::Type* ty) {
if (ty == nullptr) { if (ty == nullptr) {
return; return;
} }
if (auto* arr = ty->As<ast::Array>()) {
TraverseType(arr->type);
TraverseExpression(arr->count);
return;
}
if (auto* atomic = ty->As<ast::Atomic>()) {
TraverseType(atomic->type);
return;
}
if (auto* mat = ty->As<ast::Matrix>()) {
TraverseType(mat->type);
return;
}
if (auto* ptr = ty->As<ast::Pointer>()) {
TraverseType(ptr->type);
return;
}
if (auto* tn = ty->As<ast::TypeName>()) { if (auto* tn = ty->As<ast::TypeName>()) {
ResolveGlobalDependency(tn, tn->name, "type", "references"); ResolveGlobalDependency(tn, tn->name, "type", "references");
return;
} }
if (auto* vec = ty->As<ast::Vector>()) {
TraverseType(vec->type);
return;
}
if (auto* tex = ty->As<ast::SampledTexture>()) {
TraverseType(tex->type);
return;
}
if (auto* tex = ty->As<ast::MultisampledTexture>()) {
TraverseType(tex->type);
return;
}
if (ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32,
ast::DepthTexture, ast::DepthMultisampledTexture,
ast::StorageTexture, ast::ExternalTexture,
ast::Sampler>()) {
return;
}
UnhandledNode(diagnostics_, ty);
}
/// Traverses the decoration list, performing symbol resolution and
/// determining global dependencies.
void TraverseDecorations(const ast::DecorationList& decos) {
for (auto* deco : decos) {
TraverseDecoration(deco);
}
}
/// Traverses the decoration, performing symbol resolution and determining
/// global dependencies.
void TraverseDecoration(const ast::Decoration* deco) {
if (auto* wg = deco->As<ast::WorkgroupDecoration>()) {
TraverseExpression(wg->x);
TraverseExpression(wg->y);
TraverseExpression(wg->z);
return;
}
if (deco->IsAnyOf<ast::BindingDecoration, ast::BuiltinDecoration,
ast::GroupDecoration, ast::InternalDecoration,
ast::InterpolateDecoration, ast::InvariantDecoration,
ast::LocationDecoration, ast::OverrideDecoration,
ast::StageDecoration, ast::StrideDecoration,
ast::StructBlockDecoration,
ast::StructMemberAlignDecoration,
ast::StructMemberOffsetDecoration,
ast::StructMemberSizeDecoration>()) {
return;
}
UnhandledNode(diagnostics_, deco);
} }
/// Adds the dependency to the currently processed global /// Adds the dependency to the currently processed global
@ -426,7 +506,7 @@ struct DependencyAnalysis {
if (auto* var = node->As<ast::Variable>()) { if (auto* var = node->As<ast::Variable>()) {
return var->symbol; return var->symbol;
} }
UnhandledDecl(diagnostics_, node); UnhandledNode(diagnostics_, node);
return {}; return {};
} }
@ -455,7 +535,7 @@ struct DependencyAnalysis {
if (auto* var = node->As<ast::Variable>()) { if (auto* var = node->As<ast::Variable>()) {
return var->is_const ? "let" : "var"; return var->is_const ? "let" : "var";
} }
UnhandledDecl(diagnostics_, node); UnhandledNode(diagnostics_, node);
return {}; return {};
} }

View File

@ -101,14 +101,28 @@ static constexpr SymbolDeclKind kFuncDeclKinds[] = {
/// kinds of symbol uses. /// kinds of symbol uses.
enum class SymbolUseKind { enum class SymbolUseKind {
GlobalVarType, GlobalVarType,
GlobalVarArrayElemType,
GlobalVarArraySizeValue,
GlobalVarVectorElemType,
GlobalVarMatrixElemType,
GlobalVarSampledTexElemType,
GlobalVarMultisampledTexElemType,
GlobalVarValue, GlobalVarValue,
GlobalLetType, GlobalLetType,
GlobalLetArrayElemType,
GlobalLetArraySizeValue,
GlobalLetVectorElemType,
GlobalLetMatrixElemType,
GlobalLetValue, GlobalLetValue,
AliasType, AliasType,
StructMemberType, StructMemberType,
CallFunction, CallFunction,
ParameterType, ParameterType,
LocalVarType, LocalVarType,
LocalVarArrayElemType,
LocalVarArraySizeValue,
LocalVarVectorElemType,
LocalVarMatrixElemType,
LocalVarValue, LocalVarValue,
LocalLetType, LocalLetType,
LocalLetValue, LocalLetValue,
@ -116,13 +130,32 @@ enum class SymbolUseKind {
NestedLocalVarValue, NestedLocalVarValue,
NestedLocalLetType, NestedLocalLetType,
NestedLocalLetValue, NestedLocalLetValue,
WorkgroupSizeValue,
}; };
static constexpr SymbolUseKind kTypeUseKinds[] = { static constexpr SymbolUseKind kTypeUseKinds[] = {
SymbolUseKind::GlobalVarType, SymbolUseKind::GlobalLetType, SymbolUseKind::GlobalVarType,
SymbolUseKind::AliasType, SymbolUseKind::StructMemberType, SymbolUseKind::GlobalVarArrayElemType,
SymbolUseKind::ParameterType, SymbolUseKind::LocalVarType, SymbolUseKind::GlobalVarArraySizeValue,
SymbolUseKind::LocalLetType, SymbolUseKind::NestedLocalVarType, SymbolUseKind::GlobalVarVectorElemType,
SymbolUseKind::GlobalVarMatrixElemType,
SymbolUseKind::GlobalVarSampledTexElemType,
SymbolUseKind::GlobalVarMultisampledTexElemType,
SymbolUseKind::GlobalLetType,
SymbolUseKind::GlobalLetArrayElemType,
SymbolUseKind::GlobalLetArraySizeValue,
SymbolUseKind::GlobalLetVectorElemType,
SymbolUseKind::GlobalLetMatrixElemType,
SymbolUseKind::AliasType,
SymbolUseKind::StructMemberType,
SymbolUseKind::ParameterType,
SymbolUseKind::LocalVarType,
SymbolUseKind::LocalVarArrayElemType,
SymbolUseKind::LocalVarArraySizeValue,
SymbolUseKind::LocalVarVectorElemType,
SymbolUseKind::LocalVarMatrixElemType,
SymbolUseKind::LocalLetType,
SymbolUseKind::NestedLocalVarType,
SymbolUseKind::NestedLocalLetType, SymbolUseKind::NestedLocalLetType,
}; };
@ -130,6 +163,7 @@ static constexpr SymbolUseKind kValueUseKinds[] = {
SymbolUseKind::GlobalVarValue, SymbolUseKind::GlobalLetValue, SymbolUseKind::GlobalVarValue, SymbolUseKind::GlobalLetValue,
SymbolUseKind::LocalVarValue, SymbolUseKind::LocalLetValue, SymbolUseKind::LocalVarValue, SymbolUseKind::LocalLetValue,
SymbolUseKind::NestedLocalVarValue, SymbolUseKind::NestedLocalLetValue, SymbolUseKind::NestedLocalVarValue, SymbolUseKind::NestedLocalLetValue,
SymbolUseKind::WorkgroupSizeValue,
}; };
static constexpr SymbolUseKind kFuncUseKinds[] = { static constexpr SymbolUseKind kFuncUseKinds[] = {
@ -172,10 +206,30 @@ std::ostream& operator<<(std::ostream& out, SymbolUseKind kind) {
return out << "global var type"; return out << "global var type";
case SymbolUseKind::GlobalVarValue: case SymbolUseKind::GlobalVarValue:
return out << "global var value"; return out << "global var value";
case SymbolUseKind::GlobalVarArrayElemType:
return out << "global var array element type";
case SymbolUseKind::GlobalVarArraySizeValue:
return out << "global var array size value";
case SymbolUseKind::GlobalVarVectorElemType:
return out << "global var vector element type";
case SymbolUseKind::GlobalVarMatrixElemType:
return out << "global var matrix element type";
case SymbolUseKind::GlobalVarSampledTexElemType:
return out << "global var sampled_texture element type";
case SymbolUseKind::GlobalVarMultisampledTexElemType:
return out << "global var multisampled_texture element type";
case SymbolUseKind::GlobalLetType: case SymbolUseKind::GlobalLetType:
return out << "global let type"; return out << "global let type";
case SymbolUseKind::GlobalLetValue: case SymbolUseKind::GlobalLetValue:
return out << "global let value"; return out << "global let value";
case SymbolUseKind::GlobalLetArrayElemType:
return out << "global let array element type";
case SymbolUseKind::GlobalLetArraySizeValue:
return out << "global let array size value";
case SymbolUseKind::GlobalLetVectorElemType:
return out << "global let vector element type";
case SymbolUseKind::GlobalLetMatrixElemType:
return out << "global let matrix element type";
case SymbolUseKind::AliasType: case SymbolUseKind::AliasType:
return out << "alias type"; return out << "alias type";
case SymbolUseKind::StructMemberType: case SymbolUseKind::StructMemberType:
@ -186,6 +240,14 @@ std::ostream& operator<<(std::ostream& out, SymbolUseKind kind) {
return out << "parameter type"; return out << "parameter type";
case SymbolUseKind::LocalVarType: case SymbolUseKind::LocalVarType:
return out << "local var type"; return out << "local var type";
case SymbolUseKind::LocalVarArrayElemType:
return out << "local var array element type";
case SymbolUseKind::LocalVarArraySizeValue:
return out << "local var array size value";
case SymbolUseKind::LocalVarVectorElemType:
return out << "local var vector element type";
case SymbolUseKind::LocalVarMatrixElemType:
return out << "local var matrix element type";
case SymbolUseKind::LocalVarValue: case SymbolUseKind::LocalVarValue:
return out << "local var value"; return out << "local var value";
case SymbolUseKind::LocalLetType: case SymbolUseKind::LocalLetType:
@ -200,6 +262,8 @@ std::ostream& operator<<(std::ostream& out, SymbolUseKind kind) {
return out << "nested local let type"; return out << "nested local let type";
case SymbolUseKind::NestedLocalLetValue: case SymbolUseKind::NestedLocalLetValue:
return out << "nested local let value"; return out << "nested local let value";
case SymbolUseKind::WorkgroupSizeValue:
return out << "workgroup size value";
} }
return out << "<unknown>"; return out << "<unknown>";
} }
@ -208,21 +272,36 @@ std::ostream& operator<<(std::ostream& out, SymbolUseKind kind) {
std::string DiagString(SymbolUseKind kind) { std::string DiagString(SymbolUseKind kind) {
switch (kind) { switch (kind) {
case SymbolUseKind::GlobalVarType: case SymbolUseKind::GlobalVarType:
case SymbolUseKind::GlobalVarArrayElemType:
case SymbolUseKind::GlobalVarVectorElemType:
case SymbolUseKind::GlobalVarMatrixElemType:
case SymbolUseKind::GlobalVarSampledTexElemType:
case SymbolUseKind::GlobalVarMultisampledTexElemType:
case SymbolUseKind::GlobalLetType: case SymbolUseKind::GlobalLetType:
case SymbolUseKind::GlobalLetArrayElemType:
case SymbolUseKind::GlobalLetVectorElemType:
case SymbolUseKind::GlobalLetMatrixElemType:
case SymbolUseKind::AliasType: case SymbolUseKind::AliasType:
case SymbolUseKind::StructMemberType: case SymbolUseKind::StructMemberType:
case SymbolUseKind::ParameterType: case SymbolUseKind::ParameterType:
case SymbolUseKind::LocalVarType: case SymbolUseKind::LocalVarType:
case SymbolUseKind::LocalVarArrayElemType:
case SymbolUseKind::LocalVarVectorElemType:
case SymbolUseKind::LocalVarMatrixElemType:
case SymbolUseKind::LocalLetType: case SymbolUseKind::LocalLetType:
case SymbolUseKind::NestedLocalVarType: case SymbolUseKind::NestedLocalVarType:
case SymbolUseKind::NestedLocalLetType: case SymbolUseKind::NestedLocalLetType:
return "type"; return "type";
case SymbolUseKind::GlobalVarValue: case SymbolUseKind::GlobalVarValue:
case SymbolUseKind::GlobalVarArraySizeValue:
case SymbolUseKind::GlobalLetValue: case SymbolUseKind::GlobalLetValue:
case SymbolUseKind::GlobalLetArraySizeValue:
case SymbolUseKind::LocalVarValue: case SymbolUseKind::LocalVarValue:
case SymbolUseKind::LocalVarArraySizeValue:
case SymbolUseKind::LocalLetValue: case SymbolUseKind::LocalLetValue:
case SymbolUseKind::NestedLocalVarValue: case SymbolUseKind::NestedLocalVarValue:
case SymbolUseKind::NestedLocalLetValue: case SymbolUseKind::NestedLocalLetValue:
case SymbolUseKind::WorkgroupSizeValue:
return "identifier"; return "identifier";
case SymbolUseKind::CallFunction: case SymbolUseKind::CallFunction:
return "function"; return "function";
@ -259,14 +338,29 @@ int ScopeDepth(SymbolUseKind kind) {
switch (kind) { switch (kind) {
case SymbolUseKind::GlobalVarType: case SymbolUseKind::GlobalVarType:
case SymbolUseKind::GlobalVarValue: case SymbolUseKind::GlobalVarValue:
case SymbolUseKind::GlobalVarArrayElemType:
case SymbolUseKind::GlobalVarArraySizeValue:
case SymbolUseKind::GlobalVarVectorElemType:
case SymbolUseKind::GlobalVarMatrixElemType:
case SymbolUseKind::GlobalVarSampledTexElemType:
case SymbolUseKind::GlobalVarMultisampledTexElemType:
case SymbolUseKind::GlobalLetType: case SymbolUseKind::GlobalLetType:
case SymbolUseKind::GlobalLetValue: case SymbolUseKind::GlobalLetValue:
case SymbolUseKind::GlobalLetArrayElemType:
case SymbolUseKind::GlobalLetArraySizeValue:
case SymbolUseKind::GlobalLetVectorElemType:
case SymbolUseKind::GlobalLetMatrixElemType:
case SymbolUseKind::AliasType: case SymbolUseKind::AliasType:
case SymbolUseKind::StructMemberType: case SymbolUseKind::StructMemberType:
case SymbolUseKind::WorkgroupSizeValue:
return 0; return 0;
case SymbolUseKind::CallFunction: case SymbolUseKind::CallFunction:
case SymbolUseKind::ParameterType: case SymbolUseKind::ParameterType:
case SymbolUseKind::LocalVarType: case SymbolUseKind::LocalVarType:
case SymbolUseKind::LocalVarArrayElemType:
case SymbolUseKind::LocalVarArraySizeValue:
case SymbolUseKind::LocalVarVectorElemType:
case SymbolUseKind::LocalVarMatrixElemType:
case SymbolUseKind::LocalVarValue: case SymbolUseKind::LocalVarValue:
case SymbolUseKind::LocalLetType: case SymbolUseKind::LocalLetType:
case SymbolUseKind::LocalLetValue: case SymbolUseKind::LocalLetValue:
@ -290,6 +384,8 @@ struct SymbolTestHelper {
std::vector<const ast::Statement*> statements; std::vector<const ast::Statement*> statements;
/// Nested function local var / let declaration statements /// Nested function local var / let declaration statements
std::vector<const ast::Statement*> nested_statements; std::vector<const ast::Statement*> nested_statements;
/// Function decorations
ast::DecorationList func_decos;
/// Constructor /// Constructor
/// @param builder the program builder /// @param builder the program builder
@ -374,6 +470,38 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind,
b.Global(b.Sym(), node, ast::StorageClass::kPrivate); b.Global(b.Sym(), node, ast::StorageClass::kPrivate);
return node; return node;
} }
case SymbolUseKind::GlobalVarArrayElemType: {
auto* node = b.ty.type_name(source, symbol);
b.Global(b.Sym(), b.ty.array(node, 4), ast::StorageClass::kPrivate);
return node;
}
case SymbolUseKind::GlobalVarArraySizeValue: {
auto* node = b.Expr(source, symbol);
b.Global(b.Sym(), b.ty.array(b.ty.i32(), node),
ast::StorageClass::kPrivate);
return node;
}
case SymbolUseKind::GlobalVarVectorElemType: {
auto* node = b.ty.type_name(source, symbol);
b.Global(b.Sym(), b.ty.vec3(node), ast::StorageClass::kPrivate);
return node;
}
case SymbolUseKind::GlobalVarMatrixElemType: {
auto* node = b.ty.type_name(source, symbol);
b.Global(b.Sym(), b.ty.mat3x4(node), ast::StorageClass::kPrivate);
return node;
}
case SymbolUseKind::GlobalVarSampledTexElemType: {
auto* node = b.ty.type_name(source, symbol);
b.Global(b.Sym(), b.ty.sampled_texture(ast::TextureDimension::k2d, node));
return node;
}
case SymbolUseKind::GlobalVarMultisampledTexElemType: {
auto* node = b.ty.type_name(source, symbol);
b.Global(b.Sym(),
b.ty.multisampled_texture(ast::TextureDimension::k2d, node));
return node;
}
case SymbolUseKind::GlobalVarValue: { case SymbolUseKind::GlobalVarValue: {
auto* node = b.Expr(source, symbol); auto* node = b.Expr(source, symbol);
b.Global(b.Sym(), b.ty.i32(), ast::StorageClass::kPrivate, node); b.Global(b.Sym(), b.ty.i32(), ast::StorageClass::kPrivate, node);
@ -384,6 +512,26 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind,
b.GlobalConst(b.Sym(), node, b.Expr(1)); b.GlobalConst(b.Sym(), node, b.Expr(1));
return node; return node;
} }
case SymbolUseKind::GlobalLetArrayElemType: {
auto* node = b.ty.type_name(source, symbol);
b.GlobalConst(b.Sym(), b.ty.array(node, 4), b.Expr(1));
return node;
}
case SymbolUseKind::GlobalLetArraySizeValue: {
auto* node = b.Expr(source, symbol);
b.GlobalConst(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1));
return node;
}
case SymbolUseKind::GlobalLetVectorElemType: {
auto* node = b.ty.type_name(source, symbol);
b.GlobalConst(b.Sym(), b.ty.vec3(node), b.Expr(1));
return node;
}
case SymbolUseKind::GlobalLetMatrixElemType: {
auto* node = b.ty.type_name(source, symbol);
b.GlobalConst(b.Sym(), b.ty.mat3x4(node), b.Expr(1));
return node;
}
case SymbolUseKind::GlobalLetValue: { case SymbolUseKind::GlobalLetValue: {
auto* node = b.Expr(source, symbol); auto* node = b.Expr(source, symbol);
b.GlobalConst(b.Sym(), b.ty.i32(), node); b.GlobalConst(b.Sym(), b.ty.i32(), node);
@ -414,6 +562,28 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind,
statements.emplace_back(b.Decl(b.Var(b.Sym(), node))); statements.emplace_back(b.Decl(b.Var(b.Sym(), node)));
return node; return node;
} }
case SymbolUseKind::LocalVarArrayElemType: {
auto* node = b.ty.type_name(source, symbol);
statements.emplace_back(
b.Decl(b.Var(b.Sym(), b.ty.array(node, 4), b.Expr(1))));
return node;
}
case SymbolUseKind::LocalVarArraySizeValue: {
auto* node = b.Expr(source, symbol);
statements.emplace_back(
b.Decl(b.Var(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1))));
return node;
}
case SymbolUseKind::LocalVarVectorElemType: {
auto* node = b.ty.type_name(source, symbol);
statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.vec3(node))));
return node;
}
case SymbolUseKind::LocalVarMatrixElemType: {
auto* node = b.ty.type_name(source, symbol);
statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.mat3x4(node))));
return node;
}
case SymbolUseKind::LocalVarValue: { case SymbolUseKind::LocalVarValue: {
auto* node = b.Expr(source, symbol); auto* node = b.Expr(source, symbol);
statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.i32(), node))); statements.emplace_back(b.Decl(b.Var(b.Sym(), b.ty.i32(), node)));
@ -450,6 +620,11 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind,
b.Decl(b.Const(b.Sym(), b.ty.i32(), node))); b.Decl(b.Const(b.Sym(), b.ty.i32(), node)));
return node; return node;
} }
case SymbolUseKind::WorkgroupSizeValue: {
auto* node = b.Expr(source, symbol);
func_decos.emplace_back(b.WorkgroupSize(1, node, 2));
return node;
}
} }
return nullptr; return nullptr;
} }
@ -460,10 +635,11 @@ void SymbolTestHelper::Build() {
statements.emplace_back(b.Block(nested_statements)); statements.emplace_back(b.Block(nested_statements));
nested_statements.clear(); nested_statements.clear();
} }
if (!parameters.empty() || !statements.empty()) { if (!parameters.empty() || !statements.empty() || !func_decos.empty()) {
b.Func("func", parameters, b.ty.void_(), statements); b.Func("func", parameters, b.ty.void_(), statements, func_decos);
parameters.clear(); parameters.clear();
statements.clear(); statements.clear();
func_decos.clear();
} }
} }
@ -998,9 +1174,9 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
Structure(Sym(), {Member(Sym(), T)}); Structure(Sym(), {Member(Sym(), T)});
Global(Sym(), T, V); Global(Sym(), T, V);
GlobalConst(Sym(), T, V); GlobalConst(Sym(), T, V);
Func(Sym(), // Func(Sym(), //
{Param("p", T)}, // {Param(Sym(), T)}, //
T, // Return type T, // Return type
{ {
Decl(Var(Sym(), T, V)), // Decl(Var(Sym(), T, V)), //
Decl(Const(Sym(), T, V)), // Decl(Const(Sym(), T, V)), //
@ -1027,7 +1203,27 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
Return(V), // Return(V), //
Break(), // Break(), //
Discard(), // Discard(), //
}); }); //
// Exercise type traversal
Global(Sym(), ty.atomic(T));
Global(Sym(), ty.bool_());
Global(Sym(), ty.i32());
Global(Sym(), ty.u32());
Global(Sym(), ty.f32());
Global(Sym(), ty.array(T, V, 4));
Global(Sym(), ty.vec3(T));
Global(Sym(), ty.mat3x2(T));
Global(Sym(), ty.pointer(T, ast::StorageClass::kPrivate));
Global(Sym(), ty.sampled_texture(ast::TextureDimension::k2d, T));
Global(Sym(), ty.depth_texture(ast::TextureDimension::k2d));
Global(Sym(), ty.depth_multisampled_texture(ast::TextureDimension::k2d));
Global(Sym(), ty.external_texture());
Global(Sym(), ty.multisampled_texture(ast::TextureDimension::k2d, T));
Global(Sym(), ty.storage_texture(ast::TextureDimension::k2d,
ast::ImageFormat::kR16Float,
ast::Access::kRead)); //
Global(Sym(), ty.sampler(ast::SamplerKind::kSampler));
Func(Sym(), {}, ty.void_(), {});
#undef V #undef V
#undef T #undef T
#undef F #undef F

View File

@ -104,24 +104,14 @@ TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) {
9:10 note: called by entry point 'f0')"); 9:10 note: called by entry point 'f0')");
} }
TEST_F(ResolverValidationTest, Error_WithEmptySource) { TEST_F(ResolverValidationTest, UnhandledStmt) {
auto* s = create<FakeStmt>(); EXPECT_FATAL_FAILURE(
WrapInFunction(s); {
ProgramBuilder b;
EXPECT_FALSE(r()->Resolve()); b.WrapInFunction(b.create<FakeStmt>());
Program(std::move(b));
EXPECT_EQ(r()->error(), },
"error: unknown statement type: tint::resolver::FakeStmt"); "internal compiler error: unhandled node type: tint::resolver::FakeStmt");
}
TEST_F(ResolverValidationTest, Stmt_Error_Unknown) {
auto* s = create<FakeStmt>(Source{Source::Location{2, 30}});
WrapInFunction(s);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"2:30 error: unknown statement type: tint::resolver::FakeStmt");
} }
TEST_F(ResolverValidationTest, Stmt_If_NonBool) { TEST_F(ResolverValidationTest, Stmt_If_NonBool) {