From a9156ff091b99f2c1b434bc78a47078bf706a39f Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 5 Nov 2021 16:51:38 +0000 Subject: [PATCH] Rework Resolver so that we construct semantic types in a single pass. The semantic nodes cannot be fully immutable, as they contain cyclic references. Remove Resolver::CreateSemanticNodes(), and instead construct and mutate the semantic nodes in the single traversal pass. Give up on trying to maintain the 'authored' type names (aliased names). These are a nightmare to maintain, and provided limited use. Significantly simplfies the Resolver, and allows us to generate more semantic to semantic references, reducing sem -> ast -> sem hops. Note: This change introduces constant value propagation across constant variables. This is unlocked by the earlier construction of the sem::Variable. Change-Id: I592092fdc47fe24d30e512952511c9ab7c16d7a1 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68406 Kokoro: Kokoro Commit-Queue: Ben Clayton Reviewed-by: Antonio Maiorano --- fuzzers/tint_ast_fuzzer/util.h | 10 +- src/inspector/inspector.cc | 17 +- src/intrinsic_table.cc | 4 +- src/resolver/array_accessor_test.cc | 5 +- src/resolver/assignment_validation_test.cc | 4 +- src/resolver/compound_statement_test.cc | 16 +- src/resolver/function_validation_test.cc | 2 +- src/resolver/ptr_ref_test.cc | 15 +- src/resolver/ptr_ref_validation_test.cc | 2 +- src/resolver/resolver.cc | 2175 ++++++++--------- src/resolver/resolver.h | 207 +- src/resolver/resolver_constants.cc | 79 +- src/resolver/resolver_test.cc | 14 +- .../storage_class_layout_validation_test.cc | 6 +- src/resolver/storage_class_validation_test.cc | 9 +- .../type_constructor_validation_test.cc | 13 +- src/resolver/validation_test.cc | 7 +- src/resolver/var_let_validation_test.cc | 6 +- src/sem/block_statement.cc | 17 +- src/sem/block_statement.h | 16 +- src/sem/call.cc | 9 +- src/sem/call.h | 17 +- src/sem/constant.cc | 2 + src/sem/constant.h | 5 + src/sem/expression.h | 5 +- src/sem/for_loop_statement.cc | 5 +- src/sem/for_loop_statement.h | 4 +- src/sem/function.cc | 24 +- src/sem/function.h | 99 +- src/sem/if_statement.cc | 10 +- src/sem/if_statement.h | 9 +- src/sem/info.h | 15 +- src/sem/loop_statement.cc | 14 +- src/sem/loop_statement.h | 8 +- src/sem/statement.cc | 17 +- src/sem/statement.h | 19 +- src/sem/switch_statement.cc | 14 +- src/sem/switch_statement.h | 8 +- src/sem/variable.cc | 35 +- src/sem/variable.h | 43 +- .../module_scope_var_to_entry_point_param.cc | 6 +- src/transform/robustness_test.cc | 4 +- test/bug/tint/1121.wgsl.expected.spvasm | 2 +- 43 files changed, 1448 insertions(+), 1550 deletions(-) diff --git a/fuzzers/tint_ast_fuzzer/util.h b/fuzzers/tint_ast_fuzzer/util.h index e82ebcc9c6..0b6f4496d8 100644 --- a/fuzzers/tint_ast_fuzzer/util.h +++ b/fuzzers/tint_ast_fuzzer/util.h @@ -22,6 +22,7 @@ #include "src/castable.h" #include "src/program.h" #include "src/sem/block_statement.h" +#include "src/sem/function.h" #include "src/sem/statement.h" #include "src/sem/variable.h" @@ -74,16 +75,15 @@ std::vector GetAllVarsInScope( } // Process function parameters. - for (const auto* param : curr_stmt->Function()->params) { - const auto* sem_var = program.Sem().Get(param); - if (pred(sem_var)) { - result.push_back(sem_var); + for (const auto* param : curr_stmt->Function()->Parameters()) { + if (pred(param)) { + result.push_back(param); } } // Global variables do not belong to any ast::BlockStatement. for (const auto* global_decl : program.AST().GlobalDeclarations()) { - if (global_decl == curr_stmt->Function()) { + if (global_decl == curr_stmt->Function()->Declaration()) { // The same situation as in the previous loop. The current function has // been reached. If there are any variables declared below, they won't be // visible in this function. Thus, exit the loop. diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index 299c5fb8d3..b602c0aa70 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -827,11 +827,11 @@ void Inspector::GenerateSamplerTargets() { } auto* call_func = call->Stmt()->Function(); - std::vector entry_points; - if (call_func->IsEntryPoint()) { - entry_points = {call_func->symbol}; + std::vector entry_points; + if (call_func->Declaration()->IsEntryPoint()) { + entry_points = {call_func}; } else { - entry_points = sem.Get(call_func)->AncestorEntryPoints(); + entry_points = call_func->AncestorEntryPoints(); } if (entry_points.empty()) { @@ -854,8 +854,9 @@ void Inspector::GenerateSamplerTargets() { sampler->Declaration()->BindingPoint().group->value, sampler->Declaration()->BindingPoint().binding->value}; - for (auto entry_point : entry_points) { - const auto& ep_name = program_->Symbols().NameFor(entry_point); + for (auto* entry_point : entry_points) { + const auto& ep_name = + program_->Symbols().NameFor(entry_point->Declaration()->symbol); (*sampler_targets_)[ep_name].add( {sampler_binding_point, texture_binding_point}); } @@ -911,8 +912,8 @@ void Inspector::GetOriginatingResources( // is not called. Ignore. return; } - for (auto* call_expr : func->CallSites()) { - callsites.add(call_expr); + for (auto* call : func->CallSites()) { + callsites.add(call->Declaration()); } // Need to evaluate each function call with the group of // expressions, so move on to the next expression. diff --git a/src/intrinsic_table.cc b/src/intrinsic_table.cc index d74bb8db83..c7f6d08a32 100644 --- a/src/intrinsic_table.cc +++ b/src/intrinsic_table.cc @@ -1070,8 +1070,8 @@ const sem::Intrinsic* Impl::Match(sem::IntrinsicType intrinsic_type, ast::StorageClass::kNone, ast::Access::kUndefined, p.usage)); } return builder.create( - intrinsic.type, intrinsic.return_type, - std::move(params), intrinsic.supported_stages, intrinsic.is_deprecated); + intrinsic.type, intrinsic.return_type, std::move(params), + intrinsic.supported_stages, intrinsic.is_deprecated); }); } diff --git a/src/resolver/array_accessor_test.cc b/src/resolver/array_accessor_test.cc index a565d0cda5..728c119cc0 100644 --- a/src/resolver/array_accessor_test.cc +++ b/src/resolver/array_accessor_test.cc @@ -289,8 +289,9 @@ TEST_F(ResolverArrayAccessorTest, EXpr_Deref_FuncBadParent) { Func("func", {p}, ty.f32(), {Decl(idx), Decl(x), Return(x)}); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: cannot index type 'ptr>'"); + EXPECT_EQ( + r()->error(), + "12:34 error: cannot index type 'ptr, read_write>'"); } TEST_F(ResolverArrayAccessorTest, Exr_Deref_BadParent) { diff --git a/src/resolver/assignment_validation_test.cc b/src/resolver/assignment_validation_test.cc index f8d93904b3..1981c876cc 100644 --- a/src/resolver/assignment_validation_test.cc +++ b/src/resolver/assignment_validation_test.cc @@ -102,7 +102,7 @@ TEST_F(ResolverAssignmentValidationTest, ASSERT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: cannot assign 'array' to 'array'"); + "12:34 error: cannot assign 'array' to 'array'"); } TEST_F(ResolverAssignmentValidationTest, @@ -332,7 +332,7 @@ TEST_F(ResolverAssignmentValidationTest, AssignToPhony_DynamicArray_Fail) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ( r()->error(), - "12:34 error: cannot assign 'ref, read>' to '_'. " + "12:34 error: cannot assign 'array' to '_'. " "'_' can only be assigned a constructible, pointer, texture or sampler " "type"); } diff --git a/src/resolver/compound_statement_test.cc b/src/resolver/compound_statement_test.cc index a13cac6de7..2ace78e4b0 100644 --- a/src/resolver/compound_statement_test.cc +++ b/src/resolver/compound_statement_test.cc @@ -43,7 +43,7 @@ TEST_F(ResolverCompoundStatementTest, FunctionBlock) { ASSERT_TRUE(s->Block()->Is()); EXPECT_EQ(s->Block(), s->FindFirstParent()); EXPECT_EQ(s->Block(), s->FindFirstParent()); - EXPECT_EQ(s->Block()->As()->Function(), f); + EXPECT_EQ(s->Function()->Declaration(), f); EXPECT_EQ(s->Block()->Parent(), nullptr); } @@ -74,8 +74,7 @@ TEST_F(ResolverCompoundStatementTest, Block) { EXPECT_EQ(s->Block()->Parent(), s->FindFirstParent()); ASSERT_TRUE(s->Block()->Parent()->Is()); - EXPECT_EQ( - s->Block()->Parent()->As()->Function(), f); + EXPECT_EQ(s->Function()->Declaration(), f); EXPECT_EQ(s->Block()->Parent()->Parent(), nullptr); } } @@ -118,7 +117,7 @@ TEST_F(ResolverCompoundStatementTest, Loop) { EXPECT_TRUE( Is(s->Parent()->Parent()->Parent())); - EXPECT_EQ(s->FindFirstParent()->Function(), f); + EXPECT_EQ(s->Function()->Declaration(), f); EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(), nullptr); } @@ -144,7 +143,7 @@ TEST_F(ResolverCompoundStatementTest, Loop) { s->FindFirstParent()); EXPECT_TRUE(Is( s->Parent()->Parent()->Parent()->Parent())); - EXPECT_EQ(s->FindFirstParent()->Function(), f); + EXPECT_EQ(s->Function()->Declaration(), f); EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent()->Parent(), nullptr); } @@ -213,12 +212,7 @@ TEST_F(ResolverCompoundStatementTest, ForLoop) { Is(s->Block()->Parent()->Parent())); EXPECT_EQ(s->Block()->Parent()->Parent(), s->FindFirstParent()); - EXPECT_EQ(s->Block() - ->Parent() - ->Parent() - ->As() - ->Function(), - f); + EXPECT_EQ(s->Function()->Declaration(), f); EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr); } } diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc index dc71caeb40..ec44227cff 100644 --- a/src/resolver/function_validation_test.cc +++ b/src/resolver/function_validation_test.cc @@ -388,7 +388,7 @@ TEST_F(ResolverFunctionValidationTest, EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: return statement type must match its function return " - "type, returned 'u32', expected 'myf32'"); + "type, returned 'u32', expected 'f32'"); } TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) { diff --git a/src/resolver/ptr_ref_test.cc b/src/resolver/ptr_ref_test.cc index f8425495a4..4810537c93 100644 --- a/src/resolver/ptr_ref_test.cc +++ b/src/resolver/ptr_ref_test.cc @@ -98,11 +98,16 @@ TEST_F(ResolverPtrRefTest, DefaultPtrStorageClass) { EXPECT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_TRUE(TypeOf(function_ptr)->Is()); - ASSERT_TRUE(TypeOf(private_ptr)->Is()); - ASSERT_TRUE(TypeOf(workgroup_ptr)->Is()); - ASSERT_TRUE(TypeOf(uniform_ptr)->Is()); - ASSERT_TRUE(TypeOf(storage_ptr)->Is()); + ASSERT_TRUE(TypeOf(function_ptr)->Is()) + << "function_ptr is " << TypeOf(function_ptr)->TypeInfo().name; + ASSERT_TRUE(TypeOf(private_ptr)->Is()) + << "private_ptr is " << TypeOf(private_ptr)->TypeInfo().name; + ASSERT_TRUE(TypeOf(workgroup_ptr)->Is()) + << "workgroup_ptr is " << TypeOf(workgroup_ptr)->TypeInfo().name; + ASSERT_TRUE(TypeOf(uniform_ptr)->Is()) + << "uniform_ptr is " << TypeOf(uniform_ptr)->TypeInfo().name; + ASSERT_TRUE(TypeOf(storage_ptr)->Is()) + << "storage_ptr is " << TypeOf(storage_ptr)->TypeInfo().name; EXPECT_EQ(TypeOf(function_ptr)->As()->Access(), ast::Access::kReadWrite); diff --git a/src/resolver/ptr_ref_validation_test.cc b/src/resolver/ptr_ref_validation_test.cc index 06daeb5424..87886bb8fa 100644 --- a/src/resolver/ptr_ref_validation_test.cc +++ b/src/resolver/ptr_ref_validation_test.cc @@ -167,7 +167,7 @@ TEST_F(ResolverPtrRefValidationTest, InferredPtrAccessMismatch) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: cannot initialize let of type " - "'ptr' with value of type " + "'ptr' with value of type " "'ptr'"); } diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 9664f005df..7808b8f9c3 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -161,22 +161,6 @@ Resolver::Resolver(ProgramBuilder* builder) Resolver::~Resolver() = default; -void Resolver::set_referenced_from_function_if_needed(VariableInfo* var, - bool local) { - if (current_function_ == nullptr) { - return; - } - - if (var->kind != VariableKind::kGlobal) { - return; - } - - current_function_->referenced_module_vars.add(var); - if (local) { - current_function_->local_referenced_module_vars.add(var); - } -} - bool Resolver::Resolve() { if (builder_->Diagnostics().contains_errors()) { return false; @@ -190,23 +174,19 @@ bool Resolver::Resolve() { return false; } - // Even if resolving failed, create all the semantic nodes for information we - // did generate. - CreateSemanticNodes(); - return result; } // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section bool Resolver::IsPlain(const sem::Type* type) const { - return type->is_scalar() || type->Is() || - type->Is() || type->Is() || - type->Is() || type->Is(); + return type->is_scalar() || + type->IsAnyOf(); } // https://gpuweb.github.io/gpuweb/wgsl.html#storable-types bool Resolver::IsStorable(const sem::Type* type) const { - return IsPlain(type) || type->Is() || type->Is(); + return IsPlain(type) || type->IsAnyOf(); } // https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types @@ -443,53 +423,36 @@ bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) { return true; } -Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var, - VariableKind kind, - uint32_t index /* = 0 */) { - if (variable_to_info_.count(var)) { - TINT_ICE(Resolver, diagnostics_) - << "Variable " << builder_->Symbols().NameFor(var->symbol) - << " already resolved"; - return nullptr; - } - - std::string type_name; - const sem::Type* storage_type = nullptr; +sem::Variable* Resolver::Variable(const ast::Variable* var, + VariableKind kind, + uint32_t index /* = 0 */) { + const sem::Type* storage_ty = nullptr; // If the variable has a declared type, resolve it. if (auto* ty = var->type) { - type_name = ty->FriendlyName(builder_->Symbols()); - storage_type = Type(ty); - if (!storage_type) { + storage_ty = Type(ty); + if (!storage_ty) { return nullptr; } } - std::string rhs_type_name; - const sem::Type* rhs_type = nullptr; + const sem::Expression* rhs = nullptr; // Does the variable have a constructor? - if (auto* ctor = var->constructor) { - if (!Expression(var->constructor)) { - return nullptr; - } - - // Fetch the constructor's type - rhs_type_name = TypeNameOf(ctor); - rhs_type = TypeOf(ctor); - if (!rhs_type) { + if (var->constructor) { + rhs = Expression(var->constructor); + if (!rhs) { return nullptr; } // If the variable has no declared type, infer it from the RHS - if (!storage_type) { + if (!storage_ty) { if (!var->is_const && kind == VariableKind::kGlobal) { AddError("global var declaration must specify a type", var->source); return nullptr; } - type_name = rhs_type_name; - storage_type = rhs_type->UnwrapRef(); // Implicit load of RHS + storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS } } else if (var->is_const && kind != VariableKind::kParameter && !ast::HasDecoration(var->decorations)) { @@ -504,7 +467,7 @@ Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var, return nullptr; } - if (!storage_type) { + if (!storage_ty) { TINT_ICE(Resolver, diagnostics_) << "failed to determine storage type for variable '" + builder_->Symbols().NameFor(var->symbol) + "'\n" @@ -517,7 +480,7 @@ Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var, // No declared storage class. Infer from usage / type. if (kind == VariableKind::kLocal) { storage_class = ast::StorageClass::kFunction; - } else if (storage_type->UnwrapRef()->is_handle()) { + } else if (storage_ty->UnwrapRef()->is_handle()) { // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables // If the store type is a texture type or a sampler type, then the // variable declaration must not have a storage class decoration. The @@ -526,31 +489,97 @@ Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var, } } + if (kind == VariableKind::kLocal && !var->is_const && + storage_class != ast::StorageClass::kFunction && + IsValidationEnabled(var->decorations, + ast::DisabledValidation::kIgnoreStorageClass)) { + AddError("function variable has a non-function storage class", var->source); + return nullptr; + } + auto access = var->declared_access; if (access == ast::Access::kUndefined) { access = DefaultAccessForStorageClass(storage_class); } - auto* type = storage_type; + auto* var_ty = storage_ty; if (!var->is_const) { // Variable declaration. Unlike `let`, `var` has storage. // Variables are always of a reference type to the declared storage type. - type = - builder_->create(storage_type, storage_class, access); + var_ty = + builder_->create(storage_ty, storage_class, access); } - if (rhs_type && - !ValidateVariableConstructor(var, storage_class, storage_type, type_name, - rhs_type, rhs_type_name)) { + if (rhs && !ValidateVariableConstructor(var, storage_class, storage_ty, + rhs->Type())) { return nullptr; } - auto* info = - variable_infos_.Create(var, const_cast(type), type_name, - storage_class, access, kind, index); - variable_to_info_.emplace(var, info); + if (!ApplyStorageClassUsageToType( + storage_class, const_cast(var_ty), var->source)) { + AddNote( + std::string("while instantiating ") + + ((kind == VariableKind::kParameter) ? "parameter " : "variable ") + + builder_->Symbols().NameFor(var->symbol), + var->source); + return nullptr; + } - return info; + if (kind == VariableKind::kParameter) { + if (auto* ptr = var_ty->As()) { + // For MSL, we push module-scope variables into the entry point as pointer + // parameters, so we also need to handle their store type. + if (!ApplyStorageClassUsageToType( + ptr->StorageClass(), const_cast(ptr->StoreType()), + var->source)) { + AddNote("while instantiating parameter " + + builder_->Symbols().NameFor(var->symbol), + var->source); + return nullptr; + } + } + } + + switch (kind) { + case VariableKind::kGlobal: { + sem::BindingPoint binding_point; + if (auto bp = var->BindingPoint()) { + binding_point = {bp.group->value, bp.binding->value}; + } + + auto* global = builder_->create( + var, var_ty, storage_class, access, + (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{}, + binding_point); + + if (auto* override = + ast::GetDecoration(var->decorations)) { + if (override->has_value) { + global->SetConstantId(static_cast(override->value)); + } + } + + builder_->Sem().Add(var, global); + return global; + } + case VariableKind::kLocal: { + auto* local = builder_->create( + var, var_ty, storage_class, access, + (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{}); + builder_->Sem().Add(var, local); + return local; + } + case VariableKind::kParameter: { + auto* param = builder_->create(var, index, var_ty, + storage_class, access); + builder_->Sem().Add(var, param); + return param; + } + } + + TINT_UNREACHABLE(Resolver, diagnostics_) + << "unhandled VariableKind " << static_cast(kind); + return nullptr; } ast::Access Resolver::DefaultAccessForStorageClass( @@ -603,23 +632,23 @@ void Resolver::AllocateOverridableConstantIds() { next_constant_id = constant_id + 1; } - variable_to_info_[var]->constant_id = constant_id; + auto* sem = Sem(var); + const_cast(sem)->SetConstantId(constant_id); } } bool Resolver::ValidateVariableConstructor(const ast::Variable* var, ast::StorageClass storage_class, - const sem::Type* storage_type, - const std::string& type_name, - const sem::Type* rhs_type, - const std::string& rhs_type_name) { - auto* value_type = rhs_type->UnwrapRef(); // Implicit load of RHS + const sem::Type* storage_ty, + const sem::Type* rhs_ty) { + auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS // Value type has to match storage type - if (storage_type != value_type) { + if (storage_ty != value_type) { std::string decl = var->is_const ? "let" : "var"; - AddError("cannot initialize " + decl + " of type '" + type_name + - "' with value of type '" + rhs_type_name + "'", + AddError("cannot initialize " + decl + " of type '" + + TypeNameOf(storage_ty) + "' with value of type '" + + TypeNameOf(rhs_ty) + "'", var->source); return false; } @@ -652,17 +681,18 @@ bool Resolver::GlobalVariable(const ast::Variable* var) { return false; } - auto* info = Variable(var, VariableKind::kGlobal); - if (!info) { + auto* sem = Variable(var, VariableKind::kGlobal); + if (!sem) { return false; } - variable_stack_.Set(var->symbol, info); + variable_stack_.Set(var->symbol, sem); - if (!var->is_const && info->storage_class == ast::StorageClass::kNone) { + auto storage_class = sem->StorageClass(); + if (!var->is_const && storage_class == ast::StorageClass::kNone) { AddError("global variables must have a storage class", var->source); return false; } - if (var->is_const && !(info->storage_class == ast::StorageClass::kNone)) { + if (var->is_const && storage_class != ast::StorageClass::kNone) { AddError("global constants shouldn't have a storage class", var->source); return false; } @@ -673,7 +703,7 @@ bool Resolver::GlobalVariable(const ast::Variable* var) { if (auto* override_deco = deco->As()) { // Track the constant IDs that are specified in the shader. if (override_deco->has_value) { - constant_ids_.emplace(override_deco->value, info); + constant_ids_.emplace(override_deco->value, sem); } } } @@ -682,26 +712,13 @@ bool Resolver::GlobalVariable(const ast::Variable* var) { return false; } - if (auto bp = var->BindingPoint()) { - info->binding_point = {bp.group->value, bp.binding->value}; - } - - if (!ValidateGlobalVariable(info)) { - return false; - } - - if (!ApplyStorageClassUsageToType( - info->storage_class, const_cast(info->type->UnwrapRef()), - var->source)) { - AddNote("while instantiating variable " + - builder_->Symbols().NameFor(var->symbol), - var->source); + if (!ValidateGlobalVariable(sem)) { return false; } // TODO(bclayton): Call this at the end of resolve on all uniform and storage // referenced structs - if (!ValidateStorageClassLayout(info)) { + if (!ValidateStorageClassLayout(sem)) { return false; } @@ -735,7 +752,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Struct* str, }; auto type_name_of = [this](const sem::StructMember* sm) { - return sm->Declaration()->type->FriendlyName(builder_->Symbols()); + return TypeNameOf(sm->Type()); }; // TODO(amaiorano): Output struct and member decorations so that this output @@ -779,8 +796,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Struct* str, << size << ") */ " << s << ";\n"; }; - print_struct_begin_line(st->Align(), st->Size(), - st->FriendlyName(builder_->Symbols())); + print_struct_begin_line(st->Align(), st->Size(), TypeNameOf(st)); for (size_t i = 0; i < st->Members().size(); ++i) { auto* const m = st->Members()[i]; @@ -911,10 +927,10 @@ bool Resolver::ValidateStorageClassLayout(const sem::Struct* str, return true; } -bool Resolver::ValidateStorageClassLayout(const VariableInfo* info) { - if (auto* str = info->type->UnwrapRef()->As()) { - if (!ValidateStorageClassLayout(str, info->storage_class)) { - AddNote("see declaration of variable", info->declaration->source); +bool Resolver::ValidateStorageClassLayout(const sem::Variable* var) { + if (auto* str = var->Type()->UnwrapRef()->As()) { + if (!ValidateStorageClassLayout(str, var->StorageClass())) { + AddNote("see declaration of variable", var->Declaration()->source); return false; } } @@ -922,24 +938,25 @@ bool Resolver::ValidateStorageClassLayout(const VariableInfo* info) { return true; } -bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { - if (!ValidateNoDuplicateDecorations(info->declaration->decorations)) { +bool Resolver::ValidateGlobalVariable(const sem::Variable* var) { + auto* decl = var->Declaration(); + if (!ValidateNoDuplicateDecorations(decl->decorations)) { return false; } - for (auto* deco : info->declaration->decorations) { - if (info->declaration->is_const) { + for (auto* deco : decl->decorations) { + if (decl->is_const) { if (auto* override_deco = deco->As()) { if (override_deco->has_value) { uint32_t id = override_deco->value; - auto itr = constant_ids_.find(id); - if (itr != constant_ids_.end() && itr->second != info) { + auto it = constant_ids_.find(id); + if (it != constant_ids_.end() && it->second != var) { AddError("pipeline constant IDs must be unique", deco->source); AddNote("a pipeline constant with an ID of " + std::to_string(id) + " was previously declared " "here:", ast::GetDecoration( - itr->second->declaration->decorations) + it->second->Declaration()->decorations) ->source); return false; } @@ -958,8 +975,8 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { deco->IsAnyOf(); bool has_io_storage_class = - info->storage_class == ast::StorageClass::kInput || - info->storage_class == ast::StorageClass::kOutput; + var->StorageClass() == ast::StorageClass::kInput || + var->StorageClass() == ast::StorageClass::kOutput; if (!(deco->IsAnyOf()) && (!is_shader_io_decoration || !has_io_storage_class)) { @@ -969,8 +986,8 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { } } - auto binding_point = info->declaration->BindingPoint(); - switch (info->storage_class) { + auto binding_point = decl->BindingPoint(); + switch (var->StorageClass()) { case ast::StorageClass::kUniform: case ast::StorageClass::kStorage: case ast::StorageClass::kUniformConstant: { @@ -981,7 +998,7 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { AddError( "resource variables require [[group]] and [[binding]] " "decorations", - info->declaration->source); + decl->source); return false; } break; @@ -993,7 +1010,7 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { AddError( "non-resource variables must not have [[group]] or [[binding]] " "decorations", - info->declaration->source); + decl->source); return false; } } @@ -1001,28 +1018,28 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration // The access mode always has a default, and except for variables in the // storage storage class, must not be written. - if (info->storage_class != ast::StorageClass::kStorage && - info->declaration->declared_access != ast::Access::kUndefined) { + if (var->StorageClass() != ast::StorageClass::kStorage && + decl->declared_access != ast::Access::kUndefined) { AddError( "only variables in storage class may declare an access mode", - info->declaration->source); + decl->source); return false; } - switch (info->storage_class) { + switch (var->StorageClass()) { case ast::StorageClass::kStorage: { // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables // A variable in the storage storage class is a storage buffer variable. // Its store type must be a host-shareable structure type with block // attribute, satisfying the storage class constraints. - auto* str = info->type->UnwrapRef()->As(); + auto* str = var->Type()->UnwrapRef()->As(); if (!str) { AddError( "variables declared in the storage class must be of a " "structure type", - info->declaration->source); + decl->source); return false; } @@ -1031,9 +1048,8 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { "structure used as a storage buffer must be declared with the " "[[block]] decoration", str->Declaration()->source); - if (info->declaration->source.range.begin.line) { - AddNote("structure used as storage buffer here", - info->declaration->source); + if (decl->source.range.begin.line) { + AddNote("structure used as storage buffer here", decl->source); } return false; } @@ -1044,12 +1060,12 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { // A variable in the uniform storage class is a uniform buffer variable. // Its store type must be a host-shareable structure type with block // attribute, satisfying the storage class constraints. - auto* str = info->type->UnwrapRef()->As(); + auto* str = var->Type()->UnwrapRef()->As(); if (!str) { AddError( "variables declared in the storage class must be of a " "structure type", - info->declaration->source); + decl->source); return false; } @@ -1058,9 +1074,8 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { "structure used as a uniform buffer must be declared with the " "[[block]] decoration", str->Declaration()->source); - if (info->declaration->source.range.begin.line) { - AddNote("structure used as uniform buffer here", - info->declaration->source); + if (decl->source.range.begin.line) { + AddNote("structure used as uniform buffer here", decl->source); } return false; } @@ -1071,7 +1086,7 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { AddError( "structure containing a runtime sized array " "cannot be used as a uniform buffer", - info->declaration->source); + decl->source); AddNote("structure is declared here", str->Declaration()->source); return false; } @@ -1084,24 +1099,24 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) { break; } - if (!info->declaration->is_const) { - if (!ValidateAtomicVariable(info)) { + if (!decl->is_const) { + if (!ValidateAtomicVariable(var)) { return false; } } - return ValidateVariable(info); + return ValidateVariable(var); } // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // Atomic types may only be instantiated by variables in the workgroup storage // class or by storage buffer variables with a read_write access mode. -bool Resolver::ValidateAtomicVariable(const VariableInfo* info) { - auto sc = info->storage_class; - auto access = info->access; - auto* type = info->type->UnwrapRef(); - auto source = info->declaration->type ? info->declaration->type->source - : info->declaration->source; +bool Resolver::ValidateAtomicVariable(const sem::Variable* var) { + auto sc = var->StorageClass(); + auto* decl = var->Declaration(); + auto access = var->Access(); + auto* type = var->Type()->UnwrapRef(); + auto source = decl->type ? decl->type->source : decl->source; if (type->Is()) { if (sc != ast::StorageClass::kWorkgroup) { @@ -1118,10 +1133,9 @@ bool Resolver::ValidateAtomicVariable(const VariableInfo* info) { AddError( "atomic variables must have or storage class", source); - AddNote("atomic sub-type of '" + - type->FriendlyName(builder_->Symbols()) + - "' is declared here", - found->second); + AddNote( + "atomic sub-type of '" + TypeNameOf(type) + "' is declared here", + found->second); return false; } else if (sc == ast::StorageClass::kStorage && access != ast::Access::kReadWrite) { @@ -1129,10 +1143,9 @@ bool Resolver::ValidateAtomicVariable(const VariableInfo* info) { "atomic variables in storage class must have read_write " "access mode", source); - AddNote("atomic sub-type of '" + - type->FriendlyName(builder_->Symbols()) + - "' is declared here", - found->second); + AddNote( + "atomic sub-type of '" + TypeNameOf(type) + "' is declared here", + found->second); return false; } } @@ -1141,75 +1154,85 @@ bool Resolver::ValidateAtomicVariable(const VariableInfo* info) { return true; } -bool Resolver::ValidateVariable(const VariableInfo* info) { - auto* var = info->declaration; - auto* storage_type = info->type->UnwrapRef(); +bool Resolver::ValidateVariable(const sem::Variable* var) { + auto* decl = var->Declaration(); + auto* storage_ty = var->Type()->UnwrapRef(); - if (!var->is_const && !IsStorable(storage_type)) { - AddError(storage_type->FriendlyName(builder_->Symbols()) + - " cannot be used as the type of a var", - var->source); + if (!decl->is_const && !IsStorable(storage_ty)) { + AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a var", + decl->source); return false; } - if (var->is_const && info->kind != VariableKind::kParameter && - !(storage_type->IsConstructible() || storage_type->Is())) { - AddError(storage_type->FriendlyName(builder_->Symbols()) + - " cannot be used as the type of a let", - var->source); + if (decl->is_const && !var->Is() && + !(storage_ty->IsConstructible() || storage_ty->Is())) { + AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a let", + decl->source); return false; } - if (auto* r = storage_type->As()) { + if (auto* r = storage_ty->As()) { if (r->IsRuntimeSized()) { AddError("runtime arrays may only appear as the last member of a struct", - var->source); + decl->source); return false; } } - if (auto* r = storage_type->As()) { + if (auto* r = storage_ty->As()) { if (r->dim() != ast::TextureDimension::k2d) { - AddError("only 2d multisampled textures are supported", var->source); + AddError("only 2d multisampled textures are supported", decl->source); return false; } if (!r->type()->UnwrapRef()->is_numeric_scalar()) { AddError("texture_multisampled_2d: type must be f32, i32 or u32", - var->source); + decl->source); return false; } } - if (storage_type->is_handle() && - var->declared_storage_class != ast::StorageClass::kNone) { + if (var->Is() && !decl->is_const && + IsValidationEnabled(decl->decorations, + ast::DisabledValidation::kIgnoreStorageClass)) { + if (!var->Type()->UnwrapRef()->IsConstructible()) { + AddError("function variable must have a constructible type", + decl->type ? decl->type->source : decl->source); + return false; + } + } + + if (storage_ty->is_handle() && + decl->declared_storage_class != ast::StorageClass::kNone) { // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables // If the store type is a texture type or a sampler type, then the // variable declaration must not have a storage class decoration. The // storage class will always be handle. - AddError("variables of type '" + info->type_name + + AddError("variables of type '" + TypeNameOf(storage_ty) + "' must not have a storage class", - var->source); + decl->source); return false; } - if (IsValidationEnabled(var->decorations, + if (IsValidationEnabled(decl->decorations, ast::DisabledValidation::kIgnoreStorageClass) && - (var->declared_storage_class == ast::StorageClass::kInput || - var->declared_storage_class == ast::StorageClass::kOutput)) { - AddError("invalid use of input/output storage class", var->source); + (decl->declared_storage_class == ast::StorageClass::kInput || + decl->declared_storage_class == ast::StorageClass::kOutput)) { + AddError("invalid use of input/output storage class", decl->source); return false; } return true; } bool Resolver::ValidateFunctionParameter(const ast::Function* func, - const VariableInfo* info) { - if (!ValidateVariable(info)) { + const sem::Variable* var) { + if (!ValidateVariable(var)) { return false; } - for (auto* deco : info->declaration->decorations) { + auto* decl = var->Declaration(); + + for (auto* deco : decl->decorations) { if (!func->IsEntryPoint() && !deco->Is()) { AddError( "decoration is not valid for non-entry point function parameters", @@ -1220,10 +1243,10 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, ast::InterpolateDecoration, ast::InternalDecoration>() && (IsValidationEnabled( - info->declaration->decorations, + decl->decorations, ast::DisabledValidation::kEntryPointParameter) && IsValidationEnabled( - info->declaration->decorations, + decl->decorations, ast::DisabledValidation:: kIgnoreConstructibleFunctionParameter))) { AddError("decoration is not valid for function parameters", deco->source); @@ -1231,34 +1254,35 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, } } - if (auto* ref = info->type->As()) { + if (auto* ref = var->Type()->As()) { auto sc = ref->StorageClass(); if (!(sc == ast::StorageClass::kFunction || sc == ast::StorageClass::kPrivate || sc == ast::StorageClass::kWorkgroup) && - IsValidationEnabled(info->declaration->decorations, + IsValidationEnabled(decl->decorations, ast::DisabledValidation::kIgnoreStorageClass)) { std::stringstream ss; ss << "function parameter of pointer type cannot be in '" << sc << "' storage class"; - AddError(ss.str(), info->declaration->source); + AddError(ss.str(), decl->source); return false; } } - if (IsPlain(info->type)) { - if (!info->type->IsConstructible() && + if (IsPlain(var->Type())) { + if (!var->Type()->IsConstructible() && IsValidationEnabled( - info->declaration->decorations, + decl->decorations, ast::DisabledValidation::kIgnoreConstructibleFunctionParameter)) { AddError("store type of function parameter must be a constructible type", - info->declaration->source); + decl->source); return false; } - } else if (!info->type->IsAnyOf()) { - AddError("store type of function parameter cannot be " + - info->type->FriendlyName(builder_->Symbols()), - info->declaration->source); + } else if (!var->Type() + ->IsAnyOf()) { + AddError( + "store type of function parameter cannot be " + TypeNameOf(var->Type()), + decl->source); return false; } @@ -1266,11 +1290,11 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, } bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, - const sem::Type* storage_type, + const sem::Type* storage_ty, const bool is_input) { - auto* type = storage_type->UnwrapRef(); + auto* type = storage_ty->UnwrapRef(); const auto stage = current_function_ - ? current_function_->declaration->PipelineStage() + ? current_function_->Declaration()->PipelineStage() : ast::PipelineStage::kNone; std::stringstream stage_name; stage_name << stage; @@ -1388,8 +1412,8 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, bool Resolver::ValidateInterpolateDecoration( const ast::InterpolateDecoration* deco, - const sem::Type* storage_type) { - auto* type = storage_type->UnwrapRef(); + const sem::Type* storage_ty) { + auto* type = storage_ty->UnwrapRef(); if (type->is_integer_scalar_or_vector() && deco->type != ast::InterpolationType::kFlat) { @@ -1409,18 +1433,18 @@ bool Resolver::ValidateInterpolateDecoration( return true; } -bool Resolver::ValidateFunction(const ast::Function* func, - const FunctionInfo* info) { - if (!ValidateNoDuplicateDefinition(func->symbol, func->source, +bool Resolver::ValidateFunction(const sem::Function* func) { + auto* decl = func->Declaration(); + if (!ValidateNoDuplicateDefinition(decl->symbol, decl->source, /* check_global_scope_only */ true)) { return false; } auto workgroup_deco_count = 0; - for (auto* deco : func->decorations) { + for (auto* deco : decl->decorations) { if (deco->Is()) { workgroup_deco_count++; - if (func->PipelineStage() != ast::PipelineStage::kCompute) { + if (decl->PipelineStage() != ast::PipelineStage::kCompute) { AddError( "the workgroup_size attribute is only valid for compute stages", deco->source); @@ -1433,41 +1457,41 @@ bool Resolver::ValidateFunction(const ast::Function* func, } } - if (func->params.size() > 255) { - AddError("functions may declare at most 255 parameters", func->source); + if (decl->params.size() > 255) { + AddError("functions may declare at most 255 parameters", decl->source); return false; } - for (auto* param : func->params) { - if (!ValidateFunctionParameter(func, variable_to_info_.at(param))) { + for (size_t i = 0; i < decl->params.size(); i++) { + if (!ValidateFunctionParameter(decl, func->Parameters()[i])) { return false; } } - if (!info->return_type->Is()) { - if (!info->return_type->IsConstructible()) { + if (!func->ReturnType()->Is()) { + if (!func->ReturnType()->IsConstructible()) { AddError("function return type must be a constructible type", - func->return_type->source); + decl->return_type->source); return false; } - if (func->body) { - if (!func->body->Last() || - !func->body->Last()->Is()) { + if (decl->body) { + if (!decl->body->Last() || + !decl->body->Last()->Is()) { AddError("non-void function must end with a return statement", - func->source); + decl->source); return false; } } else if (IsValidationEnabled( - func->decorations, + decl->decorations, ast::DisabledValidation::kFunctionHasNoBody)) { TINT_ICE(Resolver, diagnostics_) - << "Function " << builder_->Symbols().NameFor(func->symbol) + << "Function " << builder_->Symbols().NameFor(decl->symbol) << " has no body"; } - for (auto* deco : func->return_type_decorations) { - if (!func->IsEntryPoint()) { + for (auto* deco : decl->return_type_decorations) { + if (!decl->IsEntryPoint()) { AddError( "decoration is not valid for non-entry point function return types", deco->source); @@ -1476,9 +1500,9 @@ bool Resolver::ValidateFunction(const ast::Function* func, if (!deco->IsAnyOf() && - (IsValidationEnabled(info->declaration->decorations, + (IsValidationEnabled(decl->decorations, ast::DisabledValidation::kEntryPointParameter) && - IsValidationEnabled(info->declaration->decorations, + IsValidationEnabled(decl->decorations, ast::DisabledValidation:: kIgnoreConstructibleFunctionParameter))) { AddError("decoration is not valid for entry point return types", @@ -1488,8 +1512,8 @@ bool Resolver::ValidateFunction(const ast::Function* func, } } - if (func->IsEntryPoint()) { - if (!ValidateEntryPoint(func, info)) { + if (decl->IsEntryPoint()) { + if (!ValidateEntryPoint(func)) { return false; } } @@ -1497,12 +1521,13 @@ bool Resolver::ValidateFunction(const ast::Function* func, return true; } -bool Resolver::ValidateEntryPoint(const ast::Function* func, - const FunctionInfo* info) { +bool Resolver::ValidateEntryPoint(const sem::Function* func) { + auto* decl = func->Declaration(); + // Use a lambda to validate the entry point decorations for a type. // Persistent state is used to track which builtins and locations have // already been seen, in order to catch conflicts. - // TODO(jrprice): This state could be stored in FunctionInfo instead, and + // TODO(jrprice): This state could be stored in sem::Function instead, and // then passed to sem::Function since it would be useful there too. std::unordered_set builtins; std::unordered_set locations; @@ -1514,7 +1539,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, // Inner lambda that is applied to a type and all of its members. auto validate_entry_point_decorations_inner = [&](const ast::DecorationList& decos, - sem::Type* ty, + const sem::Type* ty, Source source, ParamOrRetType param_or_ret, bool is_struct_member) { @@ -1539,7 +1564,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, " attribute appears multiple times as pipeline " + (param_or_ret == ParamOrRetType::kParameter ? "input" : "output"), - func->source); + decl->source); return false; } @@ -1564,14 +1589,14 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, return false; } } else if (auto* interpolate = deco->As()) { - if (func->PipelineStage() == ast::PipelineStage::kCompute) { + if (decl->PipelineStage() == ast::PipelineStage::kCompute) { is_invalid_compute_shader_decoration = true; } else if (!ValidateInterpolateDecoration(interpolate, ty)) { return false; } interpolate_attribute = interpolate; } else if (auto* invariant = deco->As()) { - if (func->PipelineStage() == ast::PipelineStage::kCompute) { + if (decl->PipelineStage() == ast::PipelineStage::kCompute) { is_invalid_compute_shader_decoration = true; } invariant_attribute = invariant; @@ -1609,14 +1634,14 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, if (ty->is_integer_scalar_or_vector() && !interpolate_attribute) { // TODO(crbug.com/tint/1224): Make these errors once downstream // usages have caught up (no sooner than M99). - if (func->PipelineStage() == ast::PipelineStage::kVertex && + if (decl->PipelineStage() == ast::PipelineStage::kVertex && param_or_ret == ParamOrRetType::kReturnType) { AddWarning( "integral user-defined vertex outputs must have a flat " "interpolation attribute", source); } - if (func->PipelineStage() == ast::PipelineStage::kFragment && + if (decl->PipelineStage() == ast::PipelineStage::kFragment && param_or_ret == ParamOrRetType::kParameter) { AddWarning( "integral user-defined fragment inputs must have a flat " @@ -1648,7 +1673,8 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, // Outer lambda for validating the entry point decorations for a type. auto validate_entry_point_decorations = [&](const ast::DecorationList& decos, - sem::Type* ty, Source source, + const sem::Type* ty, + Source source, ParamOrRetType param_or_ret) { if (!validate_entry_point_decorations_inner(decos, ty, source, param_or_ret, /*is_struct_member*/ false)) { @@ -1662,8 +1688,8 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, member->Declaration()->source, param_or_ret, /*is_struct_member*/ true)) { AddNote("while analysing entry point '" + - builder_->Symbols().NameFor(func->symbol) + "'", - func->source); + builder_->Symbols().NameFor(decl->symbol) + "'", + decl->source); return false; } } @@ -1672,10 +1698,11 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, return true; }; - for (auto* param : info->parameters) { - if (!validate_entry_point_decorations( - param->declaration->decorations, param->type, - param->declaration->source, ParamOrRetType::kParameter)) { + for (auto* param : func->Parameters()) { + auto* param_decl = param->Declaration(); + if (!validate_entry_point_decorations(param_decl->decorations, + param->Type(), param_decl->source, + ParamOrRetType::kParameter)) { return false; } } @@ -1686,21 +1713,21 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, builtins.clear(); locations.clear(); - if (!info->return_type->Is()) { - if (!validate_entry_point_decorations(func->return_type_decorations, - info->return_type, func->source, + if (!func->ReturnType()->Is()) { + if (!validate_entry_point_decorations(decl->return_type_decorations, + func->ReturnType(), decl->source, ParamOrRetType::kReturnType)) { return false; } } - if (func->PipelineStage() == ast::PipelineStage::kVertex && + if (decl->PipelineStage() == ast::PipelineStage::kVertex && builtins.count(ast::Builtin::kPosition) == 0) { // Check module-scope variables, as the SPIR-V sanitizer generates these. bool found = false; - for (auto* var : info->referenced_module_vars) { + for (auto* global : func->TransitivelyReferencedGlobals()) { if (auto* builtin = ast::GetDecoration( - var->declaration->decorations)) { + global->Declaration()->decorations)) { if (builtin->builtin == ast::Builtin::kPosition) { found = true; break; @@ -1711,31 +1738,32 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, AddError( "a vertex shader must include the 'position' builtin in its return " "type", - func->source); + decl->source); return false; } } - if (func->PipelineStage() == ast::PipelineStage::kCompute) { - if (!ast::HasDecoration(func->decorations)) { + if (decl->PipelineStage() == ast::PipelineStage::kCompute) { + if (!ast::HasDecoration(decl->decorations)) { AddError( "a compute shader must include 'workgroup_size' in its " "attributes", - func->source); + decl->source); return false; } } // Validate there are no resource variable binding collisions std::unordered_map binding_points; - for (auto* var_info : info->referenced_module_vars) { - if (!var_info->declaration->BindingPoint()) { + for (auto* var : func->TransitivelyReferencedGlobals()) { + auto* var_decl = var->Declaration(); + if (!var_decl->BindingPoint()) { continue; } - auto bp = var_info->binding_point; - auto res = binding_points.emplace(bp, var_info->declaration); + auto bp = var->BindingPoint(); + auto res = binding_points.emplace(bp, var_decl); if (!res.second && - IsValidationEnabled(var_info->declaration->decorations, + IsValidationEnabled(decl->decorations, ast::DisabledValidation::kBindingPointCollision) && IsValidationEnabled(res.first->second->decorations, ast::DisabledValidation::kBindingPointCollision)) { @@ -1744,13 +1772,13 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, // variables in the resource interface of a given shader must not have // the same group and binding values, when considered as a pair of // values. - auto func_name = builder_->Symbols().NameFor(info->declaration->symbol); + auto func_name = builder_->Symbols().NameFor(decl->symbol); AddError("entry point '" + func_name + "' references multiple variables that use the " "same resource binding [[group(" + std::to_string(bp.group) + "), binding(" + std::to_string(bp.binding) + ")]]", - var_info->declaration->source); + var_decl->source); AddNote("first resource binding usage declared here", res.first->second->source); return false; @@ -1760,19 +1788,16 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, return true; } -bool Resolver::Function(const ast::Function* func) { - auto* info = function_infos_.Create(func); - - if (func->IsEntryPoint()) { - entry_points_.emplace_back(info); - } - - TINT_SCOPED_ASSIGNMENT(current_function_, info); - +sem::Function* Resolver::Function(const ast::Function* decl) { variable_stack_.Push(); + TINT_DEFER(variable_stack_.Pop()); + uint32_t parameter_index = 0; std::unordered_map parameter_names; - for (auto* param : func->params) { + std::vector parameters; + + // Resolve all the parameters + for (auto* param : decl->params) { Mark(param); { // Check the parameter name is unique for the function @@ -1781,48 +1806,29 @@ bool Resolver::Function(const ast::Function* func) { auto name = builder_->Symbols().NameFor(param->symbol); AddError("redefinition of parameter '" + name + "'", param->source); AddNote("previous definition is here", emplaced.first->second); - return false; + return nullptr; } } - auto* param_info = - Variable(param, VariableKind::kParameter, parameter_index++); - if (!param_info) { - return false; + auto* var = As( + Variable(param, VariableKind::kParameter, parameter_index++)); + if (!var) { + return nullptr; } for (auto* deco : param->decorations) { Mark(deco); } if (!ValidateNoDuplicateDecorations(param->decorations)) { - return false; + return nullptr; } - variable_stack_.Set(param->symbol, param_info); - info->parameters.emplace_back(param_info); + variable_stack_.Set(param->symbol, var); + parameters.emplace_back(var); - if (!ApplyStorageClassUsageToType(param->declared_storage_class, - param_info->type, param->source)) { - AddNote("while instantiating parameter " + - builder_->Symbols().NameFor(param->symbol), - param->source); - return false; - } - if (auto* ptr = param_info->type->As()) { - // For MSL, we push module-scope variables into the entry point as pointer - // parameters, so we also need to handle their store type. - if (!ApplyStorageClassUsageToType( - ptr->StorageClass(), const_cast(ptr->StoreType()), - param->source)) { - AddNote("while instantiating parameter " + - builder_->Symbols().NameFor(param->symbol), - param->source); - return false; - } - } - - if (auto* str = param_info->type->As()) { - switch (func->PipelineStage()) { + auto* var_ty = const_cast(var->Type()); + if (auto* str = var_ty->As()) { + switch (decl->PipelineStage()) { case ast::PipelineStage::kVertex: str->AddUsage(sem::PipelineStageUsage::kVertexInput); break; @@ -1838,28 +1844,27 @@ bool Resolver::Function(const ast::Function* func) { } } - if (auto* ty = func->return_type) { - info->return_type = Type(ty); - info->return_type_name = ty->FriendlyName(builder_->Symbols()); - if (!info->return_type) { - return false; + // Resolve the return type + sem::Type* return_type = nullptr; + if (auto* ty = decl->return_type) { + return_type = Type(ty); + if (!return_type) { + return nullptr; } } else { - info->return_type = builder_->create(); - info->return_type_name = - info->return_type->FriendlyName(builder_->Symbols()); + return_type = builder_->create(); } - if (auto* str = info->return_type->As()) { + if (auto* str = return_type->As()) { if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str, - func->source)) { + decl->source)) { AddNote("while instantiating return type for " + - builder_->Symbols().NameFor(func->symbol), - func->source); - return false; + builder_->Symbols().NameFor(decl->symbol), + decl->source); + return nullptr; } - switch (func->PipelineStage()) { + switch (decl->PipelineStage()) { case ast::PipelineStage::kVertex: str->AddUsage(sem::PipelineStageUsage::kVertexOutput); break; @@ -1874,139 +1879,165 @@ bool Resolver::Function(const ast::Function* func) { } } - if (func->body) { - Mark(func->body); + sem::WorkgroupSize ws{}; + if (!WorkgroupSizeFor(decl, ws)) { + return nullptr; + } + + auto* func = + builder_->create(decl, return_type, parameters, ws); + builder_->Sem().Add(decl, func); + + if (decl->IsEntryPoint()) { + entry_points_.emplace_back(func); + } + + TINT_SCOPED_ASSIGNMENT(current_function_, func); + + if (decl->body) { + Mark(decl->body); if (current_compound_statement_) { TINT_ICE(Resolver, diagnostics_) << "Resolver::Function() called with a current compound statement"; - return false; + return nullptr; } auto* sem_block = builder_->create(func); - builder_->Sem().Add(func->body, sem_block); - if (!Scope(sem_block, [&] { return Statements(func->body->statements); })) { - return false; - } - } - variable_stack_.Pop(); - - for (auto* deco : func->decorations) { - Mark(deco); - } - if (!ValidateNoDuplicateDecorations(func->decorations)) { - return false; - } - - for (auto* deco : func->return_type_decorations) { - Mark(deco); - } - if (!ValidateNoDuplicateDecorations(func->return_type_decorations)) { - return false; - } - - // Set work-group size defaults. - for (int i = 0; i < 3; i++) { - info->workgroup_size[i].value = 1; - info->workgroup_size[i].overridable_const = nullptr; - } - - if (auto* workgroup = - ast::GetDecoration(func->decorations)) { - auto values = workgroup->Values(); - auto any_i32 = false; - auto any_u32 = false; - for (int i = 0; i < 3; i++) { - // Each argument to this decoration can either be a literal, an - // identifier for a module-scope constants, or nullptr if not specified. - - auto* expr = values[i]; - if (!expr) { - // Not specified, just use the default. - continue; - } - - if (!Expression(expr)) { - return false; - } - - constexpr const char* kErrBadType = - "workgroup_size argument must be either literal or module-scope " - "constant of type i32 or u32"; - constexpr const char* kErrInconsistentType = - "workgroup_size arguments must be of the same type, either i32 " - "or u32"; - - auto* ty = TypeOf(expr); - bool is_i32 = ty->UnwrapRef()->Is(); - bool is_u32 = ty->UnwrapRef()->Is(); - if (!is_i32 && !is_u32) { - AddError(kErrBadType, expr->source); - return false; - } - - any_i32 = any_i32 || is_i32; - any_u32 = any_u32 || is_u32; - if (any_i32 && any_u32) { - AddError(kErrInconsistentType, expr->source); - return false; - } - - if (auto* ident = expr->As()) { - // We have an identifier of a module-scope constant. - VariableInfo* var = variable_stack_.Get(ident->symbol); - if (!var || !(var->declaration->is_const)) { - AddError(kErrBadType, expr->source); - return false; - } - - // Capture the constant if an [[override]] attribute is present. - if (ast::HasDecoration( - var->declaration->decorations)) { - info->workgroup_size[i].overridable_const = var->declaration; - } - - expr = var->declaration->constructor; - if (!expr) { - // No constructor means this value must be overriden by the user. - info->workgroup_size[i].value = 0; - continue; - } - } else if (!expr->Is()) { - AddError( - "workgroup_size argument must be either a literal or a " - "module-scope constant", - values[i]->source); - return false; - } - - auto val = ConstantValueOf(expr); - if (!val) { - TINT_ICE(Resolver, diagnostics_) - << "could not resolve constant workgroup_size constant value"; - continue; - } - // Validate and set the default value for this dimension. - if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) { - AddError("workgroup_size argument must be at least 1", - values[i]->source); - return false; - } - - info->workgroup_size[i].value = - is_i32 ? static_cast(val.Elements()[0].i32) - : val.Elements()[0].u32; + builder_->Sem().Add(decl->body, sem_block); + if (!Scope(sem_block, [&] { return Statements(decl->body->statements); })) { + return nullptr; } } - if (!ValidateFunction(func, info)) { - return false; + for (auto* deco : decl->decorations) { + Mark(deco); + } + if (!ValidateNoDuplicateDecorations(decl->decorations)) { + return nullptr; + } + + for (auto* deco : decl->return_type_decorations) { + Mark(deco); + } + if (!ValidateNoDuplicateDecorations(decl->return_type_decorations)) { + return nullptr; + } + + if (!ValidateFunction(func)) { + return nullptr; } // Register the function information _after_ processing the statements. This // allows us to catch a function calling itself when determining the call // information as this function doesn't exist until it's finished. - symbol_to_function_[func->symbol] = info; - function_to_info_.emplace(func, info); + symbol_to_function_[decl->symbol] = func; + // If this is an entry point, mark all transitively called functions as being + // used by this entry point. + if (decl->IsEntryPoint()) { + for (auto* f : func->TransitivelyCalledFunctions()) { + const_cast(f)->AddAncestorEntryPoint(func); + } + } + + return func; +} + +bool Resolver::WorkgroupSizeFor(const ast::Function* func, + sem::WorkgroupSize& ws) { + // Set work-group size defaults. + for (int i = 0; i < 3; i++) { + ws[i].value = 1; + ws[i].overridable_const = nullptr; + } + + auto* deco = ast::GetDecoration(func->decorations); + if (!deco) { + return true; + } + + auto values = deco->Values(); + auto any_i32 = false; + auto any_u32 = false; + for (int i = 0; i < 3; i++) { + // Each argument to this decoration can either be a literal, an + // identifier for a module-scope constants, or nullptr if not specified. + + auto* expr = values[i]; + if (!expr) { + // Not specified, just use the default. + continue; + } + + auto* expr_sem = Expression(expr); + if (!expr_sem) { + return false; + } + + constexpr const char* kErrBadType = + "workgroup_size argument must be either literal or module-scope " + "constant of type i32 or u32"; + constexpr const char* kErrInconsistentType = + "workgroup_size arguments must be of the same type, either i32 " + "or u32"; + + auto* ty = TypeOf(expr); + bool is_i32 = ty->UnwrapRef()->Is(); + bool is_u32 = ty->UnwrapRef()->Is(); + if (!is_i32 && !is_u32) { + AddError(kErrBadType, expr->source); + return false; + } + + any_i32 = any_i32 || is_i32; + any_u32 = any_u32 || is_u32; + if (any_i32 && any_u32) { + AddError(kErrInconsistentType, expr->source); + return false; + } + + if (auto* ident = expr->As()) { + // We have an identifier of a module-scope constant. + auto* var = variable_stack_.Get(ident->symbol); + if (!var || !var->Declaration()->is_const) { + AddError(kErrBadType, expr->source); + return false; + } + + auto* decl = var->Declaration(); + // Capture the constant if an [[override]] attribute is present. + if (ast::HasDecoration(decl->decorations)) { + ws[i].overridable_const = decl; + } + + expr = decl->constructor; + if (!expr) { + // No constructor means this value must be overriden by the user. + ws[i].value = 0; + continue; + } + } else if (!expr->Is()) { + AddError( + "workgroup_size argument must be either a literal or a " + "module-scope constant", + values[i]->source); + return false; + } + + auto val = expr_sem->ConstantValue(); + if (!val) { + TINT_ICE(Resolver, diagnostics_) + << "could not resolve constant workgroup_size constant value"; + continue; + } + // Validate and set the default value for this dimension. + if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) { + AddError("workgroup_size argument must be at least 1", values[i]->source); + return false; + } + + ws[i].value = is_i32 ? static_cast(val.Elements()[0].i32) + : val.Elements()[0].u32; + } return true; } @@ -2080,8 +2111,8 @@ bool Resolver::Statement(const ast::Statement* stmt) { } // Non-Compound statements - sem::Statement* sem_statement = - builder_->create(stmt, current_compound_statement_); + sem::Statement* sem_statement = builder_->create( + stmt, current_compound_statement_, current_function_); builder_->Sem().Add(stmt, sem_statement); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_statement); if (auto* a = stmt->As()) { @@ -2100,9 +2131,6 @@ bool Resolver::Statement(const ast::Statement* stmt) { if (!Expression(c->expr)) { return false; } - if (!ValidateCallStatement(c)) { - return false; - } return true; } if (auto* c = stmt->As()) { @@ -2158,7 +2186,7 @@ bool Resolver::Statement(const ast::Statement* stmt) { bool Resolver::CaseStatement(const ast::CaseStatement* stmt) { auto* sem = builder_->create( - stmt->body, current_compound_statement_); + stmt->body, current_compound_statement_, current_function_); builder_->Sem().Add(stmt, sem); builder_->Sem().Add(stmt->body, sem); Mark(stmt->body); @@ -2169,8 +2197,8 @@ bool Resolver::CaseStatement(const ast::CaseStatement* stmt) { } bool Resolver::IfStatement(const ast::IfStatement* stmt) { - auto* sem = - builder_->create(stmt, current_compound_statement_); + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); builder_->Sem().Add(stmt, sem); return Scope(sem, [&] { if (!Expression(stmt->condition)) { @@ -2179,15 +2207,15 @@ bool Resolver::IfStatement(const ast::IfStatement* stmt) { auto* cond_type = TypeOf(stmt->condition)->UnwrapRef(); if (!cond_type->Is()) { - AddError("if statement condition must be bool, got " + - cond_type->FriendlyName(builder_->Symbols()), - stmt->condition->source); + AddError( + "if statement condition must be bool, got " + TypeNameOf(cond_type), + stmt->condition->source); return false; } Mark(stmt->body); auto* body = builder_->create( - stmt->body, current_compound_statement_); + stmt->body, current_compound_statement_, current_function_); builder_->Sem().Add(stmt->body, body); if (!Scope(body, [&] { return Statements(stmt->body->statements); })) { return false; @@ -2204,8 +2232,8 @@ bool Resolver::IfStatement(const ast::IfStatement* stmt) { } bool Resolver::ElseStatement(const ast::ElseStatement* stmt) { - auto* sem = - builder_->create(stmt, current_compound_statement_); + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); builder_->Sem().Add(stmt, sem); return Scope(sem, [&] { if (auto* cond = stmt->condition) { @@ -2216,7 +2244,7 @@ bool Resolver::ElseStatement(const ast::ElseStatement* stmt) { auto* else_cond_type = TypeOf(cond)->UnwrapRef(); if (!else_cond_type->Is()) { AddError("else statement condition must be bool, got " + - else_cond_type->FriendlyName(builder_->Symbols()), + TypeNameOf(else_cond_type), cond->source); return false; } @@ -2224,7 +2252,7 @@ bool Resolver::ElseStatement(const ast::ElseStatement* stmt) { Mark(stmt->body); auto* body = builder_->create( - stmt->body, current_compound_statement_); + stmt->body, current_compound_statement_, current_function_); builder_->Sem().Add(stmt->body, body); return Scope(body, [&] { return Statements(stmt->body->statements); }); }); @@ -2232,20 +2260,21 @@ bool Resolver::ElseStatement(const ast::ElseStatement* stmt) { bool Resolver::BlockStatement(const ast::BlockStatement* stmt) { auto* sem = builder_->create( - stmt->As(), current_compound_statement_); + stmt->As(), current_compound_statement_, + current_function_); builder_->Sem().Add(stmt, sem); return Scope(sem, [&] { return Statements(stmt->statements); }); } bool Resolver::LoopStatement(const ast::LoopStatement* stmt) { - auto* sem = - builder_->create(stmt, current_compound_statement_); + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); builder_->Sem().Add(stmt, sem); return Scope(sem, [&] { Mark(stmt->body); auto* body = builder_->create( - stmt->body, current_compound_statement_); + stmt->body, current_compound_statement_, current_function_); builder_->Sem().Add(stmt->body, body); return Scope(body, [&] { if (!Statements(stmt->body->statements)) { @@ -2256,7 +2285,8 @@ bool Resolver::LoopStatement(const ast::LoopStatement* stmt) { if (!stmt->continuing->Empty()) { auto* continuing = builder_->create( - stmt->continuing, current_compound_statement_); + stmt->continuing, current_compound_statement_, + current_function_); builder_->Sem().Add(stmt->continuing, continuing); if (!Scope(continuing, [&] { return Statements(stmt->continuing->statements); @@ -2272,7 +2302,7 @@ bool Resolver::LoopStatement(const ast::LoopStatement* stmt) { bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) { auto* sem = builder_->create( - stmt, current_compound_statement_); + stmt, current_compound_statement_, current_function_); builder_->Sem().Add(stmt, sem); return Scope(sem, [&] { if (auto* initializer = stmt->initializer) { @@ -2287,10 +2317,10 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) { return false; } - if (!TypeOf(condition)->UnwrapRef()->Is()) { - AddError( - "for-loop condition must be bool, got " + TypeNameOf(condition), - condition->source); + auto* cond_ty = TypeOf(condition)->UnwrapRef(); + if (!cond_ty->Is()) { + AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty), + condition->source); return false; } } @@ -2305,13 +2335,13 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) { Mark(stmt->body); auto* body = builder_->create( - stmt->body, current_compound_statement_); + stmt->body, current_compound_statement_, current_function_); builder_->Sem().Add(stmt->body, body); return Scope(body, [&] { return Statements(stmt->body->statements); }); }); } -bool Resolver::Expression(const ast::Expression* root) { +sem::Expression* Resolver::Expression(const ast::Expression* root) { std::vector sorted; if (!ast::TraverseExpressions( root, diagnostics_, [&](const ast::Expression* expr) { @@ -2319,145 +2349,241 @@ bool Resolver::Expression(const ast::Expression* root) { sorted.emplace_back(expr); return ast::TraverseAction::Descend; })) { - return false; + return nullptr; } for (auto* expr : utils::Reverse(sorted)) { - bool ok = false; + sem::Expression* sem_expr = nullptr; if (auto* array = expr->As()) { - ok = ArrayAccessor(array); + sem_expr = ArrayAccessor(array); } else if (auto* bin_op = expr->As()) { - ok = Binary(bin_op); + sem_expr = Binary(bin_op); } else if (auto* bitcast = expr->As()) { - ok = Bitcast(bitcast); + sem_expr = Bitcast(bitcast); } else if (auto* call = expr->As()) { - ok = Call(call); + sem_expr = Call(call); } else if (auto* ctor = expr->As()) { - ok = Constructor(ctor); + sem_expr = Constructor(ctor); } else if (auto* ident = expr->As()) { - ok = Identifier(ident); + sem_expr = Identifier(ident); } else if (auto* member = expr->As()) { - ok = MemberAccessor(member); + sem_expr = MemberAccessor(member); } else if (auto* unary = expr->As()) { - ok = UnaryOp(unary); + sem_expr = UnaryOp(unary); } else if (expr->Is()) { - ok = true; // No-op + sem_expr = builder_->create( + expr, builder_->create(), current_statement_, + sem::Constant{}); } else { TINT_ICE(Resolver, diagnostics_) << "unhandled expression type: " << expr->TypeInfo().name; - return false; + return nullptr; } - if (!ok) { - return false; + if (!sem_expr) { + return nullptr; + } + builder_->Sem().Add(expr, sem_expr); + if (expr == root) { + return sem_expr; } } - return true; + TINT_ICE(Resolver, diagnostics_) << "Expression() did not find root node"; + return nullptr; } -bool Resolver::ArrayAccessor(const ast::ArrayAccessorExpression* expr) { +sem::Expression* Resolver::ArrayAccessor( + const ast::ArrayAccessorExpression* expr) { auto* idx = expr->index; - auto* res = TypeOf(expr->array); - auto* parent_type = res->UnwrapRef(); - const sem::Type* ret = nullptr; - if (auto* arr = parent_type->As()) { - ret = arr->ElemType(); - } else if (auto* vec = parent_type->As()) { - ret = vec->type(); - } else if (auto* mat = parent_type->As()) { - ret = builder_->create(mat->type(), mat->rows()); + auto* parent_raw_ty = TypeOf(expr->array); + auto* parent_ty = parent_raw_ty->UnwrapRef(); + const sem::Type* ty = nullptr; + if (auto* arr = parent_ty->As()) { + ty = arr->ElemType(); + } else if (auto* vec = parent_ty->As()) { + ty = vec->type(); + } else if (auto* mat = parent_ty->As()) { + ty = builder_->create(mat->type(), mat->rows()); } else { - AddError("cannot index type '" + TypeNameOf(expr->array) + "'", - expr->source); - return false; + AddError("cannot index type '" + TypeNameOf(parent_ty) + "'", expr->source); + return nullptr; } - if (!TypeOf(idx)->UnwrapRef()->IsAnyOf()) { + auto* idx_ty = TypeOf(idx)->UnwrapRef(); + if (!idx_ty->IsAnyOf()) { AddError("index must be of type 'i32' or 'u32', found: '" + - TypeNameOf(idx) + "'", + TypeNameOf(idx_ty) + "'", idx->source); - return false; + return nullptr; } - if (parent_type->Is() || parent_type->Is()) { - if (!res->Is()) { + if (parent_ty->IsAnyOf()) { + if (!parent_raw_ty->Is()) { // TODO(bclayton): expand this to allow any const_expr expression // https://github.com/gpuweb/gpuweb/issues/1272 auto* scalar = idx->As(); if (!scalar || !scalar->literal->As()) { AddError("index must be signed or unsigned integer literal", idx->source); - return false; + return nullptr; } } } // If we're extracting from a reference, we return a reference. - if (auto* ref = res->As()) { - ret = builder_->create(ret, ref->StorageClass(), - ref->Access()); + if (auto* ref = parent_raw_ty->As()) { + ty = builder_->create(ty, ref->StorageClass(), + ref->Access()); } - SetExprInfo(expr, ret); - return true; + auto val = EvaluateConstantValue(expr, ty); + return builder_->create(expr, ty, current_statement_, val); } -bool Resolver::Bitcast(const ast::BitcastExpression* expr) { +sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { auto* ty = Type(expr->type); if (!ty) { - return false; + return nullptr; } if (ty->Is()) { AddError("cannot cast to a pointer", expr->source); - return false; + return nullptr; } - SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols())); - return true; + + auto val = EvaluateConstantValue(expr, ty); + return builder_->create(expr, ty, current_statement_, val); } -bool Resolver::Call(const ast::CallExpression* call) { - Mark(call->func); - auto* ident = call->func; +sem::Expression* Resolver::Call(const ast::CallExpression* expr) { + auto* ident = expr->func; + Mark(ident); auto name = builder_->Symbols().NameFor(ident->symbol); auto intrinsic_type = sem::ParseIntrinsicType(name); - if (intrinsic_type != IntrinsicType::kNone) { - if (!IntrinsicCall(call, intrinsic_type)) { - return false; + auto* call = (intrinsic_type != IntrinsicType::kNone) + ? IntrinsicCall(expr, intrinsic_type) + : FunctionCall(expr); + + current_function_->AddDirectCall(call); + return call; +} + +sem::Call* Resolver::IntrinsicCall(const ast::CallExpression* expr, + sem::IntrinsicType intrinsic_type) { + std::vector args(expr->args.size()); + std::vector arg_tys(expr->args.size()); + for (size_t i = 0; i < expr->args.size(); i++) { + auto* arg = Sem(expr->args[i]); + if (!arg) { + return nullptr; } - } else { - if (!FunctionCall(call)) { - return false; + args[i] = arg; + arg_tys[i] = arg->Type(); + } + + auto* intrinsic = intrinsic_table_->Lookup(intrinsic_type, std::move(arg_tys), + expr->source); + if (!intrinsic) { + return nullptr; + } + + if (intrinsic->IsDeprecated()) { + AddWarning("use of deprecated intrinsic", expr->source); + } + + auto* call = builder_->create(expr, intrinsic, std::move(args), + current_statement_); + + current_function_->AddDirectlyCalledIntrinsic(intrinsic); + + if (IsTextureIntrinsic(intrinsic_type) && + !ValidateTextureIntrinsicFunction(call)) { + return nullptr; + } + + if (!ValidateCall(call)) { + return nullptr; + } + + return call; +} + +sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr) { + auto* ident = expr->func; + auto name = builder_->Symbols().NameFor(ident->symbol); + + auto target_it = symbol_to_function_.find(ident->symbol); + if (target_it == symbol_to_function_.end()) { + if (current_function_ && + current_function_->Declaration()->symbol == ident->symbol) { + AddError("recursion is not permitted. '" + name + + "' attempted to call itself.", + expr->source); + } else { + AddError("unable to find called function: " + name, expr->source); + } + return nullptr; + } + auto* target = target_it->second; + + std::vector args(expr->args.size()); + for (size_t i = 0; i < expr->args.size(); i++) { + auto* arg = Sem(expr->args[i]); + if (!arg) { + return nullptr; + } + args[i] = arg; + } + + auto* call = builder_->create(expr, target, std::move(args), + current_statement_); + + if (current_function_) { + target->AddCallSite(call); + + // Note: Requires called functions to be resolved first. + // This is currently guaranteed as functions must be declared before + // use. + current_function_->AddTransitivelyCalledFunction(target); + for (auto* transitive_call : target->TransitivelyCalledFunctions()) { + current_function_->AddTransitivelyCalledFunction(transitive_call); + } + + // We inherit any referenced variables from the callee. + for (auto* var : target->TransitivelyReferencedGlobals()) { + current_function_->AddTransitivelyReferencedGlobal(var); } } - return ValidateCall(call); + if (!ValidateFunctionCall(call)) { + return nullptr; + } + + if (!ValidateCall(call)) { + return nullptr; + } + + return call; } -bool Resolver::ValidateCall(const ast::CallExpression* call) { - if (TypeOf(call)->Is()) { +bool Resolver::ValidateCall(const sem::Call* call) { + if (call->Type()->Is()) { bool is_call_statement = false; - if (current_statement_) { - if (auto* call_stmt = - As(current_statement_->Declaration())) { - if (call_stmt->expr == call) { - is_call_statement = true; - } + if (auto* call_stmt = As(call->Stmt()->Declaration())) { + if (call_stmt->expr == call->Declaration()) { + is_call_statement = true; } } if (!is_call_statement) { // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr // If the called function does not return a value, a function call // statement should be used instead. - auto* ident = call->func; + auto* ident = call->Declaration()->func; auto name = builder_->Symbols().NameFor(ident->symbol); - // A function call is made to either a user declared function or an - // intrinsic. function_calls_ only maps CallExpression to user declared - // functions - bool is_function = function_calls_.count(call) != 0; + bool is_function = call->Target()->Is(); AddError((is_function ? "function" : "intrinsic") + std::string(" '") + name + "' does not return a value", - call->source); + call->Declaration()->source); return false; } } @@ -2465,47 +2591,8 @@ bool Resolver::ValidateCall(const ast::CallExpression* call) { return true; } -bool Resolver::ValidateCallStatement(const ast::CallStatement*) { - return true; -} - -bool Resolver::IntrinsicCall(const ast::CallExpression* call, - sem::IntrinsicType intrinsic_type) { - std::vector arg_tys; - arg_tys.reserve(call->args.size()); - for (auto* expr : call->args) { - arg_tys.emplace_back(TypeOf(expr)); - } - - auto* result = - intrinsic_table_->Lookup(intrinsic_type, arg_tys, call->source); - if (!result) { - return false; - } - - if (result->IsDeprecated()) { - AddWarning("use of deprecated intrinsic", call->source); - } - - auto* out = builder_->create(call, result, current_statement_); - builder_->Sem().Add(call, out); - SetExprInfo(call, result->ReturnType()); - - current_function_->intrinsic_calls.emplace_back( - IntrinsicCallInfo{call, result}); - - if (IsTextureIntrinsic(intrinsic_type) && - !ValidateTextureIntrinsicFunction(call, out)) { - return false; - } - - return true; -} - -bool Resolver::ValidateTextureIntrinsicFunction( - const ast::CallExpression* ast_call, - const sem::Call* sem_call) { - auto* intrinsic = sem_call->Target()->As(); +bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) { + auto* intrinsic = call->Target()->As(); if (!intrinsic) { return false; } @@ -2513,146 +2600,111 @@ bool Resolver::ValidateTextureIntrinsicFunction( auto& signature = intrinsic->Signature(); auto index = signature.IndexOf(sem::ParameterUsage::kOffset); if (index > -1) { - auto* param = ast_call->args[index]; - if (param->Is()) { - auto values = ConstantValueOf(param); - if (!values.IsValid()) { - AddError( - "'" + func_name + "' offset parameter must be a const_expression", - param->source); - return false; - } + auto* arg = call->Arguments()[index]; + if (auto values = arg->ConstantValue()) { + // Assert that the constant values are of the expected type. if (!values.Type()->Is() || !values.ElementType()->Is()) { TINT_ICE(Resolver, diagnostics_) << "failed to resolve '" + func_name + "' offset parameter type"; return false; } - for (auto offset : values.Elements()) { - auto offset_value = offset.i32; - if (offset_value < -8 || offset_value > 7) { - AddError("each offset component of '" + func_name + - "' must be at least -8 and at most 7. " - "found: '" + - std::to_string(offset_value) + "'", - param->source); - return false; + + // Currently const_expr is restricted to literals and type constructors. + // Check that that's all we have for the offset parameter. + bool is_const_expr = true; + ast::TraverseExpressions( + arg->Declaration(), diagnostics_, [&](const ast::Expression* e) { + if (e->IsAnyOf()) { + return ast::TraverseAction::Descend; + } + is_const_expr = false; + return ast::TraverseAction::Stop; + }); + if (is_const_expr) { + for (auto offset : values.Elements()) { + auto offset_value = offset.i32; + if (offset_value < -8 || offset_value > 7) { + AddError("each offset component of '" + func_name + + "' must be at least -8 and at most 7. " + "found: '" + + std::to_string(offset_value) + "'", + arg->Declaration()->source); + return false; + } } + return true; } - } else { - AddError( - "'" + func_name + "' offset parameter must be a const_expression", - param->source); - return false; } - } - return true; -} - -bool Resolver::FunctionCall(const ast::CallExpression* call) { - auto* ident = call->func; - auto name = builder_->Symbols().NameFor(ident->symbol); - - auto callee_func_it = symbol_to_function_.find(ident->symbol); - if (callee_func_it == symbol_to_function_.end()) { - if (current_function_ && - current_function_->declaration->symbol == ident->symbol) { - AddError("recursion is not permitted. '" + name + - "' attempted to call itself.", - call->source); - } else { - AddError("unable to find called function: " + name, call->source); - } - return false; - } - auto* callee_func = callee_func_it->second; - - if (current_function_) { - callee_func->callsites.push_back(call); - - // Note: Requires called functions to be resolved first. - // This is currently guaranteed as functions must be declared before - // use. - current_function_->transitive_calls.add(callee_func); - for (auto* transitive_call : callee_func->transitive_calls) { - current_function_->transitive_calls.add(transitive_call); - } - - // We inherit any referenced variables from the callee. - for (auto* var : callee_func->referenced_module_vars) { - set_referenced_from_function_if_needed(var, false); - } - } - - function_calls_.emplace(call, - FunctionCallInfo{callee_func, current_statement_}); - SetExprInfo(call, callee_func->return_type, callee_func->return_type_name); - - if (!ValidateFunctionCall(call, callee_func)) { + AddError("'" + func_name + "' offset parameter must be a const_expression", + arg->Declaration()->source); return false; } return true; } -bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, - const FunctionInfo* target) { - auto* ident = call->func; +bool Resolver::ValidateFunctionCall(const sem::Call* call) { + auto* decl = call->Declaration(); + auto* ident = decl->func; + auto* target = call->Target()->As(); auto name = builder_->Symbols().NameFor(ident->symbol); - if (target->declaration->IsEntryPoint()) { + if (target->Declaration()->IsEntryPoint()) { // https://www.w3.org/TR/WGSL/#function-restriction // An entry point must never be the target of a function call. AddError("entry point functions cannot be the target of a function call", - call->source); + decl->source); return false; } - if (call->args.size() != target->parameters.size()) { - bool more = call->args.size() > target->parameters.size(); + if (decl->args.size() != target->Parameters().size()) { + bool more = decl->args.size() > target->Parameters().size(); AddError("too " + (more ? std::string("many") : std::string("few")) + " arguments in call to '" + name + "', expected " + - std::to_string(target->parameters.size()) + ", got " + - std::to_string(call->args.size()), - call->source); + std::to_string(target->Parameters().size()) + ", got " + + std::to_string(call->Arguments().size()), + decl->source); return false; } - for (size_t i = 0; i < call->args.size(); ++i) { - const VariableInfo* param = target->parameters[i]; - const ast::Expression* arg_expr = call->args[i]; + for (size_t i = 0; i < call->Arguments().size(); ++i) { + const sem::Variable* param = target->Parameters()[i]; + const ast::Expression* arg_expr = decl->args[i]; + auto* param_type = param->Type(); auto* arg_type = TypeOf(arg_expr)->UnwrapRef(); - if (param->type != arg_type) { + if (param_type != arg_type) { AddError("type mismatch for argument " + std::to_string(i + 1) + " in call to '" + name + "', expected '" + - param->type->FriendlyName(builder_->Symbols()) + "', got '" + - arg_type->FriendlyName(builder_->Symbols()) + "'", + TypeNameOf(param_type) + "', got '" + TypeNameOf(arg_type) + + "'", arg_expr->source); return false; } - if (param->declaration->type->Is()) { + if (param_type->Is()) { auto is_valid = false; if (auto* ident_expr = arg_expr->As()) { - VariableInfo* var = variable_stack_.Get(ident_expr->symbol); + auto* var = variable_stack_.Get(ident_expr->symbol); if (!var) { TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; return false; } - if (var->kind == VariableKind::kParameter) { + if (var->Is()) { is_valid = true; } } else if (auto* unary = arg_expr->As()) { if (unary->op == ast::UnaryOp::kAddressOf) { if (auto* ident_unary = unary->expr->As()) { - VariableInfo* var = variable_stack_.Get(ident_unary->symbol); + auto* var = variable_stack_.Get(ident_unary->symbol); if (!var) { TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; return false; } - if (var->declaration->is_const) { + if (var->Declaration()->is_const) { TINT_ICE(Resolver, diagnostics_) << "Resolver::FunctionCall() encountered an address-of " "expression of a constant identifier expression"; @@ -2665,7 +2717,7 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, if (!is_valid && IsValidationEnabled( - param->declaration->decorations, + param->Declaration()->decorations, ast::DisabledValidation::kIgnoreInvalidPointerArgument)) { AddError( "expected an address-of expression of a variable identifier " @@ -2678,52 +2730,52 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, return true; } -bool Resolver::Constructor(const ast::ConstructorExpression* expr) { +sem::Expression* Resolver::Constructor(const ast::ConstructorExpression* expr) { if (auto* type_ctor = expr->As()) { - auto* type = Type(type_ctor->type); - if (!type) { - return false; + auto* ty = Type(type_ctor->type); + if (!ty) { + return nullptr; } - auto type_name = type_ctor->type->FriendlyName(builder_->Symbols()); - // Now that the argument types have been determined, make sure that they // obey the constructor type rules laid out in // https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr. bool ok = true; - if (auto* vec_type = type->As()) { - ok = ValidateVectorConstructor(type_ctor, vec_type, type_name); - } else if (auto* mat_type = type->As()) { - ok = ValidateMatrixConstructor(type_ctor, mat_type, type_name); - } else if (type->is_scalar()) { - ok = ValidateScalarConstructor(type_ctor, type, type_name); - } else if (auto* arr_type = type->As()) { + if (auto* vec_type = ty->As()) { + ok = ValidateVectorConstructor(type_ctor, vec_type); + } else if (auto* mat_type = ty->As()) { + ok = ValidateMatrixConstructor(type_ctor, mat_type); + } else if (ty->is_scalar()) { + ok = ValidateScalarConstructor(type_ctor, ty); + } else if (auto* arr_type = ty->As()) { ok = ValidateArrayConstructor(type_ctor, arr_type); - } else if (auto* struct_type = type->As()) { + } else if (auto* struct_type = ty->As()) { ok = ValidateStructureConstructor(type_ctor, struct_type); } else { AddError("type is not constructible", type_ctor->source); - return false; + return nullptr; } if (!ok) { - return false; + return nullptr; } - SetExprInfo(expr, type, type_name); - return true; + + auto val = EvaluateConstantValue(expr, ty); + return builder_->create(expr, ty, current_statement_, val); } if (auto* scalar_ctor = expr->As()) { Mark(scalar_ctor->literal); - auto* type = TypeOf(scalar_ctor->literal); - if (!type) { - return false; + auto* ty = TypeOf(scalar_ctor->literal); + if (!ty) { + return nullptr; } - SetExprInfo(expr, type); - return true; + + auto val = EvaluateConstantValue(expr, ty); + return builder_->create(expr, ty, current_statement_, val); } TINT_ICE(Resolver, diagnostics_) << "unexpected constructor expression type"; - return false; + return nullptr; } bool Resolver::ValidateStructureConstructor( @@ -2746,12 +2798,13 @@ bool Resolver::ValidateStructureConstructor( } for (auto* member : struct_type->Members()) { auto* value = ctor->values[member->Index()]; - if (member->Type() != TypeOf(value)->UnwrapRef()) { + auto* value_ty = TypeOf(value); + if (member->Type() != value_ty->UnwrapRef()) { AddError( "type in struct constructor does not match struct member type: " "expected '" + - member->Type()->FriendlyName(builder_->Symbols()) + - "', found '" + TypeNameOf(value) + "'", + TypeNameOf(member->Type()) + "', found '" + + TypeNameOf(value_ty) + "'", value->source); return false; } @@ -2764,15 +2817,14 @@ bool Resolver::ValidateArrayConstructor( const ast::TypeConstructorExpression* ctor, const sem::Array* array_type) { auto& values = ctor->values; - auto* elem_type = array_type->ElemType(); + auto* elem_ty = array_type->ElemType(); for (auto* value : values) { - auto* value_type = TypeOf(value)->UnwrapRef(); - if (value_type != elem_type) { + auto* value_ty = TypeOf(value)->UnwrapRef(); + if (value_ty != elem_ty) { AddError( "type in array constructor does not match array type: " "expected '" + - elem_type->FriendlyName(builder_->Symbols()) + "', found '" + - TypeNameOf(value) + "'", + TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'", value->source); return false; } @@ -2781,7 +2833,7 @@ bool Resolver::ValidateArrayConstructor( if (array_type->IsRuntimeSized()) { AddError("cannot init a runtime-sized array", ctor->source); return false; - } else if (!elem_type->IsConstructible()) { + } else if (!elem_ty->IsConstructible()) { AddError("array constructor has non-constructible element type", ctor->type->As()->type->source); return false; @@ -2804,36 +2856,34 @@ bool Resolver::ValidateArrayConstructor( bool Resolver::ValidateVectorConstructor( const ast::TypeConstructorExpression* ctor, - const sem::Vector* vec_type, - const std::string& type_name) { + const sem::Vector* vec_type) { auto& values = ctor->values; - auto* elem_type = vec_type->type(); + auto* elem_ty = vec_type->type(); size_t value_cardinality_sum = 0; for (auto* value : values) { - auto* value_type = TypeOf(value)->UnwrapRef(); - if (value_type->is_scalar()) { - if (elem_type != value_type) { + auto* value_ty = TypeOf(value)->UnwrapRef(); + if (value_ty->is_scalar()) { + if (elem_ty != value_ty) { AddError( "type in vector constructor does not match vector type: " "expected '" + - elem_type->FriendlyName(builder_->Symbols()) + "', found '" + - TypeNameOf(value) + "'", + TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'", value->source); return false; } value_cardinality_sum++; - } else if (auto* value_vec = value_type->As()) { - auto* value_elem_type = value_vec->type(); + } else if (auto* value_vec = value_ty->As()) { + auto* value_elem_ty = value_vec->type(); // A mismatch of vector type parameter T is only an error if multiple // arguments are present. A single argument constructor constitutes a // type conversion expression. - if (elem_type != value_elem_type && values.size() > 1u) { + if (elem_ty != value_elem_ty && values.size() > 1u) { AddError( "type in vector constructor does not match vector type: " "expected '" + - elem_type->FriendlyName(builder_->Symbols()) + "', found '" + - value_elem_type->FriendlyName(builder_->Symbols()) + "'", + TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_elem_ty) + + "'", value->source); return false; } @@ -2842,7 +2892,7 @@ bool Resolver::ValidateVectorConstructor( } else { // A vector constructor can only accept vectors and scalars. AddError("expected vector or scalar type in vector constructor; found: " + - value_type->FriendlyName(builder_->Symbols()), + TypeNameOf(value_ty), value->source); return false; } @@ -2858,7 +2908,7 @@ bool Resolver::ValidateVectorConstructor( } const Source& values_start = values[0]->source; const Source& values_end = values[values.size() - 1]->source; - AddError("attempted to construct '" + type_name + "' with " + + AddError("attempted to construct '" + TypeNameOf(vec_type) + "' with " + std::to_string(value_cardinality_sum) + " component(s)", Source::Combine(values_start, values_end)); return false; @@ -2885,27 +2935,27 @@ bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) { bool Resolver::ValidateMatrixConstructor( const ast::TypeConstructorExpression* ctor, - const sem::Matrix* matrix_type, - const std::string& type_name) { + const sem::Matrix* matrix_ty) { auto& values = ctor->values; // Zero Value expression if (values.empty()) { return true; } - if (!ValidateMatrix(matrix_type, ctor->source)) { + if (!ValidateMatrix(matrix_ty, ctor->source)) { return false; } - auto* elem_type = matrix_type->type(); - auto num_elements = matrix_type->columns() * matrix_type->rows(); + auto* elem_type = matrix_ty->type(); + auto num_elements = matrix_ty->columns() * matrix_ty->rows(); // Print a generic error for an invalid matrix constructor, showing the // available overloads. auto print_error = [&]() { const Source& values_start = values[0]->source; const Source& values_end = values[values.size() - 1]->source; - auto elem_type_name = elem_type->FriendlyName(builder_->Symbols()); + auto type_name = TypeNameOf(matrix_ty); + auto elem_type_name = TypeNameOf(elem_type); std::stringstream ss; ss << "invalid constructor for " + type_name << std::endl << std::endl; ss << "3 candidates available:" << std::endl; @@ -2914,11 +2964,11 @@ bool Resolver::ValidateMatrixConstructor( << elem_type_name << ")" << " // " << std::to_string(num_elements) << " arguments" << std::endl; ss << " " << type_name << "("; - for (uint32_t c = 0; c < matrix_type->columns(); c++) { + for (uint32_t c = 0; c < matrix_ty->columns(); c++) { if (c > 0) { ss << ", "; } - ss << VectorPretty(matrix_type->rows(), elem_type); + ss << VectorPretty(matrix_ty->rows(), elem_type); } ss << ")" << std::endl; AddError(ss.str(), Source::Combine(values_start, values_end)); @@ -2927,10 +2977,10 @@ bool Resolver::ValidateMatrixConstructor( const sem::Type* expected_arg_type = nullptr; if (num_elements == values.size()) { // Column-major construction from scalar elements. - expected_arg_type = matrix_type->type(); - } else if (matrix_type->columns() == values.size()) { + expected_arg_type = matrix_ty->type(); + } else if (matrix_ty->columns() == values.size()) { // Column-by-column construction from vectors. - expected_arg_type = matrix_type->ColumnType(); + expected_arg_type = matrix_ty->ColumnType(); } else { print_error(); return false; @@ -2948,8 +2998,7 @@ bool Resolver::ValidateMatrixConstructor( bool Resolver::ValidateScalarConstructor( const ast::TypeConstructorExpression* ctor, - const sem::Type* type, - const std::string& type_name) { + const sem::Type* ty) { if (ctor->values.size() == 0) { return true; } @@ -2962,20 +3011,20 @@ bool Resolver::ValidateScalarConstructor( // Validate constructor auto* value = ctor->values[0]; - auto* value_type = TypeOf(value)->UnwrapRef(); + auto* value_ty = TypeOf(value)->UnwrapRef(); using Bool = sem::Bool; using I32 = sem::I32; using U32 = sem::U32; using F32 = sem::F32; - const bool is_valid = (type->Is() && value_type->is_scalar()) || - (type->Is() && value_type->is_scalar()) || - (type->Is() && value_type->is_scalar()) || - (type->Is() && value_type->is_scalar()); + const bool is_valid = (ty->Is() && value_ty->is_scalar()) || + (ty->Is() && value_ty->is_scalar()) || + (ty->Is() && value_ty->is_scalar()) || + (ty->Is() && value_ty->is_scalar()); if (!is_valid) { - AddError("cannot construct '" + type_name + "' with a value of type '" + - TypeNameOf(value) + "'", + AddError("cannot construct '" + TypeNameOf(ty) + + "' with a value of type '" + TypeNameOf(value_ty) + "'", ctor->source); return false; @@ -2984,13 +3033,11 @@ bool Resolver::ValidateScalarConstructor( return true; } -bool Resolver::Identifier(const ast::IdentifierExpression* expr) { +sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { auto symbol = expr->symbol; - if (VariableInfo* var = variable_stack_.Get(symbol)) { - SetExprInfo(expr, var->type, var->type_name); - - var->users.push_back(expr); - set_referenced_from_function_if_needed(var, true); + if (auto* var = variable_stack_.Get(symbol)) { + auto* user = + builder_->create(expr, current_statement_, var); if (current_statement_) { // If identifier is part of a loop continuing block, make sure it @@ -3021,40 +3068,47 @@ bool Resolver::Identifier(const ast::IdentifierExpression* expr) { AddNote("identifier '" + builder_->Symbols().NameFor(symbol) + "' referenced in continuing block here", expr->source); - return false; + return nullptr; } } } } } - return true; + if (current_function_) { + if (auto* global = var->As()) { + current_function_->AddDirectlyReferencedGlobal(global); + } + } + + var->AddUser(user); + return user; } - auto iter = symbol_to_function_.find(symbol); - if (iter != symbol_to_function_.end()) { + if (symbol_to_function_.count(symbol)) { AddError("missing '(' for function call", expr->source.End()); - return false; + return nullptr; } std::string name = builder_->Symbols().NameFor(symbol); if (sem::ParseIntrinsicType(name) != IntrinsicType::kNone) { AddError("missing '(' for intrinsic call", expr->source.End()); - return false; + return nullptr; } AddError("identifier must be declared before use: " + name, expr->source); - return false; + return nullptr; } -bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) { +sem::Expression* Resolver::MemberAccessor( + const ast::MemberAccessorExpression* expr) { auto* structure = TypeOf(expr->structure); - auto* storage_type = structure->UnwrapRef(); + auto* storage_ty = structure->UnwrapRef(); const sem::Type* ret = nullptr; std::vector swizzle; - if (auto* str = storage_type->As()) { + if (auto* str = storage_ty->As()) { Mark(expr->member); auto symbol = expr->member->symbol; @@ -3071,7 +3125,7 @@ bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) { AddError( "struct member " + builder_->Symbols().NameFor(symbol) + " not found", expr->source); - return false; + return nullptr; } // If we're extracting from a reference, we return a reference. @@ -3080,9 +3134,11 @@ bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) { ref->Access()); } - builder_->Sem().Add(expr, builder_->create( - expr, ret, current_statement_, member)); - } else if (auto* vec = storage_type->As()) { + return builder_->create( + expr, ret, current_statement_, member); + } + + if (auto* vec = storage_ty->As()) { Mark(expr->member); std::string s = builder_->Symbols().NameFor(expr->member->symbol); auto size = s.size(); @@ -3109,18 +3165,18 @@ bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) { default: AddError("invalid vector swizzle character", expr->member->source.Begin() + swizzle.size()); - return false; + return nullptr; } if (swizzle.back() >= vec->Width()) { AddError("invalid vector swizzle member", expr->member->source); - return false; + return nullptr; } } if (size < 1 || size > 4) { AddError("invalid vector swizzle size", expr->member->source); - return false; + return nullptr; } // All characters are valid, check if they're being mixed @@ -3134,7 +3190,7 @@ bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) { !std::all_of(s.begin(), s.end(), is_xyzw)) { AddError("invalid mixing of vector swizzle characters rgba with xyzw", expr->member->source); - return false; + return nullptr; } if (size == 1) { @@ -3151,23 +3207,18 @@ bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) { ret = builder_->create(vec->type(), static_cast(size)); } - builder_->Sem().Add( - expr, builder_->create(expr, ret, current_statement_, - std::move(swizzle))); - } else { - AddError( - "invalid member accessor expression. Expected vector or struct, got '" + - TypeNameOf(expr->structure) + "'", - expr->structure->source); - return false; + return builder_->create(expr, ret, current_statement_, + std::move(swizzle)); } - SetExprInfo(expr, ret); - - return true; + AddError( + "invalid member accessor expression. Expected vector or struct, got '" + + TypeNameOf(storage_ty) + "'", + expr->structure->source); + return nullptr; } -bool Resolver::Binary(const ast::BinaryExpression* expr) { +sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { using Bool = sem::Bool; using F32 = sem::F32; using I32 = sem::I32; @@ -3175,12 +3226,12 @@ bool Resolver::Binary(const ast::BinaryExpression* expr) { using Matrix = sem::Matrix; using Vector = sem::Vector; - auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef(); - auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef(); + auto* lhs_ty = TypeOf(expr->lhs)->UnwrapRef(); + auto* rhs_ty = TypeOf(expr->rhs)->UnwrapRef(); - auto* lhs_vec = lhs_type->As(); + auto* lhs_vec = lhs_ty->As(); auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr; - auto* rhs_vec = rhs_type->As(); + auto* rhs_vec = rhs_ty->As(); auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr; const bool matching_vec_elem_types = @@ -3188,70 +3239,66 @@ bool Resolver::Binary(const ast::BinaryExpression* expr) { (lhs_vec_elem_type == rhs_vec_elem_type) && (lhs_vec->Width() == rhs_vec->Width()); - const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type); + const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty); + + auto build = [&](const sem::Type* ty) { + auto val = EvaluateConstantValue(expr, ty); + return builder_->create(expr, ty, current_statement_, val); + }; // Binary logical expressions if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { - if (matching_types && lhs_type->Is()) { - SetExprInfo(expr, lhs_type); - return true; + if (matching_types && lhs_ty->Is()) { + return build(lhs_ty); } } if (expr->IsOr() || expr->IsAnd()) { - if (matching_types && lhs_type->Is()) { - SetExprInfo(expr, lhs_type); - return true; + if (matching_types && lhs_ty->Is()) { + return build(lhs_ty); } if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is()) { - SetExprInfo(expr, lhs_type); - return true; + return build(lhs_ty); } } // Arithmetic expressions if (expr->IsArithmetic()) { // Binary arithmetic expressions over scalars - if (matching_types && lhs_type->is_numeric_scalar()) { - SetExprInfo(expr, lhs_type); - return true; + if (matching_types && lhs_ty->is_numeric_scalar()) { + return build(lhs_ty); } // Binary arithmetic expressions over vectors if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->is_numeric_scalar()) { - SetExprInfo(expr, lhs_type); - return true; + return build(lhs_ty); } // Binary arithmetic expressions with mixed scalar and vector operands - if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_type)) { + if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty)) { if (expr->IsModulo()) { - if (rhs_type->is_integer_scalar()) { - SetExprInfo(expr, lhs_type); - return true; + if (rhs_ty->is_integer_scalar()) { + return build(lhs_ty); } - } else if (rhs_type->is_numeric_scalar()) { - SetExprInfo(expr, lhs_type); - return true; + } else if (rhs_ty->is_numeric_scalar()) { + return build(lhs_ty); } } - if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_type)) { + if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty)) { if (expr->IsModulo()) { - if (lhs_type->is_integer_scalar()) { - SetExprInfo(expr, rhs_type); - return true; + if (lhs_ty->is_integer_scalar()) { + return build(rhs_ty); } - } else if (lhs_type->is_numeric_scalar()) { - SetExprInfo(expr, rhs_type); - return true; + } else if (lhs_ty->is_numeric_scalar()) { + return build(rhs_ty); } } } // Matrix arithmetic - auto* lhs_mat = lhs_type->As(); + auto* lhs_mat = lhs_ty->As(); auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr; - auto* rhs_mat = rhs_type->As(); + auto* rhs_mat = rhs_ty->As(); auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; // Addition and subtraction of float matrices if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type && @@ -3259,49 +3306,42 @@ bool Resolver::Binary(const ast::BinaryExpression* expr) { rhs_mat_elem_type->Is() && (lhs_mat->columns() == rhs_mat->columns()) && (lhs_mat->rows() == rhs_mat->rows())) { - SetExprInfo(expr, rhs_type); - return true; + return build(rhs_ty); } if (expr->IsMultiply()) { // Multiplication of a matrix and a scalar - if (lhs_type->Is() && rhs_mat_elem_type && + if (lhs_ty->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is()) { - SetExprInfo(expr, rhs_type); - return true; + return build(rhs_ty); } if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && - rhs_type->Is()) { - SetExprInfo(expr, lhs_type); - return true; + rhs_ty->Is()) { + return build(lhs_ty); } // Vector times matrix if (lhs_vec_elem_type && lhs_vec_elem_type->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is() && (lhs_vec->Width() == rhs_mat->rows())) { - SetExprInfo(expr, builder_->create(lhs_vec->type(), - rhs_mat->columns())); - return true; + return build( + builder_->create(lhs_vec->type(), rhs_mat->columns())); } // Matrix times vector if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && rhs_vec_elem_type && rhs_vec_elem_type->Is() && (lhs_mat->columns() == rhs_vec->Width())) { - SetExprInfo(expr, builder_->create(rhs_vec->type(), - lhs_mat->rows())); - return true; + return build( + builder_->create(rhs_vec->type(), lhs_mat->rows())); } // Matrix times matrix if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is() && (lhs_mat->columns() == rhs_mat->rows())) { - SetExprInfo(expr, builder_->create( - builder_->create(lhs_mat_elem_type, - lhs_mat->rows()), - rhs_mat->columns())); - return true; + return build(builder_->create( + builder_->create(lhs_mat_elem_type, lhs_mat->rows()), + rhs_mat->columns())); } } @@ -3309,15 +3349,13 @@ bool Resolver::Binary(const ast::BinaryExpression* expr) { if (expr->IsComparison()) { if (matching_types) { // Special case for bools: only == and != - if (lhs_type->Is() && (expr->IsEqual() || expr->IsNotEqual())) { - SetExprInfo(expr, builder_->create()); - return true; + if (lhs_ty->Is() && (expr->IsEqual() || expr->IsNotEqual())) { + return build(builder_->create()); } // For the rest, we can compare i32, u32, and f32 - if (lhs_type->IsAnyOf()) { - SetExprInfo(expr, builder_->create()); - return true; + if (lhs_ty->IsAnyOf()) { + return build(builder_->create()); } } @@ -3325,24 +3363,21 @@ bool Resolver::Binary(const ast::BinaryExpression* expr) { if (matching_vec_elem_types) { if (lhs_vec_elem_type->Is() && (expr->IsEqual() || expr->IsNotEqual())) { - SetExprInfo(expr, builder_->create( - builder_->create(), lhs_vec->Width())); - return true; + return build(builder_->create( + builder_->create(), lhs_vec->Width())); } if (lhs_vec_elem_type->is_numeric_scalar()) { - SetExprInfo(expr, builder_->create( - builder_->create(), lhs_vec->Width())); - return true; + return build(builder_->create( + builder_->create(), lhs_vec->Width())); } } } // Binary bitwise operations if (expr->IsBitwise()) { - if (matching_types && lhs_type->is_integer_scalar_or_vector()) { - SetExprInfo(expr, lhs_type); - return true; + if (matching_types && lhs_ty->is_integer_scalar_or_vector()) { + return build(lhs_ty); } } @@ -3352,79 +3387,72 @@ bool Resolver::Binary(const ast::BinaryExpression* expr) { // differences in computation rules (i.e. right shift can be arithmetic or // logical depending on lhs type). - if (lhs_type->IsAnyOf() && rhs_type->Is()) { - SetExprInfo(expr, lhs_type); - return true; + if (lhs_ty->IsAnyOf() && rhs_ty->Is()) { + return build(lhs_ty); } if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf() && rhs_vec_elem_type && rhs_vec_elem_type->Is()) { - SetExprInfo(expr, lhs_type); - return true; + return build(lhs_ty); } } AddError("Binary expression operand types are invalid for this operation: " + - lhs_type->FriendlyName(builder_->Symbols()) + " " + - FriendlyName(expr->op) + " " + - rhs_type->FriendlyName(builder_->Symbols()), + TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " + + TypeNameOf(rhs_ty), expr->source); - return false; + return nullptr; } -bool Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { - auto* expr_type = TypeOf(unary->expr); - if (!expr_type) { - return false; +sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { + auto* expr_ty = TypeOf(unary->expr); + if (!expr_ty) { + return nullptr; } - std::string type_name; - const sem::Type* type = nullptr; + const sem::Type* ty = nullptr; switch (unary->op) { case ast::UnaryOp::kNot: // Result type matches the deref'd inner type. - type_name = TypeNameOf(unary->expr); - type = expr_type->UnwrapRef(); - if (!type->Is() && !type->is_bool_vector()) { - AddError("cannot logical negate expression of type '" + - TypeNameOf(unary->expr), - unary->expr->source); - return false; + ty = expr_ty->UnwrapRef(); + if (!ty->Is() && !ty->is_bool_vector()) { + AddError( + "cannot logical negate expression of type '" + TypeNameOf(expr_ty), + unary->expr->source); + return nullptr; } break; case ast::UnaryOp::kComplement: // Result type matches the deref'd inner type. - type_name = TypeNameOf(unary->expr); - type = expr_type->UnwrapRef(); - if (!type->is_integer_scalar_or_vector()) { + ty = expr_ty->UnwrapRef(); + if (!ty->is_integer_scalar_or_vector()) { AddError("cannot bitwise complement expression of type '" + - TypeNameOf(unary->expr), + TypeNameOf(expr_ty), unary->expr->source); - return false; + return nullptr; } break; case ast::UnaryOp::kNegation: // Result type matches the deref'd inner type. - type_name = TypeNameOf(unary->expr); - type = expr_type->UnwrapRef(); - if (!(type->IsAnyOf() || - type->is_signed_integer_vector() || type->is_float_vector())) { - AddError("cannot negate expression of type '" + TypeNameOf(unary->expr), + ty = expr_ty->UnwrapRef(); + if (!(ty->IsAnyOf() || + ty->is_signed_integer_vector() || ty->is_float_vector())) { + AddError("cannot negate expression of type '" + TypeNameOf(expr_ty), unary->expr->source); - return false; + return nullptr; } break; case ast::UnaryOp::kAddressOf: - if (auto* ref = expr_type->As()) { + if (auto* ref = expr_ty->As()) { if (ref->StoreType()->UnwrapRef()->is_handle()) { AddError( "cannot take the address of expression in handle storage class", unary->expr->source); - return false; + return nullptr; } auto* array = unary->expr->As(); @@ -3434,48 +3462,48 @@ bool Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { TypeOf(member->structure)->UnwrapRef()->Is())) { AddError("cannot take the address of a vector component", unary->expr->source); - return false; + return nullptr; } - type = builder_->create( - ref->StoreType(), ref->StorageClass(), ref->Access()); + ty = builder_->create(ref->StoreType(), + ref->StorageClass(), ref->Access()); } else { AddError("cannot take the address of expression", unary->expr->source); - return false; + return nullptr; } break; case ast::UnaryOp::kIndirection: - if (auto* ptr = expr_type->As()) { - type = builder_->create( + if (auto* ptr = expr_ty->As()) { + ty = builder_->create( ptr->StoreType(), ptr->StorageClass(), ptr->Access()); } else { AddError("cannot dereference expression of type '" + - TypeNameOf(unary->expr) + "'", + TypeNameOf(expr_ty) + "'", unary->expr->source); - return false; + return nullptr; } break; } - SetExprInfo(unary, type); - return true; + auto val = EvaluateConstantValue(unary, ty); + return builder_->create(unary, ty, current_statement_, val); } bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { - const ast::Variable* var = stmt->variable; - Mark(var); + Mark(stmt->variable); - if (!ValidateNoDuplicateDefinition(var->symbol, var->source)) { + if (!ValidateNoDuplicateDefinition(stmt->variable->symbol, + stmt->variable->source)) { return false; } - auto* info = Variable(var, VariableKind::kLocal); - if (!info) { + auto* var = Variable(stmt->variable, VariableKind::kLocal); + if (!var) { return false; } - for (auto* deco : var->decorations) { + for (auto* deco : stmt->variable->decorations) { Mark(deco); if (!deco->Is()) { AddError("decorations are not valid on local variables", deco->source); @@ -3483,38 +3511,12 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { } } - variable_stack_.Set(var->symbol, info); + variable_stack_.Set(stmt->variable->symbol, var); if (current_block_) { // Not all statements are inside a block - current_block_->AddDecl(var); + current_block_->AddDecl(stmt->variable); } - if (!ValidateVariable(info)) { - return false; - } - - if (!var->is_const && - IsValidationEnabled(var->decorations, - ast::DisabledValidation::kIgnoreStorageClass)) { - if (!info->type->UnwrapRef()->IsConstructible()) { - AddError("function variable must have a constructible type", - var->type ? var->type->source : var->source); - return false; - } - if (info->storage_class != ast::StorageClass::kFunction) { - if (info->storage_class != ast::StorageClass::kNone) { - AddError("function variable has a non-function storage class", - stmt->source); - return false; - } - info->storage_class = ast::StorageClass::kFunction; - } - } - - if (!ApplyStorageClassUsageToType(info->storage_class, info->type, - var->source)) { - AddNote("while instantiating variable " + - builder_->Symbols().NameFor(var->symbol), - var->source); + if (!ValidateVariable(var)) { return false; } @@ -3563,19 +3565,16 @@ bool Resolver::ValidateTypeDecl(const ast::TypeDecl* named_type) const { } sem::Type* Resolver::TypeOf(const ast::Expression* expr) { - auto it = expr_info_.find(expr); - if (it != expr_info_.end()) { - return const_cast(it->second.type); - } - return nullptr; + auto* sem = Sem(expr); + return sem ? const_cast(sem->Type()) : nullptr; } -std::string Resolver::TypeNameOf(const ast::Expression* expr) { - auto it = expr_info_.find(expr); - if (it != expr_info_.end()) { - return it->second.type_name; - } - return ""; +std::string Resolver::TypeNameOf(const sem::Type* ty) { + return RawTypeNameOf(ty->UnwrapRef()); +} + +std::string Resolver::RawTypeNameOf(const sem::Type* ty) { + return ty->FriendlyName(builder_->Symbols()); } sem::Type* Resolver::TypeOf(const ast::Literal* lit) { @@ -3596,56 +3595,37 @@ sem::Type* Resolver::TypeOf(const ast::Literal* lit) { return nullptr; } -void Resolver::SetExprInfo(const ast::Expression* expr, - const sem::Type* type, - std::string type_name) { - if (expr_info_.count(expr)) { - TINT_ICE(Resolver, diagnostics_) - << "SetExprInfo() called twice for the same expression"; - } - if (type_name.empty()) { - type_name = type->FriendlyName(builder_->Symbols()); - } - auto constant_value = EvaluateConstantValue(expr, type); - expr_info_.emplace( - expr, ExpressionInfo{type, std::move(type_name), current_statement_, - std::move(constant_value)}); -} - bool Resolver::ValidatePipelineStages() { - auto check_workgroup_storage = [&](FunctionInfo* func, - FunctionInfo* entry_point) { - auto stage = entry_point->declaration->PipelineStage(); + auto check_workgroup_storage = [&](const sem::Function* func, + const sem::Function* entry_point) { + auto stage = entry_point->Declaration()->PipelineStage(); if (stage != ast::PipelineStage::kCompute) { - for (auto* var : func->local_referenced_module_vars) { - if (var->storage_class == ast::StorageClass::kWorkgroup) { + for (auto* var : func->DirectlyReferencedGlobals()) { + if (var->StorageClass() == ast::StorageClass::kWorkgroup) { std::stringstream stage_name; stage_name << stage; - for (auto* user : var->users) { - auto it = expr_info_.find(user->As()); - if (it != expr_info_.end()) { - if (func->declaration->symbol == - it->second.statement->Function()->symbol) { - AddError("workgroup memory cannot be used by " + - stage_name.str() + " pipeline stage", - user->source); - break; - } + for (auto* user : var->Users()) { + if (func == user->Stmt()->Function()) { + AddError("workgroup memory cannot be used by " + + stage_name.str() + " pipeline stage", + user->Declaration()->source); + break; } } - AddNote("variable is declared here", var->declaration->source); + AddNote("variable is declared here", var->Declaration()->source); if (func != entry_point) { - TraverseCallChain(entry_point, func, [&](FunctionInfo* f) { - AddNote("called by function '" + - builder_->Symbols().NameFor(f->declaration->symbol) + - "'", - f->declaration->source); + TraverseCallChain(entry_point, func, [&](const sem::Function* f) { + AddNote( + "called by function '" + + builder_->Symbols().NameFor(f->Declaration()->symbol) + + "'", + f->Declaration()->source); }); AddNote("called by entry point '" + builder_->Symbols().NameFor( - entry_point->declaration->symbol) + + entry_point->Declaration()->symbol) + "'", - entry_point->declaration->source); + entry_point->Declaration()->source); } return false; } @@ -3658,33 +3638,35 @@ bool Resolver::ValidatePipelineStages() { if (!check_workgroup_storage(entry_point, entry_point)) { return false; } - for (auto* func : entry_point->transitive_calls) { + for (auto* func : entry_point->TransitivelyCalledFunctions()) { if (!check_workgroup_storage(func, entry_point)) { return false; } } } - auto check_intrinsic_calls = [&](FunctionInfo* func, - FunctionInfo* entry_point) { - auto stage = entry_point->declaration->PipelineStage(); - for (auto& call : func->intrinsic_calls) { - if (!call.intrinsic->SupportedStages().Contains(stage)) { + auto check_intrinsic_calls = [&](const sem::Function* func, + const sem::Function* entry_point) { + auto stage = entry_point->Declaration()->PipelineStage(); + for (auto* intrinsic : func->DirectlyCalledIntrinsics()) { + if (!intrinsic->SupportedStages().Contains(stage)) { + auto* call = func->FindDirectCallTo(intrinsic); std::stringstream err; err << "built-in cannot be used by " << stage << " pipeline stage"; - AddError(err.str(), call.call->source); + AddError(err.str(), call ? call->Declaration()->source + : func->Declaration()->source); if (func != entry_point) { - TraverseCallChain(entry_point, func, [&](FunctionInfo* f) { + TraverseCallChain(entry_point, func, [&](const sem::Function* f) { AddNote("called by function '" + - builder_->Symbols().NameFor(f->declaration->symbol) + + builder_->Symbols().NameFor(f->Declaration()->symbol) + "'", - f->declaration->source); + f->Declaration()->source); }); AddNote("called by entry point '" + builder_->Symbols().NameFor( - entry_point->declaration->symbol) + + entry_point->Declaration()->symbol) + "'", - entry_point->declaration->source); + entry_point->Declaration()->source); } return false; } @@ -3696,7 +3678,7 @@ bool Resolver::ValidatePipelineStages() { if (!check_intrinsic_calls(entry_point, entry_point)) { return false; } - for (auto* func : entry_point->transitive_calls) { + for (auto* func : entry_point->TransitivelyCalledFunctions()) { if (!check_intrinsic_calls(func, entry_point)) { return false; } @@ -3706,15 +3688,15 @@ bool Resolver::ValidatePipelineStages() { } template -void Resolver::TraverseCallChain(FunctionInfo* from, - FunctionInfo* to, +void Resolver::TraverseCallChain(const sem::Function* from, + const sem::Function* to, CALLBACK&& callback) const { - for (auto* f : from->transitive_calls) { + for (auto* f : from->TransitivelyCalledFunctions()) { if (f == to) { callback(f); return; } - if (f->transitive_calls.contains(to)) { + if (f->TransitivelyCalledFunctions().contains(to)) { TraverseCallChain(f, to, callback); callback(f); return; @@ -3724,127 +3706,6 @@ void Resolver::TraverseCallChain(FunctionInfo* from, << "TraverseCallChain() 'from' does not transitively call 'to'"; } -void Resolver::CreateSemanticNodes() const { - auto& sem = builder_->Sem(); - - // Collate all the 'ancestor_entry_points' - this is a map of function - // symbol to all the entry points that transitively call the function. - std::unordered_map> ancestor_entry_points; - for (auto* entry_point : entry_points_) { - for (auto* call : entry_point->transitive_calls) { - auto& vec = ancestor_entry_points[call->declaration->symbol]; - vec.emplace_back(entry_point->declaration->symbol); - } - } - - // Create semantic nodes for all ast::Variables - std::unordered_map sem_params; - for (auto it : variable_to_info_) { - auto* var = it.first; - auto* info = it.second; - - sem::Variable* sem_var = nullptr; - - if (ast::HasDecoration(var->decorations)) { - // Create a pipeline overridable constant. - sem_var = builder_->create(var, info->type, - info->constant_id); - } else { - switch (info->kind) { - case VariableKind::kGlobal: - sem_var = builder_->create( - var, info->type, info->storage_class, info->access, - info->binding_point); - break; - case VariableKind::kLocal: - sem_var = builder_->create( - var, info->type, info->storage_class, info->access); - break; - case VariableKind::kParameter: { - auto* param = builder_->create( - var, info->index, info->type, info->storage_class, info->access); - sem_var = param; - sem_params.emplace(var, param); - break; - } - } - } - - std::vector users; - for (auto* user : info->users) { - // Create semantic node for the identifier expression if necessary - auto* sem_expr = sem.Get(user); - if (sem_expr == nullptr) { - auto& expr_info = expr_info_.at(user); - auto* type = expr_info.type; - auto* stmt = expr_info.statement; - auto* sem_user = builder_->create( - user, type, stmt, sem_var, expr_info.constant_value); - sem_var->AddUser(sem_user); - sem.Add(user, sem_user); - } else { - auto* sem_user = sem_expr->As(); - if (!sem_user) { - TINT_ICE(Resolver, diagnostics_) << "expected sem::VariableUser, got " - << sem_expr->TypeInfo().name; - } - sem_var->AddUser(sem_user); - } - } - sem.Add(var, sem_var); - } - - auto remap_vars = [&sem](const std::vector& in) { - std::vector out; - out.reserve(in.size()); - for (auto* info : in) { - out.emplace_back(sem.Get(info->declaration)); - } - return out; - }; - - // Create semantic nodes for all ast::Functions - std::unordered_map func_info_to_sem_func; - for (auto it : function_to_info_) { - auto* func = it.first; - auto* info = it.second; - - std::vector parameters; - parameters.reserve(info->parameters.size()); - for (auto* p : info->parameters) { - parameters.emplace_back(sem_params.at(p->declaration)); - } - - auto* sem_func = builder_->create( - info->declaration, info->return_type, parameters, - remap_vars(info->referenced_module_vars), - remap_vars(info->local_referenced_module_vars), info->callsites, - ancestor_entry_points[func->symbol], info->workgroup_size); - func_info_to_sem_func.emplace(info, sem_func); - sem.Add(func, sem_func); - } - - // Create semantic nodes for all ast::CallExpressions - for (auto it : function_calls_) { - auto* call = it.first; - auto info = it.second; - auto* sem_func = func_info_to_sem_func.at(info.function); - sem.Add(call, builder_->create(call, sem_func, info.statement)); - } - - // Create semantic nodes for all remaining expression types - for (auto it : expr_info_) { - auto* expr = it.first; - auto& info = it.second; - if (sem.Get(expr)) { - // Expression has already been assigned a semantic node - continue; - } - sem.Add(expr, builder_->create( - expr, info.type, info.statement, info.constant_value)); - } -} - sem::Array* Resolver::Array(const ast::Array* arr) { auto source = arr->source; @@ -3854,7 +3715,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { } if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize() - AddError(elem_type->FriendlyName(builder_->Symbols()) + + AddError(TypeNameOf(elem_type) + " cannot be used as an element type of an array", source); return nullptr; @@ -3892,13 +3753,14 @@ sem::Array* Resolver::Array(const ast::Array* arr) { // sem::Array uses a size of 0 for a runtime-sized array. uint32_t count = 0; if (auto* count_expr = arr->count) { - if (!Expression(count_expr)) { + auto* count_sem = Expression(count_expr); + if (!count_sem) { return nullptr; } auto size_source = count_expr->source; - auto* ty = TypeOf(count_expr)->UnwrapRef(); + auto* ty = count_sem->Type()->UnwrapRef(); if (!ty->is_integer_scalar()) { AddError("array size must be integer scalar", size_source); return nullptr; @@ -3906,21 +3768,21 @@ sem::Array* Resolver::Array(const ast::Array* arr) { if (auto* ident = count_expr->As()) { // Make sure the identifier is a non-overridable module-scope constant. - VariableInfo* var = variable_stack_.Get(ident->symbol); - if (!var || var->kind != VariableKind::kGlobal || - !var->declaration->is_const) { + auto* var = variable_stack_.Get(ident->symbol); + if (!var || !var->Is() || + !var->Declaration()->is_const) { AddError("array size identifier must be a module-scope constant", size_source); return nullptr; } if (ast::HasDecoration( - var->declaration->decorations)) { + var->Declaration()->decorations)) { AddError("array size expression must not be pipeline-overridable", size_source); return nullptr; } - count_expr = var->declaration->constructor; + count_expr = var->Declaration()->constructor; } else if (!count_expr->Is()) { AddError( "array size expression must be either a literal or a module-scope " @@ -3929,7 +3791,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return nullptr; } - auto count_val = ConstantValueOf(count_expr); + auto count_val = count_sem->ConstantValue(); if (!count_val) { TINT_ICE(Resolver, diagnostics_) << "could not resolve array size expression"; @@ -4128,7 +3990,7 @@ bool Resolver::ValidateLocationDecoration( const Source& source, const bool is_input) { std::string inputs_or_output = is_input ? "inputs" : "output"; - if (current_function_ && current_function_->declaration->PipelineStage() == + if (current_function_ && current_function_->Declaration()->PipelineStage() == ast::PipelineStage::kCompute) { AddError("decoration is not valid for compute shader " + inputs_or_output, location->source); @@ -4136,7 +3998,7 @@ bool Resolver::ValidateLocationDecoration( } if (!type->is_numeric_scalar_or_vector()) { - std::string invalid_type = type->FriendlyName(builder_->Symbols()); + std::string invalid_type = TypeNameOf(type); AddError("cannot apply 'location' attribute to declaration of type '" + invalid_type + "'", source); @@ -4201,7 +4063,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { // Validate member type if (!IsPlain(type)) { - AddError(type->FriendlyName(builder_->Symbols()) + + AddError(TypeNameOf(type) + " cannot be used as the type of a structure member", member->source); return nullptr; @@ -4323,7 +4185,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { } bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) { - auto* func_type = current_function_->return_type; + auto* func_type = current_function_->ReturnType(); auto* ret_type = ret->value ? TypeOf(ret->value)->UnwrapRef() : builder_->create(); @@ -4332,13 +4194,13 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) { AddError( "return statement type must match its function " "return type, returned '" + - ret_type->FriendlyName(builder_->Symbols()) + "', expected '" + - current_function_->return_type_name + "'", + TypeNameOf(ret_type) + "', expected '" + TypeNameOf(func_type) + + "'", ret->source); return false; } - auto* sem = builder_->Sem().Get(ret); + auto* sem = Sem(ret); if (auto* continuing = sem->FindFirstParent()) { AddError("continuing blocks must not contain a return statement", @@ -4353,8 +4215,6 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) { } bool Resolver::Return(const ast::ReturnStatement* ret) { - current_function_->return_statements.push_back(ret); - if (auto* value = ret->value) { if (!Expression(value)) { return false; @@ -4435,8 +4295,8 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { } bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) { - auto* sem = - builder_->create(stmt, current_compound_statement_); + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); builder_->Sem().Add(stmt, sem); return Scope(sem, [&] { if (!Expression(stmt->condition)) { @@ -4464,15 +4324,15 @@ bool Resolver::Assignment(const ast::AssignmentStatement* a) { } bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) { - auto const* rhs_type = TypeOf(a->rhs); + auto const* rhs_ty = TypeOf(a->rhs); if (a->lhs->Is()) { // https://www.w3.org/TR/WGSL/#phony-assignment-section - auto* ty = rhs_type->UnwrapRef(); + auto* ty = rhs_ty->UnwrapRef(); if (!ty->IsConstructible() && !ty->IsAnyOf()) { AddError( - "cannot assign '" + TypeNameOf(a->rhs) + + "cannot assign '" + TypeNameOf(rhs_ty) + "' to '_'. '_' can only be assigned a constructible, pointer, " "texture or sampler type", a->rhs->source); @@ -4482,52 +4342,53 @@ bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) { } // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement - auto const* lhs_type = TypeOf(a->lhs); + auto const* lhs_ty = TypeOf(a->lhs); if (auto* ident = a->lhs->As()) { - if (VariableInfo* var = variable_stack_.Get(ident->symbol)) { - if (var->kind == VariableKind::kParameter) { + if (auto* var = variable_stack_.Get(ident->symbol)) { + if (var->Is()) { AddError("cannot assign to function parameter", a->lhs->source); AddNote("'" + builder_->Symbols().NameFor(ident->symbol) + "' is declared here:", - var->declaration->source); + var->Declaration()->source); return false; } - if (var->declaration->is_const) { + if (var->Declaration()->is_const) { AddError("cannot assign to const", a->lhs->source); AddNote("'" + builder_->Symbols().NameFor(ident->symbol) + "' is declared here:", - var->declaration->source); + var->Declaration()->source); return false; } } } - auto* lhs_ref = lhs_type->As(); + auto* lhs_ref = lhs_ty->As(); if (!lhs_ref) { // LHS is not a reference, so it has no storage. - AddError("cannot assign to value of type '" + TypeNameOf(a->lhs) + "'", + AddError("cannot assign to value of type '" + TypeNameOf(lhs_ty) + "'", a->lhs->source); return false; } - auto* storage_type = lhs_ref->StoreType(); - auto* value_type = rhs_type->UnwrapRef(); // Implicit load of RHS + auto* storage_ty = lhs_ref->StoreType(); + auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS // Value type has to match storage type - if (storage_type != value_type) { - AddError("cannot assign '" + TypeNameOf(a->rhs) + "' to '" + - TypeNameOf(a->lhs) + "'", + if (storage_ty != value_type) { + AddError("cannot assign '" + TypeNameOf(rhs_ty) + "' to '" + + TypeNameOf(lhs_ty) + "'", a->source); return false; } - if (!storage_type->IsConstructible()) { + if (!storage_ty->IsConstructible()) { AddError("storage type of assignment must be constructible", a->source); return false; } if (lhs_ref->Access() == ast::Access::kRead) { - AddError("cannot store into a read-only type '" + TypeNameOf(a->lhs) + "'", - a->source); + AddError( + "cannot store into a read-only type '" + RawTypeNameOf(lhs_ty) + "'", + a->source); return false; } return true; @@ -4537,11 +4398,11 @@ bool Resolver::ValidateNoDuplicateDefinition(Symbol sym, const Source& source, bool check_global_scope_only) { if (check_global_scope_only) { - if (VariableInfo* var = variable_stack_.Get(sym)) { - if (var->kind == VariableKind::kGlobal) { + if (auto* var = variable_stack_.Get(sym)) { + if (var->Is()) { AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'", source); - AddNote("previous definition is here", var->declaration->source); + AddNote("previous definition is here", var->Declaration()->source); return false; } } @@ -4549,14 +4410,14 @@ bool Resolver::ValidateNoDuplicateDefinition(Symbol sym, if (it != symbol_to_function_.end()) { AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'", source); - AddNote("previous definition is here", it->second->declaration->source); + AddNote("previous definition is here", it->second->Declaration()->source); return false; } } else { - if (VariableInfo* var = variable_stack_.Get(sym)) { + if (auto* var = variable_stack_.Get(sym)) { AddError("redefinition of '" + builder_->Symbols().NameFor(sym) + "'", source); - AddNote("previous definition is here", var->declaration->source); + AddNote("previous definition is here", var->Declaration()->source); return false; } } @@ -4592,8 +4453,7 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, for (auto* member : str->Members()) { if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) { std::stringstream err; - err << "while analysing structure member " - << str->FriendlyName(builder_->Symbols()) << "." + err << "while analysing structure member " << TypeNameOf(str) << "." << builder_->Symbols().NameFor(member->Declaration()->symbol); AddNote(err.str(), member->Declaration()->source); return false; @@ -4609,9 +4469,8 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) { std::stringstream err; - err << "Type '" << ty->FriendlyName(builder_->Symbols()) - << "' cannot be used in storage class '" << sc - << "' as it is non-host-shareable"; + err << "Type '" << TypeNameOf(ty) << "' cannot be used in storage class '" + << sc << "' as it is non-host-shareable"; AddError(err.str(), usage); return false; } @@ -4630,10 +4489,10 @@ bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) { variable_stack_.Push(); TINT_DEFER({ - TINT_DEFER(variable_stack_.Pop()); current_block_ = prev_current_block; current_compound_statement_ = prev_current_compound_statement; current_statement_ = prev_current_statement; + variable_stack_.Pop(); }); return callback(); @@ -4671,26 +4530,18 @@ void Resolver::AddNote(const std::string& msg, const Source& source) const { diagnostics_.add_note(diag::System::Resolver, msg, source); } -Resolver::VariableInfo::VariableInfo(const ast::Variable* decl, - sem::Type* ty, - const std::string& tn, - ast::StorageClass sc, - ast::Access ac, - VariableKind k, - uint32_t idx) - : declaration(decl), - type(ty), - type_name(tn), - storage_class(sc), - access(ac), - kind(k), - index(idx) {} - -Resolver::VariableInfo::~VariableInfo() = default; - -Resolver::FunctionInfo::FunctionInfo(const ast::Function* decl) - : declaration(decl) {} -Resolver::FunctionInfo::~FunctionInfo() = default; +template +const sem::Info::GetResultType* Resolver::Sem( + const AST_OR_TYPE* ast) { + auto* sem = builder_->Sem().Get(ast); + if (!sem) { + TINT_ICE(Resolver, diagnostics_) + << "AST node '" << ast->TypeInfo().name << "' had no semantic info\n" + << "At: " << ast->source << "\n" + << "Pointer: " << ast; + } + return sem; +} } // namespace resolver } // namespace tint diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index edd8759a05..5730a55356 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -95,79 +95,9 @@ class Resolver { /// Describes the context in which a variable is declared enum class VariableKind { kParameter, kLocal, kGlobal }; - /// Structure holding semantic information about a variable. - /// Used to build the sem::Variable nodes at the end of resolving. - struct VariableInfo { - VariableInfo(const ast::Variable* decl, - sem::Type* type, - const std::string& type_name, - ast::StorageClass storage_class, - ast::Access ac, - VariableKind k, - uint32_t idx); - ~VariableInfo(); - - ast::Variable const* const declaration; - sem::Type* type; - std::string const type_name; - ast::StorageClass storage_class; - ast::Access const access; - std::vector users; - sem::BindingPoint binding_point; - VariableKind kind; - uint32_t index = 0; // Parameter index, if kind == kParameter - uint16_t constant_id = 0; - }; - - struct IntrinsicCallInfo { - const ast::CallExpression* call; - const sem::Intrinsic* intrinsic; - }; - std::set> valid_struct_storage_layouts_; - /// Structure holding semantic information about a function. - /// Used to build the sem::Function nodes at the end of resolving. - struct FunctionInfo { - explicit FunctionInfo(const ast::Function* decl); - ~FunctionInfo(); - - const ast::Function* const declaration; - std::vector parameters; - utils::UniqueVector referenced_module_vars; - utils::UniqueVector local_referenced_module_vars; - std::vector return_statements; - std::vector callsites; - sem::Type* return_type = nullptr; - std::string return_type_name; - std::array workgroup_size; - std::vector intrinsic_calls; - - // List of transitive calls this function makes - utils::UniqueVector transitive_calls; - - // List of entry point functions that transitively call this function - utils::UniqueVector ancestor_entry_points; - }; - - /// Structure holding semantic information about an expression. - /// Used to build the sem::Expression nodes at the end of resolving. - struct ExpressionInfo { - sem::Type const* type; - std::string const type_name; // Declared type name - sem::Statement* statement; - sem::Constant constant_value; - }; - - /// Structure holding semantic information about a call expression to an - /// ast::Function. - /// Used to build the sem::Call nodes at the end of resolving. - struct FunctionCallInfo { - FunctionInfo* function; - sem::Statement* statement; - }; - /// Structure holding semantic information about a block (i.e. scope), such as /// parent block and variables declared in the block. /// Used to validate variable scoping rules. @@ -231,35 +161,40 @@ class Resolver { const ast::ExpressionList& params, uint32_t* id); - void set_referenced_from_function_if_needed(VariableInfo* var, bool local); - + ////////////////////////////////////////////////////////////////////////////// // AST and Type traversal methods + ////////////////////////////////////////////////////////////////////////////// + + // Expression resolving methods + // Returns the semantic node pointer on success, nullptr on failure. + sem::Expression* ArrayAccessor(const ast::ArrayAccessorExpression*); + sem::Expression* Binary(const ast::BinaryExpression*); + sem::Expression* Bitcast(const ast::BitcastExpression*); + sem::Expression* Call(const ast::CallExpression*); + sem::Expression* Constructor(const ast::ConstructorExpression*); + sem::Expression* Expression(const ast::Expression*); + sem::Function* Function(const ast::Function*); + sem::Call* FunctionCall(const ast::CallExpression*); + sem::Expression* Identifier(const ast::IdentifierExpression*); + sem::Call* IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType); + sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*); + sem::Expression* UnaryOp(const ast::UnaryOpExpression*); + + // Statement resolving methods // Each return true on success, false on failure. - bool ArrayAccessor(const ast::ArrayAccessorExpression*); bool Assignment(const ast::AssignmentStatement* a); - bool Binary(const ast::BinaryExpression*); - bool Bitcast(const ast::BitcastExpression*); bool BlockStatement(const ast::BlockStatement*); - bool Call(const ast::CallExpression*); bool CaseStatement(const ast::CaseStatement*); - bool Constructor(const ast::ConstructorExpression*); bool ElseStatement(const ast::ElseStatement*); - bool Expression(const ast::Expression*); bool ForLoopStatement(const ast::ForLoopStatement*); - bool Function(const ast::Function*); - bool FunctionCall(const ast::CallExpression* call); - bool GlobalVariable(const ast::Variable* var); - bool Identifier(const ast::IdentifierExpression*); - bool IfStatement(const ast::IfStatement*); - bool IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType); - bool LoopStatement(const ast::LoopStatement*); - bool MemberAccessor(const ast::MemberAccessorExpression*); bool Parameter(const ast::Variable* param); + bool GlobalVariable(const ast::Variable* var); + bool IfStatement(const ast::IfStatement*); + bool LoopStatement(const ast::LoopStatement*); bool Return(const ast::ReturnStatement* ret); bool Statement(const ast::Statement*); bool Statements(const ast::StatementList&); bool SwitchStatement(const ast::SwitchStatement* s); - bool UnaryOp(const ast::UnaryOpExpression*); bool VariableDeclStatement(const ast::VariableDeclStatement*); // AST and Type validation methods @@ -270,18 +205,16 @@ class Resolver { uint32_t el_align, const Source& source); bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s); - bool ValidateAtomicVariable(const VariableInfo* info); + bool ValidateAtomicVariable(const sem::Variable* var); bool ValidateAssignment(const ast::AssignmentStatement* a); bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, const sem::Type* storage_type, const bool is_input); - bool ValidateCall(const ast::CallExpression* call); - bool ValidateCallStatement(const ast::CallStatement* stmt); - bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); - bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); - bool ValidateFunctionCall(const ast::CallExpression* call, - const FunctionInfo* target); - bool ValidateGlobalVariable(const VariableInfo* var); + bool ValidateCall(const sem::Call* call); + bool ValidateEntryPoint(const sem::Function* func); + bool ValidateFunction(const sem::Function* func); + bool ValidateFunctionCall(const sem::Call* call); + bool ValidateGlobalVariable(const sem::Variable* var); bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, const sem::Type* storage_type); bool ValidateLocationDecoration(const ast::LocationDecoration* location, @@ -291,11 +224,11 @@ class Resolver { const bool is_input = false); bool ValidateMatrix(const sem::Matrix* ty, const Source& source); bool ValidateFunctionParameter(const ast::Function* func, - const VariableInfo* info); + const sem::Variable* var); bool ValidateNoDuplicateDefinition(Symbol sym, const Source& source, bool check_global_scope_only = false); - bool ValidateParameter(const ast::Function* func, const VariableInfo* info); + bool ValidateParameter(const ast::Function* func, const sem::Variable* var); bool ValidateReturn(const ast::ReturnStatement* ret); bool ValidateStatements(const ast::StatementList& stmts); bool ValidateStorageTexture(const ast::StorageTexture* t); @@ -303,33 +236,30 @@ class Resolver { bool ValidateStructureConstructor(const ast::TypeConstructorExpression* ctor, const sem::Struct* struct_type); bool ValidateSwitch(const ast::SwitchStatement* s); - bool ValidateVariable(const VariableInfo* info); + bool ValidateVariable(const sem::Variable* var); bool ValidateVariableConstructor(const ast::Variable* var, ast::StorageClass storage_class, const sem::Type* storage_type, - const std::string& type_name, - const sem::Type* rhs_type, - const std::string& rhs_type_name); + const sem::Type* rhs_type); bool ValidateVector(const sem::Vector* ty, const Source& source); bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor, - const sem::Vector* vec_type, - const std::string& type_name); + const sem::Vector* vec_type); bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor, - const sem::Matrix* matrix_type, - const std::string& type_name); + const sem::Matrix* matrix_type); bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor, - const sem::Type* type, - const std::string& type_name); + const sem::Type* type); bool ValidateArrayConstructor(const ast::TypeConstructorExpression* ctor, const sem::Array* arr_type); bool ValidateTypeDecl(const ast::TypeDecl* named_type) const; - bool ValidateTextureIntrinsicFunction(const ast::CallExpression* ast_call, - const sem::Call* sem_call); + bool ValidateTextureIntrinsicFunction(const sem::Call* call); bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations); // sem::Struct is assumed to have at least one member bool ValidateStorageClassLayout(const sem::Struct* type, ast::StorageClass sc); - bool ValidateStorageClassLayout(const VariableInfo* info); + bool ValidateStorageClassLayout(const sem::Variable* var); + + /// Resolves the WorkgroupSize for the given function + bool WorkgroupSizeFor(const ast::Function*, sem::WorkgroupSize& ws); /// @returns the sem::Type for the ast::Type `ty`, building it if it /// hasn't been constructed already. If an error is raised, nullptr is @@ -355,16 +285,16 @@ class Resolver { /// raised. raised, nullptr is returned. sem::Struct* Structure(const ast::Struct* str); - /// @returns the VariableInfo for the variable `var`, building it if it hasn't - /// been constructed already. If an error is raised, nullptr is returned. + /// @returns the semantic info for the variable `var`. If an error is raised, + /// nullptr is returned. /// @note this method does not resolve the decorations as these are /// context-dependent (global, local, parameter) /// @param var the variable to create or return the `VariableInfo` for /// @param kind what kind of variable we are declaring /// @param index the index of the parameter, if this variable is a parameter - VariableInfo* Variable(const ast::Variable* var, - VariableKind kind, - uint32_t index = 0); + sem::Variable* Variable(const ast::Variable* var, + VariableKind kind, + uint32_t index = 0); /// Records the storage class usage for the given type, and any transient /// dependencies of the type. Validates that the type can be used for the @@ -389,23 +319,17 @@ class Resolver { /// @param expr the expression sem::Type* TypeOf(const ast::Expression* expr); - /// @returns the declared type name of the ast::Expression `expr` - /// @param expr the type name - std::string TypeNameOf(const ast::Expression* expr); + /// @returns the type name of the given semantic type, unwrapping references. + std::string TypeNameOf(const sem::Type* ty); + + /// @returns the type name of the given semantic type, without unwrapping + /// references. + std::string RawTypeNameOf(const sem::Type* ty); /// @returns the semantic type of the AST literal `lit` /// @param lit the literal sem::Type* TypeOf(const ast::Literal* lit); - /// Records the semantic information for the expression node with the resolved - /// type `type` and optional declared type name `type_name`. - /// @param expr the expression - /// @param type the resolved type - /// @param type_name the declared type name - void SetExprInfo(const ast::Expression* expr, - const sem::Type* type, - std::string type_name = ""); - /// Assigns `stmt` to #current_statement_, #current_compound_statement_, and /// possibly #current_block_, pushes the variable scope, then calls /// `callback`. Before returning #current_statement_, @@ -437,16 +361,13 @@ class Resolver { void AddNote(const std::string& msg, const Source& source) const; template - void TraverseCallChain(FunctionInfo* from, - FunctionInfo* to, + void TraverseCallChain(const sem::Function* from, + const sem::Function* to, CALLBACK&& callback) const; ////////////////////////////////////////////////////////////////////////////// /// Constant value evaluation methods ////////////////////////////////////////////////////////////////////////////// - /// @return the Constant value of the given Expression - sem::Constant ConstantValueOf(const ast::Expression* expr); - /// Cast `Value` to `target_type` /// @return the casted value sem::Constant ConstantCast(const sem::Constant& value, @@ -461,29 +382,27 @@ class Resolver { const ast::TypeConstructorExpression* type_ctor, const sem::Type* type); + /// Sem is a helper for obtaining the semantic node for the given AST node. + template + const sem::Info::GetResultType* Sem(const AST_OR_TYPE* ast); + ProgramBuilder* const builder_; diag::List& diagnostics_; std::unique_ptr const intrinsic_table_; - ScopeStack variable_stack_; - std::unordered_map symbol_to_function_; - std::vector entry_points_; + ScopeStack variable_stack_; + std::unordered_map symbol_to_function_; + std::vector entry_points_; std::unordered_map atomic_composite_info_; - std::unordered_map function_to_info_; - std::unordered_map variable_to_info_; - std::unordered_map - function_calls_; - std::unordered_map expr_info_; std::unordered_map named_type_info_; std::unordered_set marked_; - std::unordered_map constant_ids_; + std::unordered_map constant_ids_; - FunctionInfo* current_function_ = nullptr; + sem::Function* current_function_ = nullptr; sem::Statement* current_statement_ = nullptr; sem::CompoundStatement* current_compound_statement_ = nullptr; sem::BlockStatement* current_block_ = nullptr; - BlockAllocator variable_infos_; - BlockAllocator function_infos_; }; } // namespace resolver diff --git a/src/resolver/resolver_constants.cc b/src/resolver/resolver_constants.cc index e28ed038d9..fb59ff3895 100644 --- a/src/resolver/resolver_constants.cc +++ b/src/resolver/resolver_constants.cc @@ -15,6 +15,7 @@ #include "src/resolver/resolver.h" #include "src/sem/constant.h" +#include "src/utils/get_or_create.h" namespace tint { namespace resolver { @@ -26,46 +27,6 @@ using f32 = ProgramBuilder::f32; } // namespace -sem::Constant Resolver::ConstantCast(const sem::Constant& value, - const sem::Type* target_elem_type) { - if (value.ElementType() == target_elem_type) { - return value; - } - - sem::Constant::Scalars elems; - for (size_t i = 0; i < value.Elements().size(); ++i) { - if (target_elem_type->Is()) { - elems.emplace_back( - value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); - } else if (target_elem_type->Is()) { - elems.emplace_back( - value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); - } else if (target_elem_type->Is()) { - elems.emplace_back( - value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); - } else if (target_elem_type->Is()) { - elems.emplace_back( - value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); - } - } - - auto* target_type = - value.Type()->Is() - ? builder_->create(target_elem_type, - static_cast(elems.size())) - : target_elem_type; - - return sem::Constant(target_type, elems); -} - -sem::Constant Resolver::ConstantValueOf(const ast::Expression* expr) { - auto it = expr_info_.find(expr); - if (it != expr_info_.end()) { - return it->second.constant_value; - } - return {}; -} - sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) { if (auto* e = expr->As()) { @@ -131,11 +92,11 @@ sem::Constant Resolver::EvaluateConstantValue( // type_ctor's type. sem::Constant::Scalars elems; for (auto* cv : ctor_values) { - auto value = ConstantValueOf(cv); - if (!value.IsValid()) { + auto* expr = builder_->Sem().Get(cv); + if (!expr || !expr->ConstantValue()) { return {}; } - auto cast = ConstantCast(value, elem_type); + auto cast = ConstantCast(expr->ConstantValue(), elem_type); elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end()); } @@ -149,5 +110,37 @@ sem::Constant Resolver::EvaluateConstantValue( return sem::Constant(type, std::move(elems)); } +sem::Constant Resolver::ConstantCast(const sem::Constant& value, + const sem::Type* target_elem_type) { + if (value.ElementType() == target_elem_type) { + return value; + } + + sem::Constant::Scalars elems; + for (size_t i = 0; i < value.Elements().size(); ++i) { + if (target_elem_type->Is()) { + elems.emplace_back( + value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); + } else if (target_elem_type->Is()) { + elems.emplace_back( + value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); + } else if (target_elem_type->Is()) { + elems.emplace_back( + value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); + } else if (target_elem_type->Is()) { + elems.emplace_back( + value.WithScalarAt(i, [](auto&& s) { return static_cast(s); })); + } + } + + auto* target_type = + value.Type()->Is() + ? builder_->create(target_elem_type, + static_cast(elems.size())) + : target_elem_type; + + return sem::Constant(target_type, elems); +} + } // namespace resolver } // namespace tint diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index ddf4d36d4d..99f206acac 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -922,8 +922,8 @@ TEST_F(ResolverTest, Function_CallSites) { auto* foo_sem = Sem().Get(foo); ASSERT_NE(foo_sem, nullptr); ASSERT_EQ(foo_sem->CallSites().size(), 2u); - EXPECT_EQ(foo_sem->CallSites()[0], call_1); - EXPECT_EQ(foo_sem->CallSites()[1], call_2); + EXPECT_EQ(foo_sem->CallSites()[0]->Declaration(), call_1); + EXPECT_EQ(foo_sem->CallSites()[1]->Declaration(), call_2); auto* bar_sem = Sem().Get(bar); ASSERT_NE(bar_sem, nullptr); @@ -1908,17 +1908,17 @@ TEST_F(ResolverTest, Function_EntryPoints_StageDecoration) { const auto& b_eps = func_b_sem->AncestorEntryPoints(); ASSERT_EQ(2u, b_eps.size()); - EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]); - EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]); + EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]->Declaration()->symbol); + EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]->Declaration()->symbol); const auto& a_eps = func_a_sem->AncestorEntryPoints(); ASSERT_EQ(1u, a_eps.size()); - EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]); + EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]->Declaration()->symbol); const auto& c_eps = func_c_sem->AncestorEntryPoints(); ASSERT_EQ(2u, c_eps.size()); - EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]); - EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]); + EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]->Declaration()->symbol); + EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]->Declaration()->symbol); EXPECT_TRUE(ep_1_sem->AncestorEntryPoints().empty()); EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty()); diff --git a/src/resolver/storage_class_layout_validation_test.cc b/src/resolver/storage_class_layout_validation_test.cc index 5339f13cf3..3ad7d3f1fc 100644 --- a/src/resolver/storage_class_layout_validation_test.cc +++ b/src/resolver/storage_class_layout_validation_test.cc @@ -179,11 +179,11 @@ TEST_F(ResolverStorageClassLayoutValidationTest, ASSERT_FALSE(r()->Resolve()); EXPECT_EQ( r()->error(), - R"(56:78 error: the offset of a struct member of type 'Inner' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting [[align(16)]] on this member + R"(56:78 error: the offset of a struct member of type '[[stride(16)]] array' in storage class 'uniform' must be a multiple of 16 bytes, but 'inner' is currently at offset 4. Consider setting [[align(16)]] on this member 12:34 note: see layout of struct: /* align(4) size(164) */ struct Outer { /* offset( 0) align(4) size( 4) */ scalar : f32; -/* offset( 4) align(4) size(160) */ inner : Inner; +/* offset( 4) align(4) size(160) */ inner : [[stride(16)]] array; /* */ }; 78:90 note: see declaration of variable)"); } @@ -351,7 +351,7 @@ TEST_F(ResolverStorageClassLayoutValidationTest, R"(34:56 error: uniform storage requires that array elements be aligned to 16 bytes, but array stride of 'inner' is currently 8. Consider setting [[stride(16)]] on the array type 12:34 note: see layout of struct: /* align(4) size(84) */ struct Outer { -/* offset( 0) align(4) size(80) */ inner : Inner; +/* offset( 0) align(4) size(80) */ inner : [[stride(8)]] array; /* offset(80) align(4) size( 4) */ scalar : i32; /* */ }; 78:90 note: see declaration of variable)"); diff --git a/src/resolver/storage_class_validation_test.cc b/src/resolver/storage_class_validation_test.cc index 9a60ac4be9..f902a9efb6 100644 --- a/src/resolver/storage_class_validation_test.cc +++ b/src/resolver/storage_class_validation_test.cc @@ -96,7 +96,8 @@ TEST_F(ResolverStorageClassValidationTest, StorageBufferBoolAlias) { EXPECT_EQ( r()->error(), - R"(56:78 error: variables declared in the storage class must be of a structure type)"); + R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable +56:78 note: while instantiating variable g)"); } TEST_F(ResolverStorageClassValidationTest, NotStorage_AccessMode) { @@ -194,7 +195,8 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferBool) { EXPECT_EQ( r()->error(), - R"(56:78 error: variables declared in the storage class must be of a structure type)"); + R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable +56:78 note: while instantiating variable g)"); } TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) { @@ -243,7 +245,8 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferBoolAlias) { EXPECT_EQ( r()->error(), - R"(56:78 error: variables declared in the storage class must be of a structure type)"); + R"(56:78 error: Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable +56:78 note: while instantiating variable g)"); } TEST_F(ResolverStorageClassValidationTest, UniformBufferNoBlockDecoration) { diff --git a/src/resolver/type_constructor_validation_test.cc b/src/resolver/type_constructor_validation_test.cc index d72cd82b57..c369f0bf6c 100644 --- a/src/resolver/type_constructor_validation_test.cc +++ b/src/resolver/type_constructor_validation_test.cc @@ -1553,7 +1553,7 @@ TEST_F(ResolverTypeConstructorValidationTest, EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: type in vector constructor does not match vector " - "type: expected 'f32', found 'UnsignedInt'"); + "type: expected 'f32', found 'u32'"); } TEST_F(ResolverTypeConstructorValidationTest, @@ -1638,10 +1638,9 @@ struct MatrixDimensions { uint32_t columns; }; -static std::string MatrixStr(const MatrixDimensions& dimensions, - std::string subtype = "f32") { +static std::string MatrixStr(const MatrixDimensions& dimensions) { return "mat" + std::to_string(dimensions.columns) + "x" + - std::to_string(dimensions.rows) + "<" + subtype + ">"; + std::to_string(dimensions.rows) + ""; } using MatrixConstructorTest = ResolverTestWithParam; @@ -1919,9 +1918,9 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) { WrapInFunction(tc); EXPECT_FALSE(r()->Resolve()); - EXPECT_THAT(r()->error(), HasSubstr("12:1 error: invalid constructor for " + - MatrixStr(param, "Float32") + - "\n\n3 candidates available:")); + EXPECT_THAT(r()->error(), + HasSubstr("12:1 error: invalid constructor for " + + MatrixStr(param) + "\n\n3 candidates available:")); } TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) { diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 55a9c1f136..bdad5fef50 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -499,9 +499,10 @@ TEST_F(ResolverValidationTest, EXpr_MemberAccessor_FuncBadParent) { Func("func", {p}, ty.f32(), {Decl(x), Return(x)}); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "error: invalid member accessor expression. Expected vector or " - "struct, got 'ptr>'"); + EXPECT_EQ( + r()->error(), + "error: invalid member accessor expression. " + "Expected vector or struct, got 'ptr, read_write>'"); } TEST_F(ResolverValidationTest, diff --git a/src/resolver/var_let_validation_test.cc b/src/resolver/var_let_validation_test.cc index 762ac007bb..48469ada49 100644 --- a/src/resolver/var_let_validation_test.cc +++ b/src/resolver/var_let_validation_test.cc @@ -120,7 +120,7 @@ TEST_F(ResolverVarLetValidationTest, LetConstructorWrongTypeViaAlias) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ( r()->error(), - R"(3:3 error: cannot initialize let of type 'I32' with value of type 'u32')"); + R"(3:3 error: cannot initialize let of type 'i32' with value of type 'u32')"); } TEST_F(ResolverVarLetValidationTest, VarConstructorWrongTypeViaAlias) { @@ -131,7 +131,7 @@ TEST_F(ResolverVarLetValidationTest, VarConstructorWrongTypeViaAlias) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ( r()->error(), - R"(3:3 error: cannot initialize var of type 'I32' with value of type 'u32')"); + R"(3:3 error: cannot initialize var of type 'i32' with value of type 'u32')"); } TEST_F(ResolverVarLetValidationTest, LetOfPtrConstructedWithRef) { @@ -147,7 +147,7 @@ TEST_F(ResolverVarLetValidationTest, LetOfPtrConstructedWithRef) { EXPECT_EQ( r()->error(), - R"(12:34 error: cannot initialize let of type 'ptr' with value of type 'f32')"); + R"(12:34 error: cannot initialize let of type 'ptr' with value of type 'f32')"); } TEST_F(ResolverVarLetValidationTest, LocalVarRedeclared) { diff --git a/src/sem/block_statement.cc b/src/sem/block_statement.cc index c0dbd68e8a..89a0bf9f7b 100644 --- a/src/sem/block_statement.cc +++ b/src/sem/block_statement.cc @@ -26,8 +26,9 @@ namespace tint { namespace sem { BlockStatement::BlockStatement(const ast::BlockStatement* declaration, - const CompoundStatement* parent) - : Base(declaration, parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) {} BlockStatement::~BlockStatement() = default; @@ -39,15 +40,19 @@ void BlockStatement::AddDecl(const ast::Variable* var) { decls_.push_back(var); } -FunctionBlockStatement::FunctionBlockStatement(const ast::Function* function) - : Base(function->body, nullptr), function_(function) {} +FunctionBlockStatement::FunctionBlockStatement(const sem::Function* function) + : Base(function->Declaration()->body, nullptr, function) { + TINT_ASSERT(Semantic, function); +} FunctionBlockStatement::~FunctionBlockStatement() = default; LoopBlockStatement::LoopBlockStatement(const ast::BlockStatement* declaration, - const CompoundStatement* parent) - : Base(declaration, parent) { + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) { TINT_ASSERT(Semantic, parent); + TINT_ASSERT(Semantic, function); } LoopBlockStatement::~LoopBlockStatement() = default; diff --git a/src/sem/block_statement.h b/src/sem/block_statement.h index 6ab6403c0f..bc96e7d650 100644 --- a/src/sem/block_statement.h +++ b/src/sem/block_statement.h @@ -40,8 +40,10 @@ class BlockStatement : public Castable { /// Constructor /// @param declaration the AST node for this block statement /// @param parent the owning statement + /// @param function the owning function BlockStatement(const ast::BlockStatement* declaration, - const CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~BlockStatement() override; @@ -67,16 +69,10 @@ class FunctionBlockStatement public: /// Constructor /// @param function the owning function - explicit FunctionBlockStatement(const ast::Function* function); + explicit FunctionBlockStatement(const sem::Function* function); /// Destructor ~FunctionBlockStatement() override; - - /// @returns the function owning this block - const ast::Function* Function() const { return function_; } - - private: - ast::Function const* const function_; }; /// Holds semantic information about a loop body block or for-loop body block @@ -85,8 +81,10 @@ class LoopBlockStatement : public Castable { /// Constructor /// @param declaration the AST node for this block statement /// @param parent the owning statement + /// @param function the owning function LoopBlockStatement(const ast::BlockStatement* declaration, - const CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~LoopBlockStatement() override; diff --git a/src/sem/call.cc b/src/sem/call.cc index 774056d2f9..ac95e78d42 100644 --- a/src/sem/call.cc +++ b/src/sem/call.cc @@ -14,16 +14,21 @@ #include "src/sem/call.h" +#include +#include + TINT_INSTANTIATE_TYPEINFO(tint::sem::Call); namespace tint { namespace sem { -Call::Call(const ast::Expression* declaration, +Call::Call(const ast::CallExpression* declaration, const CallTarget* target, + std::vector arguments, Statement* statement) : Base(declaration, target->ReturnType(), statement, Constant{}), - target_(target) {} + target_(target), + arguments_(std::move(arguments)) {} Call::~Call() = default; diff --git a/src/sem/call.h b/src/sem/call.h index d2fdb314b4..e7ce6bd07e 100644 --- a/src/sem/call.h +++ b/src/sem/call.h @@ -15,6 +15,8 @@ #ifndef SRC_SEM_CALL_H_ #define SRC_SEM_CALL_H_ +#include + #include "src/sem/expression.h" #include "src/sem/intrinsic.h" @@ -28,9 +30,11 @@ class Call : public Castable { /// Constructor /// @param declaration the AST node /// @param target the call target + /// @param arguments the call arguments /// @param statement the statement that owns this expression - Call(const ast::Expression* declaration, + Call(const ast::CallExpression* declaration, const CallTarget* target, + std::vector arguments, Statement* statement); /// Destructor @@ -39,8 +43,19 @@ class Call : public Castable { /// @return the target of the call const CallTarget* Target() const { return target_; } + /// @return the call arguments + const std::vector& Arguments() const { + return arguments_; + } + + /// @returns the AST node + const ast::CallExpression* Declaration() const { + return static_cast(declaration_); + } + private: CallTarget const* const target_; + std::vector arguments_; }; } // namespace sem diff --git a/src/sem/constant.cc b/src/sem/constant.cc index 3b9711bd19..b39a115fd0 100644 --- a/src/sem/constant.cc +++ b/src/sem/constant.cc @@ -59,5 +59,7 @@ Constant::Constant(const Constant&) = default; Constant::~Constant() = default; +Constant& Constant::operator=(const Constant& rhs) = default; + } // namespace sem } // namespace tint diff --git a/src/sem/constant.h b/src/sem/constant.h index 27895ff2e5..d28ee7b6b3 100644 --- a/src/sem/constant.h +++ b/src/sem/constant.h @@ -77,6 +77,11 @@ class Constant { /// Destructor ~Constant(); + /// Copy assignment + /// @param other the Constant to copy + /// @returns this Constant + Constant& operator=(const Constant& other); + /// @returns true if the Constant has been initialized bool IsValid() const { return type_ != nullptr; } diff --git a/src/sem/expression.h b/src/sem/expression.h index ba001cb15b..e605ef844b 100644 --- a/src/sem/expression.h +++ b/src/sem/expression.h @@ -53,8 +53,11 @@ class Expression : public Castable { /// @returns the AST node const ast::Expression* Declaration() const { return declaration_; } - private: + protected: + /// The AST expression node for this semantic expression const ast::Expression* const declaration_; + + private: const sem::Type* const type_; const Statement* const statement_; const Constant constant_; diff --git a/src/sem/for_loop_statement.cc b/src/sem/for_loop_statement.cc index 59b6738996..db1370760a 100644 --- a/src/sem/for_loop_statement.cc +++ b/src/sem/for_loop_statement.cc @@ -22,8 +22,9 @@ namespace tint { namespace sem { ForLoopStatement::ForLoopStatement(const ast::ForLoopStatement* declaration, - CompoundStatement* parent) - : Base(declaration, parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) {} ForLoopStatement::~ForLoopStatement() = default; diff --git a/src/sem/for_loop_statement.h b/src/sem/for_loop_statement.h index 2ee92e4fed..ff89241440 100644 --- a/src/sem/for_loop_statement.h +++ b/src/sem/for_loop_statement.h @@ -32,8 +32,10 @@ class ForLoopStatement : public Castable { /// Constructor /// @param declaration the AST node for this for-loop statement /// @param parent the owning statement + /// @param function the owning function ForLoopStatement(const ast::ForLoopStatement* declaration, - CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~ForLoopStatement() override; diff --git a/src/sem/function.cc b/src/sem/function.cc index 070938bbca..709dd03dfc 100644 --- a/src/sem/function.cc +++ b/src/sem/function.cc @@ -28,23 +28,13 @@ TINT_INSTANTIATE_TYPEINFO(tint::sem::Function); namespace tint { namespace sem { -Function::Function( - const ast::Function* declaration, - Type* return_type, - std::vector parameters, - std::vector transitively_referenced_globals, - std::vector directly_referenced_globals, - std::vector callsites, - std::vector ancestor_entry_points, - sem::WorkgroupSize workgroup_size) +Function::Function(const ast::Function* declaration, + Type* return_type, + std::vector parameters, + sem::WorkgroupSize workgroup_size) : Base(return_type, utils::ToConstPtrVec(parameters)), declaration_(declaration), - workgroup_size_(std::move(workgroup_size)), - directly_referenced_globals_(std::move(directly_referenced_globals)), - transitively_referenced_globals_( - std::move(transitively_referenced_globals)), - callsites_(callsites), - ancestor_entry_points_(std::move(ancestor_entry_points)) { + workgroup_size_(std::move(workgroup_size)) { for (auto* parameter : parameters) { parameter->SetOwner(this); } @@ -150,8 +140,8 @@ Function::VariableBindings Function::TransitivelyReferencedVariablesOfType( } bool Function::HasAncestorEntryPoint(Symbol symbol) const { - for (const auto& point : ancestor_entry_points_) { - if (point == symbol) { + for (const auto* point : ancestor_entry_points_) { + if (point->Declaration()->symbol == symbol) { return true; } } diff --git a/src/sem/function.h b/src/sem/function.h index c10301d7fd..d8b666281e 100644 --- a/src/sem/function.h +++ b/src/sem/function.h @@ -62,18 +62,10 @@ class Function : public Castable { /// @param declaration the ast::Function /// @param return_type the return type of the function /// @param parameters the parameters to the function - /// @param transitively_referenced_globals the referenced module variables - /// @param directly_referenced_globals the locally referenced module - /// @param callsites the callsites of the function - /// @param ancestor_entry_points the ancestor entry points /// @param workgroup_size the workgroup size Function(const ast::Function* declaration, Type* return_type, std::vector parameters, - std::vector transitively_referenced_globals, - std::vector directly_referenced_globals, - std::vector callsites, - std::vector ancestor_entry_points, sem::WorkgroupSize workgroup_size); /// Destructor @@ -85,22 +77,98 @@ class Function : public Castable { /// @returns the workgroup size {x, y, z} for the function. const sem::WorkgroupSize& WorkgroupSize() const { return workgroup_size_; } + /// @returns all directly referenced global variables + const utils::UniqueVector& DirectlyReferencedGlobals() + const { + return directly_referenced_globals_; + } + + /// Records that this function directly references the given global variable. + /// Note: Implicitly adds this global to the transtively-called globals. + /// @param global the module-scope variable + void AddDirectlyReferencedGlobal(const sem::GlobalVariable* global) { + directly_referenced_globals_.add(global); + transitively_referenced_globals_.add(global); + } + /// @returns all transitively referenced global variables const utils::UniqueVector& TransitivelyReferencedGlobals() const { return transitively_referenced_globals_; } - /// @returns the list of callsites of this function - std::vector CallSites() const { - return callsites_; + /// Records that this function transitively references the given global + /// variable. + /// @param global the module-scoped variable + void AddTransitivelyReferencedGlobal(const sem::GlobalVariable* global) { + transitively_referenced_globals_.add(global); } - /// @returns the names of the ancestor entry points - const std::vector& AncestorEntryPoints() const { + /// @returns the list of functions that this function transitively calls. + const utils::UniqueVector& TransitivelyCalledFunctions() + const { + return transitively_called_functions_; + } + + /// Records that this function transitively calls `function`. + /// @param function the function this function transitively calls + void AddTransitivelyCalledFunction(const Function* function) { + transitively_called_functions_.add(function); + } + + /// @returns the list of intrinsics that this function directly calls. + const utils::UniqueVector& DirectlyCalledIntrinsics() + const { + return directly_called_intrinsics_; + } + + /// Records that this function transitively calls `intrinsic`. + /// @param intrinsic the intrinsic this function directly calls + void AddDirectlyCalledIntrinsic(const Intrinsic* intrinsic) { + directly_called_intrinsics_.add(intrinsic); + } + + /// @returns the list of direct calls to functions / intrinsics made by this + /// function + std::vector DirectCallStatements() const { + return direct_calls_; + } + + /// Adds a record of the direct function / intrinsic calls made by this + /// function + /// @param call the call + void AddDirectCall(const Call* call) { direct_calls_.emplace_back(call); } + + /// @param target the target of a call + /// @returns the Call to the given CallTarget, or nullptr the target was not + /// called by this function. + const Call* FindDirectCallTo(const CallTarget* target) const { + for (auto* call : direct_calls_) { + if (call->Target() == target) { + return call; + } + } + return nullptr; + } + + /// @returns the list of callsites of this function + std::vector CallSites() const { return callsites_; } + + /// Adds a record of a callsite to this function + /// @param call the callsite + void AddCallSite(const Call* call) { callsites_.emplace_back(call); } + + /// @returns the ancestor entry points + const std::vector& AncestorEntryPoints() const { return ancestor_entry_points_; } + /// Adds a record that the given entry point transitively calls this function + /// @param entry_point the entry point that transtively calls this function + void AddAncestorEntryPoint(const sem::Function* entry_point) { + ancestor_entry_points_.emplace_back(entry_point); + } + /// Retrieves any referenced location variables /// @returns the pair. std::vector> @@ -174,8 +242,9 @@ class Function : public Castable { utils::UniqueVector transitively_referenced_globals_; utils::UniqueVector transitively_called_functions_; utils::UniqueVector directly_called_intrinsics_; - std::vector callsites_; - std::vector ancestor_entry_points_; + std::vector direct_calls_; + std::vector callsites_; + std::vector ancestor_entry_points_; }; } // namespace sem diff --git a/src/sem/if_statement.cc b/src/sem/if_statement.cc index 4b5d843fe4..cf34d8d336 100644 --- a/src/sem/if_statement.cc +++ b/src/sem/if_statement.cc @@ -23,14 +23,16 @@ namespace tint { namespace sem { IfStatement::IfStatement(const ast::IfStatement* declaration, - CompoundStatement* parent) - : Base(declaration, parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) {} IfStatement::~IfStatement() = default; ElseStatement::ElseStatement(const ast::ElseStatement* declaration, - CompoundStatement* parent) - : Base(declaration, parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) {} ElseStatement::~ElseStatement() = default; diff --git a/src/sem/if_statement.h b/src/sem/if_statement.h index d615a449b3..6c25fcab01 100644 --- a/src/sem/if_statement.h +++ b/src/sem/if_statement.h @@ -34,7 +34,10 @@ class IfStatement : public Castable { /// Constructor /// @param declaration the AST node for this if statement /// @param parent the owning statement - IfStatement(const ast::IfStatement* declaration, CompoundStatement* parent); + /// @param function the owning function + IfStatement(const ast::IfStatement* declaration, + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~IfStatement() override; @@ -46,8 +49,10 @@ class ElseStatement : public Castable { /// Constructor /// @param declaration the AST node for this else statement /// @param parent the owning statement + /// @param function the owning function ElseStatement(const ast::ElseStatement* declaration, - CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~ElseStatement() override; diff --git a/src/sem/info.h b/src/sem/info.h index 78b82d5e83..72e2d3ecf9 100644 --- a/src/sem/info.h +++ b/src/sem/info.h @@ -27,10 +27,18 @@ namespace sem { /// Info holds all the resolved semantic information for a Program. class Info { + public: /// Placeholder type used by Get() to provide a default value for EXPLICIT_SEM using InferFromAST = std::nullptr_t; - public: + /// Resolves to the return type of the Get() method given the desired sementic + /// type and AST type. + template + using GetResultType = + std::conditional_t::value, + SemanticNodeTypeFor, + SEM>; + /// Constructor Info(); @@ -50,10 +58,7 @@ class Info { /// @returns a pointer to the semantic node if found, otherwise nullptr template ::value, - SemanticNodeTypeFor, - SEM>> + typename RESULT = GetResultType> const RESULT* Get(const AST_OR_TYPE* node) const { auto it = map.find(node); if (it == map.end()) { diff --git a/src/sem/loop_statement.cc b/src/sem/loop_statement.cc index 62ea00d6e7..6c43b934b0 100644 --- a/src/sem/loop_statement.cc +++ b/src/sem/loop_statement.cc @@ -23,16 +23,22 @@ namespace tint { namespace sem { LoopStatement::LoopStatement(const ast::LoopStatement* declaration, - CompoundStatement* parent) - : Base(declaration, parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) { + TINT_ASSERT(Semantic, parent); + TINT_ASSERT(Semantic, function); +} LoopStatement::~LoopStatement() = default; LoopContinuingBlockStatement::LoopContinuingBlockStatement( const ast::BlockStatement* declaration, - const CompoundStatement* parent) - : Base(declaration, parent) { + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) { TINT_ASSERT(Semantic, parent); + TINT_ASSERT(Semantic, function); } LoopContinuingBlockStatement::~LoopContinuingBlockStatement() = default; diff --git a/src/sem/loop_statement.h b/src/sem/loop_statement.h index c80bc8c40b..ad1f0a84c6 100644 --- a/src/sem/loop_statement.h +++ b/src/sem/loop_statement.h @@ -33,8 +33,10 @@ class LoopStatement : public Castable { /// Constructor /// @param declaration the AST node for this loop statement /// @param parent the owning statement + /// @param function the owning function LoopStatement(const ast::LoopStatement* declaration, - CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~LoopStatement() override; @@ -47,8 +49,10 @@ class LoopContinuingBlockStatement /// Constructor /// @param declaration the AST node for this block statement /// @param parent the owning statement + /// @param function the owning function LoopContinuingBlockStatement(const ast::BlockStatement* declaration, - const CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~LoopContinuingBlockStatement() override; diff --git a/src/sem/statement.cc b/src/sem/statement.cc index 0accb48b1b..098fe14e5d 100644 --- a/src/sem/statement.cc +++ b/src/sem/statement.cc @@ -27,23 +27,18 @@ namespace tint { namespace sem { Statement::Statement(const ast::Statement* declaration, - const CompoundStatement* parent) - : declaration_(declaration), parent_(parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : declaration_(declaration), parent_(parent), function_(function) {} const BlockStatement* Statement::Block() const { return FindFirstParent(); } -const ast::Function* Statement::Function() const { - if (auto* fbs = FindFirstParent()) { - return fbs->Function(); - } - return nullptr; -} - CompoundStatement::CompoundStatement(const ast::Statement* declaration, - const CompoundStatement* parent) - : Base(declaration, parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) {} CompoundStatement::~CompoundStatement() = default; diff --git a/src/sem/statement.h b/src/sem/statement.h index 0449c3aeb3..e5297dc674 100644 --- a/src/sem/statement.h +++ b/src/sem/statement.h @@ -33,6 +33,7 @@ namespace sem { /// Forward declaration class CompoundStatement; +class Function; namespace detail { /// FindFirstParentReturn is a traits helper for determining the return type for @@ -64,7 +65,10 @@ class Statement : public Castable { /// Constructor /// @param declaration the AST node for this statement /// @param parent the owning statement - Statement(const ast::Statement* declaration, const CompoundStatement* parent); + /// @param function the owning function + Statement(const ast::Statement* declaration, + const CompoundStatement* parent, + const sem::Function* function); /// @return the AST node for this statement const ast::Statement* Declaration() const { return declaration_; } @@ -90,11 +94,12 @@ class Statement : public Castable { const BlockStatement* Block() const; /// @returns the function that owns this statement - const ast::Function* Function() const; + const sem::Function* Function() const { return function_; } private: - ast::Statement const* const declaration_; - CompoundStatement const* const parent_; + const ast::Statement* const declaration_; + const CompoundStatement* const parent_; + const sem::Function* const function_; }; /// CompoundStatement is the base class of statements that can hold other @@ -103,9 +108,11 @@ class CompoundStatement : public Castable { public: /// Constructor /// @param declaration the AST node for this statement - /// @param parent the owning statement + /// @param statement the owning statement + /// @param function the owning function CompoundStatement(const ast::Statement* declaration, - const CompoundStatement* parent); + const CompoundStatement* statement, + const sem::Function* function); /// Destructor ~CompoundStatement() override; diff --git a/src/sem/switch_statement.cc b/src/sem/switch_statement.cc index c99ff719bd..fe13c3ef5e 100644 --- a/src/sem/switch_statement.cc +++ b/src/sem/switch_statement.cc @@ -23,16 +23,22 @@ namespace tint { namespace sem { SwitchStatement::SwitchStatement(const ast::SwitchStatement* declaration, - CompoundStatement* parent) - : Base(declaration, parent) {} + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) { + TINT_ASSERT(Semantic, parent); + TINT_ASSERT(Semantic, function); +} SwitchStatement::~SwitchStatement() = default; SwitchCaseBlockStatement::SwitchCaseBlockStatement( const ast::BlockStatement* declaration, - const CompoundStatement* parent) - : Base(declaration, parent) { + const CompoundStatement* parent, + const sem::Function* function) + : Base(declaration, parent, function) { TINT_ASSERT(Semantic, parent); + TINT_ASSERT(Semantic, function); } SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default; diff --git a/src/sem/switch_statement.h b/src/sem/switch_statement.h index f7bb735a1a..8e5a2cddeb 100644 --- a/src/sem/switch_statement.h +++ b/src/sem/switch_statement.h @@ -33,8 +33,10 @@ class SwitchStatement : public Castable { /// Constructor /// @param declaration the AST node for this switch statement /// @param parent the owning statement + /// @param function the owning function SwitchStatement(const ast::SwitchStatement* declaration, - CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~SwitchStatement() override; @@ -47,8 +49,10 @@ class SwitchCaseBlockStatement /// Constructor /// @param declaration the AST node for this block statement /// @param parent the owning statement + /// @param function the owning function SwitchCaseBlockStatement(const ast::BlockStatement* declaration, - const CompoundStatement* parent); + const CompoundStatement* parent, + const sem::Function* function); /// Destructor ~SwitchCaseBlockStatement() override; diff --git a/src/sem/variable.cc b/src/sem/variable.cc index a1e0eb8e2a..5eac5b839a 100644 --- a/src/sem/variable.cc +++ b/src/sem/variable.cc @@ -31,19 +31,26 @@ namespace sem { Variable::Variable(const ast::Variable* declaration, const sem::Type* type, ast::StorageClass storage_class, - ast::Access access) + ast::Access access, + Constant constant_value) : declaration_(declaration), type_(type), storage_class_(storage_class), - access_(access) {} + access_(access), + constant_value_(constant_value) {} Variable::~Variable() = default; LocalVariable::LocalVariable(const ast::Variable* declaration, const sem::Type* type, ast::StorageClass storage_class, - ast::Access access) - : Base(declaration, type, storage_class, access) {} + ast::Access access, + Constant constant_value) + : Base(declaration, + type, + storage_class, + access, + std::move(constant_value)) {} LocalVariable::~LocalVariable() = default; @@ -51,20 +58,12 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration, const sem::Type* type, ast::StorageClass storage_class, ast::Access access, + Constant constant_value, sem::BindingPoint binding_point) - : Base(declaration, type, storage_class, access), + : Base(declaration, type, storage_class, access, std::move(constant_value)), binding_point_(binding_point), is_pipeline_constant_(false) {} -GlobalVariable::GlobalVariable(const ast::Variable* declaration, - const sem::Type* type, - uint16_t constant_id) - : Base(declaration, - type, - ast::StorageClass::kNone, - ast::Access::kReadWrite), - is_pipeline_constant_(true), - constant_id_(constant_id) {} GlobalVariable::~GlobalVariable() = default; @@ -74,18 +73,16 @@ Parameter::Parameter(const ast::Variable* declaration, ast::StorageClass storage_class, ast::Access access, const ParameterUsage usage /* = ParameterUsage::kNone */) - : Base(declaration, type, storage_class, access), + : Base(declaration, type, storage_class, access, Constant{}), index_(index), usage_(usage) {} Parameter::~Parameter() = default; VariableUser::VariableUser(const ast::IdentifierExpression* declaration, - const sem::Type* type, Statement* statement, - sem::Variable* variable, - Constant constant_value) - : Base(declaration, type, statement, std::move(constant_value)), + sem::Variable* variable) + : Base(declaration, variable->Type(), statement, variable->ConstantValue()), variable_(variable) {} } // namespace sem diff --git a/src/sem/variable.h b/src/sem/variable.h index 782fda926f..5d4ac05669 100644 --- a/src/sem/variable.h +++ b/src/sem/variable.h @@ -47,10 +47,12 @@ class Variable : public Castable { /// @param type the variable type /// @param storage_class the variable storage class /// @param access the variable access control type + /// @param constant_value the constant value for the variable. May be invalid Variable(const ast::Variable* declaration, const sem::Type* type, ast::StorageClass storage_class, - ast::Access access); + ast::Access access, + Constant constant_value); /// Destructor ~Variable() override; @@ -67,6 +69,9 @@ class Variable : public Castable { /// @returns the access control for the variable ast::Access Access() const { return access_; } + /// @return the constant value of this expression + const Constant& ConstantValue() const { return constant_value_; } + /// @returns the expressions that use the variable const std::vector& Users() const { return users_; } @@ -78,6 +83,7 @@ class Variable : public Castable { const sem::Type* const type_; ast::StorageClass const storage_class_; ast::Access const access_; + const Constant constant_value_; std::vector users_; }; @@ -89,10 +95,12 @@ class LocalVariable : public Castable { /// @param type the variable type /// @param storage_class the variable storage class /// @param access the variable access control type + /// @param constant_value the constant value for the variable. May be invalid LocalVariable(const ast::Variable* declaration, const sem::Type* type, ast::StorageClass storage_class, - ast::Access access); + ast::Access access, + Constant constant_value); /// Destructor ~LocalVariable() override; @@ -101,26 +109,20 @@ class LocalVariable : public Castable { /// GlobalVariable is a module-scope variable class GlobalVariable : public Castable { public: - /// Constructor for non-overridable constants + /// Constructor /// @param declaration the AST declaration node /// @param type the variable type /// @param storage_class the variable storage class /// @param access the variable access control type + /// @param constant_value the constant value for the variable. May be invalid /// @param binding_point the optional resource binding point of the variable GlobalVariable(const ast::Variable* declaration, const sem::Type* type, ast::StorageClass storage_class, ast::Access access, + Constant constant_value, sem::BindingPoint binding_point = {}); - /// Constructor for overridable pipeline constants - /// @param declaration the AST declaration node - /// @param type the variable type - /// @param constant_id the pipeline constant ID - GlobalVariable(const ast::Variable* declaration, - const sem::Type* type, - uint16_t constant_id); - /// Destructor ~GlobalVariable() override; @@ -130,13 +132,20 @@ class GlobalVariable : public Castable { /// @returns the pipeline constant ID associated with the variable uint16_t ConstantId() const { return constant_id_; } + /// @param id the constant identifier to assign to this variable + void SetConstantId(uint16_t id) { + constant_id_ = id; + is_pipeline_constant_ = true; + } + /// @returns true if this variable is an overridable pipeline constant bool IsPipelineConstant() const { return is_pipeline_constant_; } private: - sem::BindingPoint binding_point_; - const bool is_pipeline_constant_; - uint16_t const constant_id_ = 0; + const sem::BindingPoint binding_point_; + + bool is_pipeline_constant_ = false; + uint16_t constant_id_ = 0; }; /// Parameter is a function parameter @@ -186,15 +195,11 @@ class VariableUser : public Castable { public: /// Constructor /// @param declaration the AST identifier node - /// @param type the resolved type of the expression /// @param statement the statement that owns this expression /// @param variable the semantic variable - /// @param constant_value the constant value for the variable. May be invalid VariableUser(const ast::IdentifierExpression* declaration, - const sem::Type* type, Statement* statement, - sem::Variable* variable, - Constant constant_value); + sem::Variable* variable); /// @returns the variable that this expression refers to const sem::Variable* Variable() const { return variable_; } diff --git a/src/transform/module_scope_var_to_entry_point_param.cc b/src/transform/module_scope_var_to_entry_point_param.cc index 17076eb6d6..0efb0b41c6 100644 --- a/src/transform/module_scope_var_to_entry_point_param.cc +++ b/src/transform/module_scope_var_to_entry_point_param.cc @@ -109,8 +109,8 @@ struct ModuleScopeVarToEntryPointParam::State { // Find all of the calls to this function that will need to be replaced. for (auto* call : func_sem->CallSites()) { - auto* call_sem = ctx.src->Sem().Get(call); - calls_to_replace[call_sem->Stmt()->Function()].push_back(call); + calls_to_replace[call->Stmt()->Function()->Declaration()].push_back( + call->Declaration()); } } } @@ -268,7 +268,7 @@ struct ModuleScopeVarToEntryPointParam::State { // Replace all uses of the module-scope variable. // For non-entry points, dereference non-handle pointer parameters. for (auto* user : var->Users()) { - if (user->Stmt()->Function() == func_ast) { + if (user->Stmt()->Function()->Declaration() == func_ast) { const ast::Expression* expr = ctx.dst->Expr(new_var_symbol); if (is_pointer) { // If this identifier is used by an address-of operator, just diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc index 0d3eb301f6..4dfe0c1e09 100644 --- a/src/transform/robustness_test.cc +++ b/src/transform/robustness_test.cc @@ -39,7 +39,7 @@ var a : array; let c : u32 = 1u; fn f() { - let b : f32 = a[min(c, 2u)]; + let b : f32 = a[1u]; } )"; @@ -807,7 +807,7 @@ struct S { let c : u32 = 1u; fn f() { - let b : f32 = s.b[min(c, (arrayLength(&(s.b)) - 1u))]; + let b : f32 = s.b[min(1u, (arrayLength(&(s.b)) - 1u))]; let x : i32 = min(1, 2); let y : u32 = arrayLength(&(s.b)); } diff --git a/test/bug/tint/1121.wgsl.expected.spvasm b/test/bug/tint/1121.wgsl.expected.spvasm index 6d3354b7da..e9ee84c048 100644 --- a/test/bug/tint/1121.wgsl.expected.spvasm +++ b/test/bug/tint/1121.wgsl.expected.spvasm @@ -177,6 +177,7 @@ %209 = OpConstantComposite %v2float %float_1 %float_1 %_ptr_Function_v2float = OpTypePointer Function %v2float %213 = OpConstantNull %v2float + %216 = OpConstantComposite %v2int %int_16 %int_16 %int_1 = OpConstant %int 1 %_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint %_ptr_StorageBuffer_uint_0 = OpTypePointer StorageBuffer %uint @@ -359,7 +360,6 @@ %210 = OpFSub %v2float %208 %209 OpStore %floorCoord %210 %215 = OpLoad %v2int %tilePixel0Idx - %216 = OpCompositeConstruct %v2int %int_16 %int_16 %217 = OpIAdd %v2int %215 %216 %214 = OpConvertSToF %v2float %217 %218 = OpVectorTimesScalar %v2float %214 %float_2