tint: Improve the output of DependencyGraph

Add ResolvedIdentifier to hold the resolved AST node, sem::BuiltinType
or type::Builtin.

Reduces duplicate builtin symbol lookups in Resolver.

Bug: tint:1810
Change-Id: Idde2b5f6fa22804b5019adc14c717bebd8342475
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/119041
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2023-02-08 15:18:43 +00:00 committed by Dawn LUCI CQ
parent 9e36723497
commit cf0e9301b2
12 changed files with 704 additions and 294 deletions

View File

@ -879,8 +879,7 @@ class ProgramBuilder {
/// @returns the type name
template <typename NAME, typename... ARGS, typename _ = DisableIfSource<NAME>>
const ast::TypeName* operator()(NAME&& name, ARGS&&... args) const {
return builder->create<ast::TypeName>(
builder->Ident(std::forward<NAME>(name), std::forward<ARGS>(args)...));
return (*this)(builder->source_, std::forward<NAME>(name), std::forward<ARGS>(args)...);
}
/// Creates a type name
@ -891,7 +890,8 @@ class ProgramBuilder {
template <typename NAME, typename... ARGS>
const ast::TypeName* operator()(const Source& source, NAME&& name, ARGS&&... args) const {
return builder->create<ast::TypeName>(
source, builder->Ident(std::forward<NAME>(name), std::forward<ARGS>(args)...));
source,
builder->Ident(source, std::forward<NAME>(name), std::forward<ARGS>(args)...));
}
/// Creates an alias type

View File

@ -210,7 +210,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsAliasUsedAsVariable) {
WrapInFunction(Decl(Var("v", Expr(Source{{56, 78}}, "mix"))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: missing '(' for builtin call)");
EXPECT_EQ(r()->error(), R"(56:78 error: missing '(' for type initializer or cast)");
}
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsAliasUsedAsType) {
@ -242,7 +242,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsStructUsedAsVariable) {
WrapInFunction(Decl(Var("v", Expr(Source{{12, 34}}, "mix"))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: missing '(' for builtin call)");
EXPECT_EQ(r()->error(), R"(12:34 error: missing '(' for type initializer or cast)");
}
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsStructUsedAsType) {

View File

@ -354,7 +354,8 @@ class DependencyScanner {
Switch(
expr,
[&](const ast::IdentifierExpression* ident) {
AddDependency(ident, ident->identifier->symbol, "identifier", "references");
AddDependency(ident->identifier, ident->identifier->symbol, "identifier",
"references");
},
[&](const ast::CallExpression* call) {
if (call->target.name) {
@ -392,7 +393,7 @@ class DependencyScanner {
TraverseType(ptr->type);
},
[&](const ast::TypeName* tn) { //
AddDependency(tn, tn->name->symbol, "type", "references");
AddDependency(tn->name, tn->name->symbol, "type", "references");
},
[&](const ast::Vector* vec) { //
TraverseType(vec->type);
@ -468,15 +469,25 @@ class DependencyScanner {
UnhandledNode(diagnostics_, attr);
}
/// Adds the dependency from `from` to `to`, erroring if `to` cannot be
/// resolved.
void AddDependency(const ast::Node* from, Symbol to, const char* use, const char* action) {
/// Adds the dependency from @p from to @p to, erroring if @p to cannot be resolved.
void AddDependency(const ast::Identifier* from,
Symbol to,
const char* use,
const char* action) {
auto* resolved = scope_stack_.Get(to);
if (!resolved) {
if (!IsBuiltin(to)) {
UnknownSymbol(to, from->source, use);
auto s = symbols_.NameFor(to);
if (auto builtin_fn = sem::ParseBuiltinType(s); builtin_fn != sem::BuiltinType::kNone) {
graph_.resolved_identifiers.Add(from, ResolvedIdentifier(builtin_fn));
return;
}
if (auto builtin_ty = type::ParseBuiltin(s); builtin_ty != type::Builtin::kUndefined) {
graph_.resolved_identifiers.Add(from, ResolvedIdentifier(builtin_ty));
return;
}
UnknownSymbol(to, from->source, use);
return;
}
if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
@ -486,17 +497,7 @@ class DependencyScanner {
}
}
graph_.resolved_symbols.Add(from, resolved);
}
/// @returns true if `name` is the name of a builtin function, or builtin type alias
bool IsBuiltin(Symbol name) const {
auto s = symbols_.NameFor(name);
if (sem::ParseBuiltinType(s) != sem::BuiltinType::kNone ||
type::ParseBuiltin(s) != type::Builtin::kUndefined) {
return true;
}
return false;
graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved));
}
/// Appends an error to the diagnostics that the given symbol cannot be
@ -529,7 +530,7 @@ struct DependencyAnalysis {
/// @returns true if analysis found no errors, otherwise false.
bool Run(const ast::Module& module) {
// Reserve container memory
graph_.resolved_symbols.Reserve(module.GlobalDeclarations().Length());
graph_.resolved_identifiers.Reserve(module.GlobalDeclarations().Length());
sorted_.Reserve(module.GlobalDeclarations().Length());
// Collect all the named globals from the AST module
@ -821,4 +822,20 @@ bool DependencyGraph::Build(const ast::Module& module,
return da.Run(module);
}
std::ostream& operator<<(std::ostream& out, const ResolvedIdentifier& ri) {
if (!ri) {
return out << "<unresolved symbol>";
}
if (auto* node = ri.Node()) {
return out << "'" << node->TypeInfo().name << "' at " << node->source;
}
if (auto builtin_fn = ri.BuiltinFunction(); builtin_fn != sem::BuiltinType::kNone) {
return out << "builtin function '" << builtin_fn << "'";
}
if (auto builtin_ty = ri.BuiltinType(); builtin_ty != type::Builtin::kUndefined) {
return out << "builtin function '" << builtin_ty << "'";
}
return out << "<unhandled ResolvedIdentifier value>";
}
} // namespace tint::resolver

View File

@ -19,10 +19,83 @@
#include "src/tint/ast/module.h"
#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/sem/builtin_type.h"
#include "src/tint/type/builtin.h"
#include "src/tint/utils/hashmap.h"
namespace tint::resolver {
/// ResolvedIdentifier holds the resolution of an ast::Identifier.
/// Can hold one of:
/// - const ast::TypeDecl* (as const ast::Node*)
/// - const ast::Variable* (as const ast::Node*)
/// - const ast::Function* (as const ast::Node*)
/// - sem::BuiltinType
/// - type::Builtin
class ResolvedIdentifier {
public:
ResolvedIdentifier() = default;
/// Constructor
/// @param value the resolved identifier value
template <typename T>
ResolvedIdentifier(T value) : value_(value) {} // NOLINT(runtime/explicit)
/// @return true if the ResolvedIdentifier holds a value (successfully resolved)
operator bool() const { return !std::holds_alternative<std::monostate>(value_); }
/// @return the node pointer if the ResolvedIdentifier holds an AST node, otherwise nullptr
const ast::Node* Node() const {
if (auto n = std::get_if<const ast::Node*>(&value_)) {
return *n;
}
return nullptr;
}
/// @return the builtin function if the ResolvedIdentifier holds sem::BuiltinType, otherwise
/// sem::BuiltinType::kNone
sem::BuiltinType BuiltinFunction() const {
if (auto n = std::get_if<sem::BuiltinType>(&value_)) {
return *n;
}
return sem::BuiltinType::kNone;
}
/// @return the builtin type if the ResolvedIdentifier holds type::Builtin, otherwise
/// type::Builtin::kUndefined
type::Builtin BuiltinType() const {
if (auto n = std::get_if<type::Builtin>(&value_)) {
return *n;
}
return type::Builtin::kUndefined;
}
/// @param value the value to compare the ResolvedIdentifier to
/// @return true if the ResolvedIdentifier is equal to @p value
template <typename T>
bool operator==(const T& value) const {
if (auto n = std::get_if<T>(&value_)) {
return *n == value;
}
return false;
}
/// @param other the other value to compare to this
/// @return true if this ResolvedIdentifier and @p other are not equal
template <typename T>
bool operator!=(const T& other) const {
return !(*this == other);
}
private:
std::variant<std::monostate, const ast::Node*, sem::BuiltinType, type::Builtin> value_;
};
/// @param out the std::ostream to write to
/// @param ri the ResolvedIdentifier
/// @returns `out` so calls can be chained
std::ostream& operator<<(std::ostream& out, const ResolvedIdentifier& ri);
/// DependencyGraph holds information about module-scope declaration dependency
/// analysis and symbol resolutions.
struct DependencyGraph {
@ -48,9 +121,8 @@ struct DependencyGraph {
/// All globals in dependency-sorted order.
utils::Vector<const ast::Node*, 32> ordered_globals;
/// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or
/// variable that declares the symbol.
utils::Hashmap<const ast::Node*, const ast::Node*, 64> resolved_symbols;
/// Map of ast::Identifier to a ResolvedIdentifier
utils::Hashmap<const ast::Identifier*, ResolvedIdentifier, 64> resolved_identifiers;
/// Map of ast::Variable to a type, function, or variable that is shadowed by
/// the variable key. A declaration (X) shadows another (Y) if X and Y use

View File

@ -103,8 +103,7 @@ static constexpr SymbolDeclKind kFuncDeclKinds[] = {
SymbolDeclKind::Function,
};
/// SymbolUseKind is used by parameterized tests to enumerate the different
/// kinds of symbol uses.
/// SymbolUseKind is used by parameterized tests to enumerate the different kinds of symbol uses.
enum class SymbolUseKind {
GlobalVarType,
GlobalVarArrayElemType,
@ -400,19 +399,19 @@ struct SymbolTestHelper {
/// Destructor.
~SymbolTestHelper();
/// Declares a symbol with the given kind
/// @param kind the kind of symbol declaration
/// Declares an identifier with the given kind
/// @param kind the kind of identifier declaration
/// @param symbol the symbol to use for the declaration
/// @param source the source of the declaration
/// @returns the declaration node
/// @returns the identifier node
const ast::Node* Add(SymbolDeclKind kind, Symbol symbol, Source source);
/// Declares a use of a symbol with the given kind
/// Declares a use of an identifier with the given kind
/// @param kind the kind of symbol use
/// @param symbol the declaration symbol to use
/// @param source the source of the use
/// @returns the use node
const ast::Node* Add(SymbolUseKind kind, Symbol symbol, Source source);
const ast::Identifier* Add(SymbolUseKind kind, Symbol symbol, Source source);
/// Builds a function, if any parameter or local declarations have been added
void Build();
@ -422,7 +421,7 @@ SymbolTestHelper::SymbolTestHelper(ProgramBuilder* b) : builder(b) {}
SymbolTestHelper::~SymbolTestHelper() {}
const ast::Node* SymbolTestHelper::Add(SymbolDeclKind kind, Symbol symbol, Source source) {
const ast::Node* SymbolTestHelper::Add(SymbolDeclKind kind, Symbol symbol, Source source = {}) {
auto& b = *builder;
switch (kind) {
case SymbolDeclKind::GlobalVar:
@ -464,88 +463,90 @@ const ast::Node* SymbolTestHelper::Add(SymbolDeclKind kind, Symbol symbol, Sourc
return nullptr;
}
const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind, Symbol symbol, Source source) {
const ast::Identifier* SymbolTestHelper::Add(SymbolUseKind kind,
Symbol symbol,
Source source = {}) {
auto& b = *builder;
switch (kind) {
case SymbolUseKind::GlobalVarType: {
auto* node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), node, type::AddressSpace::kPrivate);
return node;
return node->name;
}
case SymbolUseKind::GlobalVarArrayElemType: {
auto* node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.array(node, 4_i), type::AddressSpace::kPrivate);
return node;
return node->name;
}
case SymbolUseKind::GlobalVarArraySizeValue: {
auto* node = b.Expr(source, symbol);
b.GlobalVar(b.Sym(), b.ty.array(b.ty.i32(), node), type::AddressSpace::kPrivate);
return node;
return node->identifier;
}
case SymbolUseKind::GlobalVarVectorElemType: {
auto* node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.vec3(node), type::AddressSpace::kPrivate);
return node;
return node->name;
}
case SymbolUseKind::GlobalVarMatrixElemType: {
auto* node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.mat3x4(node), type::AddressSpace::kPrivate);
return node;
return node->name;
}
case SymbolUseKind::GlobalVarSampledTexElemType: {
auto* node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.sampled_texture(type::TextureDimension::k2d, node));
return node;
return node->name;
}
case SymbolUseKind::GlobalVarMultisampledTexElemType: {
auto* node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.multisampled_texture(type::TextureDimension::k2d, node));
return node;
return node->name;
}
case SymbolUseKind::GlobalVarValue: {
auto* node = b.Expr(source, symbol);
b.GlobalVar(b.Sym(), b.ty.i32(), type::AddressSpace::kPrivate, node);
return node;
return node->identifier;
}
case SymbolUseKind::GlobalConstType: {
auto* node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), node, b.Expr(1_i));
return node;
return node->name;
}
case SymbolUseKind::GlobalConstArrayElemType: {
auto* node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), b.ty.array(node, 4_i), b.Expr(1_i));
return node;
return node->name;
}
case SymbolUseKind::GlobalConstArraySizeValue: {
auto* node = b.Expr(source, symbol);
b.GlobalConst(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1_i));
return node;
return node->identifier;
}
case SymbolUseKind::GlobalConstVectorElemType: {
auto* node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), b.ty.vec3(node), b.Expr(1_i));
return node;
return node->name;
}
case SymbolUseKind::GlobalConstMatrixElemType: {
auto* node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), b.ty.mat3x4(node), b.Expr(1_i));
return node;
return node->name;
}
case SymbolUseKind::GlobalConstValue: {
auto* node = b.Expr(source, symbol);
b.GlobalConst(b.Sym(), b.ty.i32(), node);
return node;
return node->identifier;
}
case SymbolUseKind::AliasType: {
auto* node = b.ty(source, symbol);
b.Alias(b.Sym(), node);
return node;
return node->name;
}
case SymbolUseKind::StructMemberType: {
auto* node = b.ty(source, symbol);
b.Structure(b.Sym(), utils::Vector{b.Member("m", node)});
return node;
return node->name;
}
case SymbolUseKind::CallFunction: {
auto* node = b.Ident(source, symbol);
@ -555,72 +556,72 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind, Symbol symbol, Source
case SymbolUseKind::ParameterType: {
auto* node = b.ty(source, symbol);
parameters.Push(b.Param(b.Sym(), node));
return node;
return node->name;
}
case SymbolUseKind::LocalVarType: {
auto* node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), node)));
return node;
return node->name;
}
case SymbolUseKind::LocalVarArrayElemType: {
auto* node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.array(node, 4_u), b.Expr(1_i))));
return node;
return node->name;
}
case SymbolUseKind::LocalVarArraySizeValue: {
auto* node = b.Expr(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1_i))));
return node;
return node->identifier;
}
case SymbolUseKind::LocalVarVectorElemType: {
auto* node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.vec3(node))));
return node;
return node->name;
}
case SymbolUseKind::LocalVarMatrixElemType: {
auto* node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.mat3x4(node))));
return node;
return node->name;
}
case SymbolUseKind::LocalVarValue: {
auto* node = b.Expr(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.i32(), node)));
return node;
return node->identifier;
}
case SymbolUseKind::LocalLetType: {
auto* node = b.ty(source, symbol);
statements.Push(b.Decl(b.Let(b.Sym(), node, b.Expr(1_i))));
return node;
return node->name;
}
case SymbolUseKind::LocalLetValue: {
auto* node = b.Expr(source, symbol);
statements.Push(b.Decl(b.Let(b.Sym(), b.ty.i32(), node)));
return node;
return node->identifier;
}
case SymbolUseKind::NestedLocalVarType: {
auto* node = b.ty(source, symbol);
nested_statements.Push(b.Decl(b.Var(b.Sym(), node)));
return node;
return node->name;
}
case SymbolUseKind::NestedLocalVarValue: {
auto* node = b.Expr(source, symbol);
nested_statements.Push(b.Decl(b.Var(b.Sym(), b.ty.i32(), node)));
return node;
return node->identifier;
}
case SymbolUseKind::NestedLocalLetType: {
auto* node = b.ty(source, symbol);
nested_statements.Push(b.Decl(b.Let(b.Sym(), node, b.Expr(1_i))));
return node;
return node->name;
}
case SymbolUseKind::NestedLocalLetValue: {
auto* node = b.Expr(source, symbol);
nested_statements.Push(b.Decl(b.Let(b.Sym(), b.ty.i32(), node)));
return node;
return node->identifier;
}
case SymbolUseKind::WorkgroupSizeValue: {
auto* node = b.Expr(source, symbol);
func_attrs.Push(b.WorkgroupSize(1_i, node, 2_i));
return node;
return node->identifier;
}
}
return nullptr;
@ -641,7 +642,7 @@ void SymbolTestHelper::Build() {
}
////////////////////////////////////////////////////////////////////////////////
// Used-before-declarated tests
// Used-before-declared tests
////////////////////////////////////////////////////////////////////////////////
namespace used_before_decl_tests {
@ -1103,14 +1104,14 @@ TEST_F(ResolverDependencyGraphOrderedGlobalsTest, DirectiveFirst) {
} // namespace ordered_globals
////////////////////////////////////////////////////////////////////////////////
// Resolved symbols tests
// Resolve to user-declaration tests
////////////////////////////////////////////////////////////////////////////////
namespace resolved_symbols {
namespace resolve_to_user_decl {
using ResolverDependencyGraphResolvedSymbolTest =
using ResolverDependencyGraphResolveToUserDeclTest =
ResolverDependencyGraphTestWithParam<std::tuple<SymbolDeclKind, SymbolUseKind>>;
TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
TEST_P(ResolverDependencyGraphResolveToUserDeclTest, Test) {
const Symbol symbol = Sym("SYMBOL");
const auto decl_kind = std::get<0>(GetParam());
const auto use_kind = std::get<1>(GetParam());
@ -1128,41 +1129,209 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
if (expect_pass) {
// Check that the use resolves to the declaration
auto resolved_symbol = graph.resolved_symbols.Find(use);
ASSERT_TRUE(resolved_symbol);
EXPECT_EQ(*resolved_symbol, decl)
<< "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
<< "\n"
auto resolved_identifier = graph.resolved_identifiers.Find(use);
ASSERT_TRUE(resolved_identifier);
auto* resolved_node = resolved_identifier->Node();
EXPECT_EQ(resolved_node, decl)
<< "resolved: " << (resolved_node ? resolved_node->TypeInfo().name : "<null>") << "\n"
<< "decl: " << decl->TypeInfo().name;
}
}
INSTANTIATE_TEST_SUITE_P(Types,
ResolverDependencyGraphResolvedSymbolTest,
ResolverDependencyGraphResolveToUserDeclTest,
testing::Combine(testing::ValuesIn(kTypeDeclKinds),
testing::ValuesIn(kTypeUseKinds)));
INSTANTIATE_TEST_SUITE_P(Values,
ResolverDependencyGraphResolvedSymbolTest,
ResolverDependencyGraphResolveToUserDeclTest,
testing::Combine(testing::ValuesIn(kValueDeclKinds),
testing::ValuesIn(kValueUseKinds)));
INSTANTIATE_TEST_SUITE_P(Functions,
ResolverDependencyGraphResolvedSymbolTest,
ResolverDependencyGraphResolveToUserDeclTest,
testing::Combine(testing::ValuesIn(kFuncDeclKinds),
testing::ValuesIn(kFuncUseKinds)));
} // namespace resolved_symbols
} // namespace resolve_to_user_decl
////////////////////////////////////////////////////////////////////////////////
// Resolve to builtin func tests
////////////////////////////////////////////////////////////////////////////////
namespace resolve_to_builtin_func {
using ResolverDependencyGraphResolveToBuiltinFunc =
ResolverDependencyGraphTestWithParam<std::tuple<SymbolUseKind, sem::BuiltinType>>;
TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, Resolve) {
const auto use = std::get<0>(GetParam());
const auto builtin = std::get<1>(GetParam());
const auto symbol = Symbols().New(utils::ToString(builtin));
SymbolTestHelper helper(this);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->BuiltinFunction(), builtin) << *resolved;
}
TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, ShadowedByGlobalVar) {
const auto use = std::get<0>(GetParam());
const auto builtin = std::get<1>(GetParam());
const auto symbol = Symbols().New(utils::ToString(builtin));
SymbolTestHelper helper(this);
auto* decl = helper.Add(SymbolDeclKind::GlobalVar, symbol);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << *resolved;
}
TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, ShadowedByStruct) {
const auto use = std::get<0>(GetParam());
const auto builtin = std::get<1>(GetParam());
const auto symbol = Symbols().New(utils::ToString(builtin));
SymbolTestHelper helper(this);
auto* decl = helper.Add(SymbolDeclKind::Struct, symbol);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << *resolved;
}
TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, ShadowedByFunc) {
const auto use = std::get<0>(GetParam());
const auto builtin = std::get<1>(GetParam());
const auto symbol = Symbols().New(utils::ToString(builtin));
SymbolTestHelper helper(this);
auto* decl = helper.Add(SymbolDeclKind::Function, symbol);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << *resolved;
}
INSTANTIATE_TEST_SUITE_P(Types,
ResolverDependencyGraphResolveToBuiltinFunc,
testing::Combine(testing::ValuesIn(kTypeUseKinds),
testing::ValuesIn(sem::kBuiltinTypes)));
INSTANTIATE_TEST_SUITE_P(Values,
ResolverDependencyGraphResolveToBuiltinFunc,
testing::Combine(testing::ValuesIn(kValueUseKinds),
testing::ValuesIn(sem::kBuiltinTypes)));
INSTANTIATE_TEST_SUITE_P(Functions,
ResolverDependencyGraphResolveToBuiltinFunc,
testing::Combine(testing::ValuesIn(kFuncUseKinds),
testing::ValuesIn(sem::kBuiltinTypes)));
} // namespace resolve_to_builtin_func
////////////////////////////////////////////////////////////////////////////////
// Resolve to builtin type tests
////////////////////////////////////////////////////////////////////////////////
namespace resolve_to_builtin_type {
using ResolverDependencyGraphResolveToBuiltinType =
ResolverDependencyGraphTestWithParam<std::tuple<SymbolUseKind, const char*>>;
TEST_P(ResolverDependencyGraphResolveToBuiltinType, Resolve) {
const auto use = std::get<0>(GetParam());
const auto name = std::get<1>(GetParam());
const auto symbol = Symbols().New(name);
SymbolTestHelper helper(this);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->BuiltinType(), type::ParseBuiltin(name)) << *resolved;
}
TEST_P(ResolverDependencyGraphResolveToBuiltinType, ShadowedByGlobalVar) {
const auto use = std::get<0>(GetParam());
const auto name = std::get<1>(GetParam());
const auto symbol = Symbols().New(name);
SymbolTestHelper helper(this);
auto* decl = helper.Add(SymbolDeclKind::GlobalVar, symbol);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << *resolved;
}
TEST_P(ResolverDependencyGraphResolveToBuiltinType, ShadowedByStruct) {
const auto use = std::get<0>(GetParam());
const auto name = std::get<1>(GetParam());
const auto symbol = Symbols().New(name);
SymbolTestHelper helper(this);
auto* decl = helper.Add(SymbolDeclKind::Struct, symbol);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << *resolved;
}
TEST_P(ResolverDependencyGraphResolveToBuiltinType, ShadowedByFunc) {
const auto use = std::get<0>(GetParam());
const auto name = std::get<1>(GetParam());
const auto symbol = Symbols().New(name);
SymbolTestHelper helper(this);
auto* decl = helper.Add(SymbolDeclKind::Function, symbol);
auto* ident = helper.Add(use, symbol);
helper.Build();
auto resolved = Build().resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << *resolved;
}
INSTANTIATE_TEST_SUITE_P(Types,
ResolverDependencyGraphResolveToBuiltinType,
testing::Combine(testing::ValuesIn(kTypeUseKinds),
testing::ValuesIn(type::kBuiltinStrings)));
INSTANTIATE_TEST_SUITE_P(Values,
ResolverDependencyGraphResolveToBuiltinType,
testing::Combine(testing::ValuesIn(kValueUseKinds),
testing::ValuesIn(type::kBuiltinStrings)));
INSTANTIATE_TEST_SUITE_P(Functions,
ResolverDependencyGraphResolveToBuiltinType,
testing::Combine(testing::ValuesIn(kFuncUseKinds),
testing::ValuesIn(type::kBuiltinStrings)));
} // namespace resolve_to_builtin_type
////////////////////////////////////////////////////////////////////////////////
// Shadowing tests
////////////////////////////////////////////////////////////////////////////////
namespace shadowing {
using ResolverDependencyShadowTest =
using ResolverDependencyGraphShadowTest =
ResolverDependencyGraphTestWithParam<std::tuple<SymbolDeclKind, SymbolDeclKind>>;
TEST_P(ResolverDependencyShadowTest, Test) {
TEST_P(ResolverDependencyGraphShadowTest, Test) {
const Symbol symbol = Sym("SYMBOL");
const auto outer_kind = std::get<0>(GetParam());
const auto inner_kind = std::get<1>(GetParam());
@ -1185,12 +1354,12 @@ TEST_P(ResolverDependencyShadowTest, Test) {
}
INSTANTIATE_TEST_SUITE_P(LocalShadowGlobal,
ResolverDependencyShadowTest,
ResolverDependencyGraphShadowTest,
testing::Combine(testing::ValuesIn(kGlobalDeclKinds),
testing::ValuesIn(kLocalDeclKinds)));
INSTANTIATE_TEST_SUITE_P(NestedLocalShadowLocal,
ResolverDependencyShadowTest,
ResolverDependencyGraphShadowTest,
testing::Combine(testing::Values(SymbolDeclKind::Parameter,
SymbolDeclKind::LocalVar,
SymbolDeclKind::LocalLet),
@ -1204,6 +1373,18 @@ INSTANTIATE_TEST_SUITE_P(NestedLocalShadowLocal,
////////////////////////////////////////////////////////////////////////////////
namespace ast_traversal {
static const ast::Identifier* IdentifierOf(const ast::IdentifierExpression* expr) {
return expr->identifier;
}
static const ast::Identifier* IdentifierOf(const ast::TypeName* ty) {
return ty->name;
}
static const ast::Identifier* IdentifierOf(const ast::Identifier* ident) {
return ident;
}
using ResolverDependencyGraphTraversalTest = ResolverDependencyGraphTest;
TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
@ -1217,7 +1398,7 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
struct SymbolUse {
const ast::Node* decl = nullptr;
const ast::Node* use = nullptr;
const ast::Identifier* use = nullptr;
std::string where;
};
@ -1225,7 +1406,8 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
auto add_use = [&](const ast::Node* decl, auto* use, int line, const char* kind) {
symbol_uses.Push(
SymbolUse{decl, use, std::string(__FILE__) + ":" + std::to_string(line) + ": " + kind});
SymbolUse{decl, IdentifierOf(use),
std::string(__FILE__) + ":" + std::to_string(line) + ": " + kind});
return use;
};
#define V add_use(value_decl, Expr(value_sym), __LINE__, "V()")
@ -1310,9 +1492,9 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
auto graph = Build();
for (auto use : symbol_uses) {
auto resolved_symbol = graph.resolved_symbols.Find(use.use);
ASSERT_TRUE(resolved_symbol) << use.where;
EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
auto resolved_identifier = graph.resolved_identifiers.Find(use.use);
ASSERT_TRUE(resolved_identifier) << use.where;
EXPECT_EQ(*resolved_identifier, use.decl) << use.where;
}
}

View File

@ -105,7 +105,7 @@ Resolver::Resolver(ProgramBuilder* builder)
diagnostics_(builder->Diagnostics()),
const_eval_(*builder),
intrinsic_table_(IntrinsicTable::Create(*builder)),
sem_(builder, dependencies_),
sem_(builder),
validator_(builder,
sem_,
enabled_extensions_,
@ -332,41 +332,45 @@ type::Type* Resolver::Type(const ast::Type* ty) {
TINT_UNREACHABLE(Resolver, builder_->Diagnostics()) << "TODO(crbug.com/tint/1810)";
}
auto* resolved = sem_.ResolvedSymbol(t);
if (resolved == nullptr) {
if (IsBuiltin(t->name->symbol)) {
auto name = builder_->Symbols().NameFor(t->name->symbol);
AddError("cannot use builtin '" + name + "' as type", ty->source);
return nullptr;
}
return BuiltinType(t->name->symbol, t->source);
auto resolved = dependencies_.resolved_identifiers.Get(t->name);
if (!resolved) {
TINT_ICE(Resolver, builder_->Diagnostics()) << "identifier was not resolved";
return nullptr;
}
return Switch(
resolved, //
[&](type::Type* type) { return type; },
[&](sem::Variable* var) {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
AddError("cannot use variable '" + name + "' as type", ty->source);
AddNote("'" + name + "' declared here", var->Declaration()->source);
return nullptr;
},
[&](sem::Function* func) {
auto name = builder_->Symbols().NameFor(func->Declaration()->symbol);
AddError("cannot use function '" + name + "' as type", ty->source);
AddNote("'" + name + "' declared here", func->Declaration()->source);
return nullptr;
},
[&](Default) -> type::Type* {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Unhandled resolved type '"
<< (resolved ? resolved->TypeInfo().name : "<null>")
<< "' resolved from ast::Type '" << ty->TypeInfo().name << "'";
return nullptr;
});
},
[&](Default) {
if (auto* ast_node = resolved->Node()) {
auto* resolved_node = sem_.Get(ast_node);
return Switch(
resolved_node, //
[&](type::Type* type) { return type; },
[&](sem::Variable* variable) {
auto name = builder_->Symbols().NameFor(variable->Declaration()->symbol);
AddError("cannot use variable '" + name + "' as type", ty->source);
AddNote("'" + name + "' declared here", variable->Declaration()->source);
return nullptr;
},
[&](sem::Function* func) {
auto name = builder_->Symbols().NameFor(func->Declaration()->symbol);
AddError("cannot use function '" + name + "' as type", ty->source);
AddNote("'" + name + "' declared here", func->Declaration()->source);
return nullptr;
});
}
if (auto builtin_ty = resolved->BuiltinType();
builtin_ty != type::Builtin::kUndefined) {
return BuiltinType(builtin_ty, t->name);
}
if (auto builtin_fn = resolved->BuiltinFunction();
builtin_fn != sem::BuiltinType::kNone) {
auto name = builder_->Symbols().NameFor(t->name->symbol);
AddError("cannot use builtin '" + name + "' as type", ty->source);
return nullptr;
}
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Unhandled type: '" << ty->TypeInfo().name << "'";
<< "unhandled resolved identifier: " << *resolved;
return nullptr;
});
@ -947,7 +951,8 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
return nullptr;
}
// Track the pipeline-overridable constants that are transitively referenced by this variable.
// Track the pipeline-overridable constants that are transitively referenced by this
// variable.
for (auto* var : transitively_referenced_overrides) {
builder_->Sem().AddTransitivelyReferencedOverride(sem, var);
}
@ -1176,8 +1181,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
"abstract-integer, i32 or u32";
for (size_t i = 0; i < 3; i++) {
// Each argument to this attribute can either be a literal, an identifier for a module-scope
// constants, a const-expression, or nullptr if not specified.
// Each argument to this attribute can either be a literal, an identifier for a
// module-scope constants, a const-expression, or nullptr if not specified.
auto* value = values[i];
if (!value) {
break;
@ -1604,8 +1609,8 @@ sem::ValueExpression* Resolver::Expression(const ast::Expression* root) {
return sem_expr;
}
// If we just processed the lhs of a constexpr logical binary expression, mark the rhs for
// short-circuiting.
// If we just processed the lhs of a constexpr logical binary expression, mark the rhs
// for short-circuiting.
if (sem_expr->ConstantValue()) {
if (auto binary = logical_binary_lhs_to_parent_.Find(expr)) {
const bool lhs_is_true = sem_expr->ConstantValue()->ValueAs<bool>();
@ -1678,7 +1683,8 @@ bool Resolver::AliasAnalysis(const sem::Call* call) {
auto& target_info = alias_analysis_infos_[target];
auto& caller_info = alias_analysis_infos_[current_function_];
// Track the set of root identifiers that are read and written by arguments passed in this call.
// Track the set of root identifiers that are read and written by arguments passed in this
// call.
std::unordered_map<const sem::Variable*, const sem::ValueExpression*> arg_reads;
std::unordered_map<const sem::Variable*, const sem::ValueExpression*> arg_writes;
for (size_t i = 0; i < args.Length(); i++) {
@ -1715,8 +1721,8 @@ bool Resolver::AliasAnalysis(const sem::Call* call) {
},
[&](const sem::Parameter* param) { caller_info.parameter_writes.insert(param); });
} else if (target_info.parameter_reads.count(target->Parameters()[i])) {
// Arguments that are read from can alias with arguments or module-scope variables that
// are written to.
// Arguments that are read from can alias with arguments or module-scope variables
// that are written to.
if (arg_writes.count(root)) {
return make_error(arg, {arg_writes.at(root), Alias::Argument, "write"});
}
@ -2101,8 +2107,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
// Constant evaluation failed.
// Can happen for expressions that will fail validation (later).
// Use the kRuntime EvaluationStage, as kConstant will trigger an assertion in
// the sem::ValueExpression initializer, which checks that kConstant is paired with
// a constant value.
// the sem::ValueExpression initializer, which checks that kConstant is paired
// with a constant value.
stage = sem::EvaluationStage::kRuntime;
}
}
@ -2310,35 +2316,44 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
// conversion.
auto* ident = expr->target.name;
Mark(ident);
if (auto* resolved = sem_.ResolvedSymbol<type::Type>(ident)) {
// A type initializer or conversions.
// Note: Unlike the code path where we're resolving the call target from an
// ast::Type, all types must already have the element type explicitly specified,
// so there's no need to infer element types.
return ty_init_or_conv(resolved);
auto resolved = dependencies_.resolved_identifiers.Get(ident);
if (!resolved) {
TINT_ICE(Resolver, builder_->Diagnostics()) << "identifier was not resolved";
return nullptr;
}
auto* resolved = sem_.ResolvedSymbol<sem::Node>(ident);
call = Switch<sem::Call*>(
resolved, //
[&](sem::Function* func) { return FunctionCall(expr, func, args, arg_behaviors); },
[&](sem::Variable* var) {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
AddError("cannot call variable '" + name + "'", ident->source);
AddNote("'" + name + "' declared here", var->Declaration()->source);
return nullptr;
},
[&](Default) -> sem::Call* {
auto name = builder_->Symbols().NameFor(ident->symbol);
if (auto builtin_type = sem::ParseBuiltinType(name);
builtin_type != sem::BuiltinType::kNone) {
return BuiltinCall(expr, builtin_type, args);
}
if (auto* alias = BuiltinType(ident->symbol, ident->source)) {
return ty_init_or_conv(alias);
}
return nullptr;
});
if (auto* ast_node = resolved->Node()) {
auto* resolved_node = sem_.Get(ast_node);
return Switch(
resolved_node, //
[&](const type::Type* ty) {
// A type initializer or conversions.
// Note: Unlike the code path where we're resolving the call target from an
// ast::Type, all types must already have the element type explicitly
// specified, so there's no need to infer element types.
return ty_init_or_conv(ty);
},
[&](sem::Function* func) { return FunctionCall(expr, func, args, arg_behaviors); },
[&](sem::Variable* var) {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
AddError("cannot call variable '" + name + "'", ident->source);
AddNote("'" + name + "' declared here", var->Declaration()->source);
return nullptr;
});
}
if (auto builtin_fn = resolved->BuiltinFunction(); builtin_fn != sem::BuiltinType::kNone) {
return BuiltinCall(expr, builtin_fn, args);
}
if (auto builtin_ty = resolved->BuiltinType(); builtin_ty != type::Builtin::kUndefined) {
auto* ty = BuiltinType(builtin_ty, expr->target.name);
return ty ? ty_init_or_conv(ty) : nullptr;
}
TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled resolved identifier: " << *resolved;
return nullptr;
}
if (!call) {
@ -2438,13 +2453,12 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
return call;
}
type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const {
auto name = builder_->Symbols().NameFor(sym);
type::Type* Resolver::BuiltinType(type::Builtin builtin_ty, const ast::Identifier* ident) const {
auto& b = *builder_;
auto vec_f32 = [&](uint32_t n) { return b.create<type::Vector>(b.create<type::F32>(), n); };
auto vec_f16 = [&](uint32_t n) { return b.create<type::Vector>(b.create<type::F16>(), n); };
switch (type::ParseBuiltin(name)) {
switch (builtin_ty) {
case type::Builtin::kMat2X2F:
return b.create<type::Matrix>(vec_f32(2u), 2u);
case type::Builtin::kMat2X3F:
@ -2464,32 +2478,41 @@ type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const {
case type::Builtin::kMat4X4F:
return b.create<type::Matrix>(vec_f32(4u), 4u);
case type::Builtin::kMat2X2H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(2u), 2u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(2u), 2u)
: nullptr;
case type::Builtin::kMat2X3H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(3u), 2u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(3u), 2u)
: nullptr;
case type::Builtin::kMat2X4H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(4u), 2u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(4u), 2u)
: nullptr;
case type::Builtin::kMat3X2H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(2u), 3u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(2u), 3u)
: nullptr;
case type::Builtin::kMat3X3H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(3u), 3u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(3u), 3u)
: nullptr;
case type::Builtin::kMat3X4H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(4u), 3u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(4u), 3u)
: nullptr;
case type::Builtin::kMat4X2H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(2u), 4u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(2u), 4u)
: nullptr;
case type::Builtin::kMat4X3H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(3u), 4u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(3u), 4u)
: nullptr;
case type::Builtin::kMat4X4H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(4u), 4u)
: nullptr;
return validator_.CheckF16Enabled(ident->source)
? b.create<type::Matrix>(vec_f16(4u), 4u)
: nullptr;
case type::Builtin::kVec2F:
return vec_f32(2u);
case type::Builtin::kVec3F:
@ -2497,11 +2520,11 @@ type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const {
case type::Builtin::kVec4F:
return vec_f32(4u);
case type::Builtin::kVec2H:
return validator_.CheckF16Enabled(source) ? vec_f16(2u) : nullptr;
return validator_.CheckF16Enabled(ident->source) ? vec_f16(2u) : nullptr;
case type::Builtin::kVec3H:
return validator_.CheckF16Enabled(source) ? vec_f16(3u) : nullptr;
return validator_.CheckF16Enabled(ident->source) ? vec_f16(3u) : nullptr;
case type::Builtin::kVec4H:
return validator_.CheckF16Enabled(source) ? vec_f16(4u) : nullptr;
return validator_.CheckF16Enabled(ident->source) ? vec_f16(4u) : nullptr;
case type::Builtin::kVec2I:
return b.create<type::Vector>(b.create<type::I32>(), 2u);
case type::Builtin::kVec3I:
@ -2518,7 +2541,8 @@ type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const {
break;
}
TINT_ICE(Resolver, diagnostics_) << source << " unhandled builtin type '" << name << "'";
auto name = builder_->Symbols().NameFor(ident->symbol);
TINT_ICE(Resolver, diagnostics_) << ident->source << " unhandled builtin type '" << name << "'";
return nullptr;
}
@ -2685,97 +2709,110 @@ sem::ValueExpression* Resolver::Literal(const ast::LiteralExpression* literal) {
sem::ValueExpression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
Mark(expr->identifier);
auto symbol = expr->identifier->symbol;
auto* sem_resolved = sem_.ResolvedSymbol<sem::Node>(expr);
if (auto* variable = As<sem::Variable>(sem_resolved)) {
auto* user = builder_->create<sem::VariableUser>(expr, current_statement_, variable);
if (current_statement_) {
// If identifier is part of a loop continuing block, make sure it
// doesn't refer to a variable that is bypassed by a continue statement
// in the loop's body block.
if (auto* continuing_block =
current_statement_->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
auto* loop_block = continuing_block->FindFirstParent<sem::LoopBlockStatement>();
if (loop_block->FirstContinue()) {
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
if (auto decl = loop_block->Decls().Find(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
builder_->Symbols().NameFor(symbol) + "'",
loop_block->FirstContinue()->source);
AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
"' declared here",
decl->variable->Declaration()->source);
AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
"' referenced in continuing block here",
expr->source);
return nullptr;
auto resolved = dependencies_.resolved_identifiers.Get(expr->identifier);
if (!resolved) {
TINT_ICE(Resolver, builder_->Diagnostics()) << "identifier was not resolved";
return nullptr;
}
if (auto* ast_node = resolved->Node()) {
auto* resolved_node = sem_.Get(ast_node);
return Switch(
resolved_node, //
[&](sem::Variable* variable) -> sem::VariableUser* {
auto symbol = expr->identifier->symbol;
auto* user =
builder_->create<sem::VariableUser>(expr, current_statement_, variable);
if (current_statement_) {
// If identifier is part of a loop continuing block, make sure it
// doesn't refer to a variable that is bypassed by a continue statement
// in the loop's body block.
if (auto* continuing_block =
current_statement_
->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
auto* loop_block =
continuing_block->FindFirstParent<sem::LoopBlockStatement>();
if (loop_block->FirstContinue()) {
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
if (auto decl = loop_block->Decls().Find(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
builder_->Symbols().NameFor(symbol) + "'",
loop_block->FirstContinue()->source);
AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
"' declared here",
decl->variable->Declaration()->source);
AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
"' referenced in continuing block here",
expr->source);
return nullptr;
}
}
}
}
}
}
}
auto* global = variable->As<sem::GlobalVariable>();
if (current_function_) {
if (global) {
current_function_->AddDirectlyReferencedGlobal(global);
auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global);
if (refs) {
for (auto* var : *refs) {
current_function_->AddTransitivelyReferencedGlobal(var);
auto* global = variable->As<sem::GlobalVariable>();
if (current_function_) {
if (global) {
current_function_->AddDirectlyReferencedGlobal(global);
auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global);
if (refs) {
for (auto* var : *refs) {
current_function_->AddTransitivelyReferencedGlobal(var);
}
}
}
}
}
} else if (variable->Declaration()->Is<ast::Override>()) {
if (resolved_overrides_) {
// Track the reference to this pipeline-overridable constant and any other
// pipeline-overridable constants that it references.
resolved_overrides_->Add(global);
auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global);
if (refs) {
for (auto* var : *refs) {
resolved_overrides_->Add(var);
} else if (variable->Declaration()->Is<ast::Override>()) {
if (resolved_overrides_) {
// Track the reference to this pipeline-overridable constant and any other
// pipeline-overridable constants that it references.
resolved_overrides_->Add(global);
auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global);
if (refs) {
for (auto* var : *refs) {
resolved_overrides_->Add(var);
}
}
}
} else if (variable->Declaration()->Is<ast::Var>()) {
// Use of a module-scope 'var' outside of a function.
// Note: The spec is currently vague around the rules here. See
// https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when
// resolved.
std::string desc = "var '" + builder_->Symbols().NameFor(symbol) + "' ";
AddError(desc + "cannot be referenced at module-scope", expr->source);
AddNote(desc + "declared here", variable->Declaration()->source);
return nullptr;
}
}
} else if (variable->Declaration()->Is<ast::Var>()) {
// Use of a module-scope 'var' outside of a function.
// Note: The spec is currently vague around the rules here. See
// https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when resolved.
std::string desc = "var '" + builder_->Symbols().NameFor(symbol) + "' ";
AddError(desc + "cannot be referenced at module-scope", expr->source);
AddNote(desc + "declared here", variable->Declaration()->source);
return nullptr;
}
variable->AddUser(user);
return user;
variable->AddUser(user);
return user;
},
[&](const type::Type*) {
AddError("missing '(' for type initializer or cast", expr->source.End());
return nullptr;
},
[&](const sem::Function*) {
AddError("missing '(' for function call", expr->source.End());
return nullptr;
});
}
if (Is<sem::Function>(sem_resolved)) {
AddError("missing '(' for function call", expr->source.End());
return nullptr;
}
if (IsBuiltin(symbol)) {
AddError("missing '(' for builtin call", expr->source.End());
return nullptr;
}
if (sem_.ResolvedSymbol<type::Type>(expr) ||
type::ParseBuiltin(builder_->Symbols().NameFor(symbol)) != type::Builtin::kUndefined) {
if (resolved->BuiltinType() != type::Builtin::kUndefined) {
AddError("missing '(' for type initializer or cast", expr->source.End());
return nullptr;
}
// The dependency graph should have errored on this unresolved identifier before reaching here.
TINT_ICE(Resolver, diagnostics_)
<< expr->source << " unresolved identifier:\n"
<< "resolved: " << (sem_resolved ? sem_resolved->TypeInfo().name : "<null>") << "\n"
<< "name: " << builder_->Symbols().NameFor(symbol);
if (resolved->BuiltinFunction() != sem::BuiltinType::kNone) {
AddError("missing '(' for builtin call", expr->source.End());
return nullptr;
}
TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled resolved identifier: " << *resolved;
return nullptr;
}
@ -2977,8 +3014,8 @@ sem::ValueExpression* Resolver::Binary(const ast::BinaryExpression* expr) {
return nullptr;
}
} else {
// The arguments have constant values, but the operator cannot be const-evaluated. This
// can only be evaluated at runtime.
// The arguments have constant values, but the operator cannot be const-evaluated.
// This can only be evaluated at runtime.
stage = sem::EvaluationStage::kRuntime;
}
}
@ -3966,9 +4003,4 @@ void Resolver::AddNote(const std::string& msg, const Source& source) const {
diagnostics_.add_note(diag::System::Resolver, msg, source);
}
bool Resolver::IsBuiltin(Symbol symbol) const {
std::string name = builder_->Symbols().NameFor(symbol);
return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone;
}
} // namespace tint::resolver

View File

@ -428,12 +428,9 @@ class Resolver {
/// Adds the given note message to the diagnostics
void AddNote(const std::string& msg, const Source& source) const;
/// @returns true if the symbol is the name of a builtin function.
bool IsBuiltin(Symbol) const;
/// @returns the builtin type for the symbol @p symbol at @p source
/// @returns the type::Type for the builtin type @p builtin_ty with the identifier @p ident
/// @note: Will raise an ICE if @p symbol is not a builtin type.
type::Type* BuiltinType(Symbol symbol, const Source& source) const;
type::Type* BuiltinType(type::Builtin builtin_ty, const ast::Identifier* ident) const;
// ArrayInitializerSig represents a unique array initializer signature.
// It is a tuple of the array type, number of arguments provided and earliest evaluation stage.

View File

@ -18,8 +18,7 @@
namespace tint::resolver {
SemHelper::SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies)
: builder_(builder), dependencies_(dependencies) {}
SemHelper::SemHelper(ProgramBuilder* builder) : builder_(builder) {}
SemHelper::~SemHelper() = default;

View File

@ -29,8 +29,7 @@ class SemHelper {
public:
/// Constructor
/// @param builder the program builder
/// @param dependencies the program dependency graph
explicit SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies);
explicit SemHelper(ProgramBuilder* builder);
~SemHelper();
/// Get is a helper for obtaining the semantic node for the given AST node.
@ -76,18 +75,6 @@ class SemHelper {
}
}
/// @returns the resolved symbol (function, type or variable) for the given ast::Identifier or
/// ast::TypeName cast to the given semantic type.
/// @param node the node to retrieve
template <typename SEM = sem::Info::InferFromAST>
sem::Info::GetResultType<SEM, ast::Node>* ResolvedSymbol(const ast::Node* node) const {
if (auto resolved = dependencies_.resolved_symbols.Find(node)) {
auto* sem = builder_->Sem().Get<SEM>(*resolved);
return const_cast<sem::Info::GetResultType<SEM, ast::Node>*>(sem);
}
return nullptr;
}
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
type::Type* TypeOf(const ast::Expression* expr) const;
@ -104,7 +91,6 @@ class SemHelper {
private:
ProgramBuilder* builder_;
DependencyGraph& dependencies_;
};
} // namespace tint::resolver

View File

@ -2344,8 +2344,8 @@ bool Validator::Assignment(const ast::Statement* a, const type::Type* rhs_ty) co
// https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
auto const* lhs_ty = sem_.TypeOf(lhs);
if (auto* variable = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* v = variable->Declaration();
if (auto* var_user = sem_.Get<sem::VariableUser>(lhs)) {
auto* v = var_user->Variable()->Declaration();
const char* err = Switch(
v, //
[&](const ast::Parameter*) { return "cannot assign to function parameter"; },
@ -2392,8 +2392,8 @@ bool Validator::IncrementDecrementStatement(const ast::IncrementDecrementStateme
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
if (auto* variable = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* v = variable->Declaration();
if (auto* var_user = sem_.Get<sem::VariableUser>(lhs)) {
auto* v = var_user->Variable()->Declaration();
const char* err = Switch(
v, //
[&](const ast::Parameter*) { return "cannot modify function parameter"; },

View File

@ -161,6 +161,124 @@ const char* str(BuiltinType i);
/// matches the name in the WGSL spec.
std::ostream& operator<<(std::ostream& out, BuiltinType i);
/// All builtin types
constexpr BuiltinType kBuiltinTypes[] = {
BuiltinType::kAbs,
BuiltinType::kAcos,
BuiltinType::kAcosh,
BuiltinType::kAll,
BuiltinType::kAny,
BuiltinType::kArrayLength,
BuiltinType::kAsin,
BuiltinType::kAsinh,
BuiltinType::kAtan,
BuiltinType::kAtan2,
BuiltinType::kAtanh,
BuiltinType::kCeil,
BuiltinType::kClamp,
BuiltinType::kCos,
BuiltinType::kCosh,
BuiltinType::kCountLeadingZeros,
BuiltinType::kCountOneBits,
BuiltinType::kCountTrailingZeros,
BuiltinType::kCross,
BuiltinType::kDegrees,
BuiltinType::kDeterminant,
BuiltinType::kDistance,
BuiltinType::kDot,
BuiltinType::kDot4I8Packed,
BuiltinType::kDot4U8Packed,
BuiltinType::kDpdx,
BuiltinType::kDpdxCoarse,
BuiltinType::kDpdxFine,
BuiltinType::kDpdy,
BuiltinType::kDpdyCoarse,
BuiltinType::kDpdyFine,
BuiltinType::kExp,
BuiltinType::kExp2,
BuiltinType::kExtractBits,
BuiltinType::kFaceForward,
BuiltinType::kFirstLeadingBit,
BuiltinType::kFirstTrailingBit,
BuiltinType::kFloor,
BuiltinType::kFma,
BuiltinType::kFract,
BuiltinType::kFrexp,
BuiltinType::kFwidth,
BuiltinType::kFwidthCoarse,
BuiltinType::kFwidthFine,
BuiltinType::kInsertBits,
BuiltinType::kInverseSqrt,
BuiltinType::kLdexp,
BuiltinType::kLength,
BuiltinType::kLog,
BuiltinType::kLog2,
BuiltinType::kMax,
BuiltinType::kMin,
BuiltinType::kMix,
BuiltinType::kModf,
BuiltinType::kNormalize,
BuiltinType::kPack2X16Float,
BuiltinType::kPack2X16Snorm,
BuiltinType::kPack2X16Unorm,
BuiltinType::kPack4X8Snorm,
BuiltinType::kPack4X8Unorm,
BuiltinType::kPow,
BuiltinType::kQuantizeToF16,
BuiltinType::kRadians,
BuiltinType::kReflect,
BuiltinType::kRefract,
BuiltinType::kReverseBits,
BuiltinType::kRound,
BuiltinType::kSaturate,
BuiltinType::kSelect,
BuiltinType::kSign,
BuiltinType::kSin,
BuiltinType::kSinh,
BuiltinType::kSmoothstep,
BuiltinType::kSqrt,
BuiltinType::kStep,
BuiltinType::kStorageBarrier,
BuiltinType::kTan,
BuiltinType::kTanh,
BuiltinType::kTranspose,
BuiltinType::kTrunc,
BuiltinType::kUnpack2X16Float,
BuiltinType::kUnpack2X16Snorm,
BuiltinType::kUnpack2X16Unorm,
BuiltinType::kUnpack4X8Snorm,
BuiltinType::kUnpack4X8Unorm,
BuiltinType::kWorkgroupBarrier,
BuiltinType::kWorkgroupUniformLoad,
BuiltinType::kTextureDimensions,
BuiltinType::kTextureGather,
BuiltinType::kTextureGatherCompare,
BuiltinType::kTextureNumLayers,
BuiltinType::kTextureNumLevels,
BuiltinType::kTextureNumSamples,
BuiltinType::kTextureSample,
BuiltinType::kTextureSampleBias,
BuiltinType::kTextureSampleCompare,
BuiltinType::kTextureSampleCompareLevel,
BuiltinType::kTextureSampleGrad,
BuiltinType::kTextureSampleLevel,
BuiltinType::kTextureSampleBaseClampToEdge,
BuiltinType::kTextureStore,
BuiltinType::kTextureLoad,
BuiltinType::kAtomicLoad,
BuiltinType::kAtomicStore,
BuiltinType::kAtomicAdd,
BuiltinType::kAtomicSub,
BuiltinType::kAtomicMax,
BuiltinType::kAtomicMin,
BuiltinType::kAtomicAnd,
BuiltinType::kAtomicOr,
BuiltinType::kAtomicXor,
BuiltinType::kAtomicExchange,
BuiltinType::kAtomicCompareExchangeWeak,
BuiltinType::kTintMaterialize,
};
} // namespace tint::sem
#endif // SRC_TINT_SEM_BUILTIN_TYPE_H_

View File

@ -41,6 +41,13 @@ const char* str(BuiltinType i);
/// matches the name in the WGSL spec.
std::ostream& operator<<(std::ostream& out, BuiltinType i);
/// All builtin types
constexpr BuiltinType kBuiltinTypes[] = {
{{- range Sem.Builtins }}
BuiltinType::k{{PascalCase .Name}},
{{- end }}
};
} // namespace tint::sem
#endif // SRC_TINT_SEM_BUILTIN_TYPE_H_