From 87c78ddabc62d8db0392320231cc0d6e2f554951 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 3 Feb 2021 16:43:20 +0000 Subject: [PATCH] Add semantic::Function, use it. Pull the mutable semantic fields from ast::Function and into a new semantic::Function node. Have the TypeDeterminer create these semantic::Function nodes. Bug: tint:390 Change-Id: I237b1bed8709dd9a3cfa24d85d48fc77b7e532da Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/39902 Reviewed-by: David Neto Commit-Queue: Ben Clayton --- BUILD.gn | 1 + src/CMakeLists.txt | 1 + src/ast/function.cc | 222 ---------------- src/ast/function.h | 93 +------ src/ast/function_test.cc | 108 -------- src/inspector/inspector.cc | 30 ++- src/semantic/expression.h | 2 - src/semantic/function.h | 148 +++++++++++ src/semantic/sem_function.cc | 237 ++++++++++++++++++ src/semantic/sem_function_test.cc | 107 ++++++++ src/semantic/test_helper.h | 39 +++ src/semantic/type_mappings.h | 3 + src/transform/first_index_offset.cc | 4 +- src/type_determiner.cc | 48 +++- src/type_determiner.h | 42 +++- src/type_determiner_test.cc | 37 ++- src/writer/hlsl/generator_impl.cc | 60 +++-- src/writer/hlsl/generator_impl.h | 6 +- src/writer/msl/generator_impl.cc | 46 ++-- src/writer/spirv/builder.cc | 7 +- .../spirv/builder_function_decoration_test.cc | 9 +- src/writer/wgsl/generator_impl.cc | 6 +- 22 files changed, 744 insertions(+), 512 deletions(-) create mode 100644 src/semantic/function.h create mode 100644 src/semantic/sem_function.cc create mode 100644 src/semantic/sem_function_test.cc create mode 100644 src/semantic/test_helper.h diff --git a/BUILD.gn b/BUILD.gn index 23d89b1c07..8e15afc8b5 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -382,6 +382,7 @@ source_set("libtint_core_src") { "src/semantic/info.h", "src/semantic/node.h", "src/semantic/sem_expression.cc", + "src/semantic/sem_function.cc", "src/semantic/sem_info.cc", "src/semantic/sem_node.cc", "src/semantic/type_mappings.h", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2ab4a49695..56a1efb9da 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -196,6 +196,7 @@ set(TINT_LIB_SRCS semantic/info.h semantic/node.h semantic/sem_expression.cc + semantic/sem_function.cc semantic/sem_info.cc semantic/sem_node.cc semantic/type_mappings.h diff --git a/src/ast/function.cc b/src/ast/function.cc index f17421a19b..a2b5ef7e91 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -65,161 +65,6 @@ PipelineStage Function::pipeline_stage() const { return PipelineStage::kNone; } -void Function::add_referenced_module_variable(Variable* var) { - for (const auto* v : referenced_module_vars_) { - if (v->symbol() == var->symbol()) { - return; - } - } - referenced_module_vars_.push_back(var); -} - -void Function::add_local_referenced_module_variable(Variable* var) { - for (const auto* v : local_referenced_module_vars_) { - if (v->symbol() == var->symbol()) { - return; - } - } - local_referenced_module_vars_.push_back(var); -} - -const std::vector> -Function::referenced_location_variables() const { - std::vector> ret; - - for (auto* var : referenced_module_variables()) { - for (auto* deco : var->decorations()) { - if (auto* location = deco->As()) { - ret.push_back({var, location}); - break; - } - } - } - return ret; -} - -const std::vector> -Function::referenced_uniform_variables() const { - std::vector> ret; - - for (auto* var : referenced_module_variables()) { - if (var->storage_class() != StorageClass::kUniform) { - continue; - } - - BindingDecoration* binding = nullptr; - GroupDecoration* group = nullptr; - for (auto* deco : var->decorations()) { - if (auto* b = deco->As()) { - binding = b; - } else if (auto* g = deco->As()) { - group = g; - } - } - if (binding == nullptr || group == nullptr) { - continue; - } - - ret.push_back({var, BindingInfo{binding, group}}); - } - return ret; -} - -const std::vector> -Function::referenced_storagebuffer_variables() const { - std::vector> ret; - - for (auto* var : referenced_module_variables()) { - if (var->storage_class() != StorageClass::kStorage) { - continue; - } - - BindingDecoration* binding = nullptr; - GroupDecoration* group = nullptr; - for (auto* deco : var->decorations()) { - if (auto* b = deco->As()) { - binding = b; - } else if (auto* s = deco->As()) { - group = s; - } - } - if (binding == nullptr || group == nullptr) { - continue; - } - - ret.push_back({var, BindingInfo{binding, group}}); - } - return ret; -} - -const std::vector> -Function::referenced_builtin_variables() const { - std::vector> ret; - - for (auto* var : referenced_module_variables()) { - for (auto* deco : var->decorations()) { - if (auto* builtin = deco->As()) { - ret.push_back({var, builtin}); - break; - } - } - } - return ret; -} - -const std::vector> -Function::referenced_sampler_variables() const { - return ReferencedSamplerVariablesImpl(type::SamplerKind::kSampler); -} - -const std::vector> -Function::referenced_comparison_sampler_variables() const { - return ReferencedSamplerVariablesImpl(type::SamplerKind::kComparisonSampler); -} - -const std::vector> -Function::referenced_sampled_texture_variables() const { - return ReferencedSampledTextureVariablesImpl(false); -} - -const std::vector> -Function::referenced_multisampled_texture_variables() const { - return ReferencedSampledTextureVariablesImpl(true); -} - -const std::vector> -Function::local_referenced_builtin_variables() const { - std::vector> ret; - - for (auto* var : local_referenced_module_variables()) { - for (auto* deco : var->decorations()) { - if (auto* builtin = deco->As()) { - ret.push_back({var, builtin}); - break; - } - } - } - return ret; -} - -void Function::add_ancestor_entry_point(Symbol ep) { - for (const auto& point : ancestor_entry_points_) { - if (point == ep) { - return; - } - } - ancestor_entry_points_.push_back(ep); -} - -bool Function::HasAncestorEntryPoint(Symbol symbol) const { - for (const auto& point : ancestor_entry_points_) { - if (point == symbol) { - return true; - } - } - return false; -} - const Statement* Function::get_last_statement() const { return body_->last(); } @@ -295,73 +140,6 @@ std::string Function::type_name() const { return out.str(); } -const std::vector> -Function::ReferencedSamplerVariablesImpl(type::SamplerKind kind) const { - std::vector> ret; - - for (auto* var : referenced_module_variables()) { - auto* unwrapped_type = var->type()->UnwrapIfNeeded(); - auto* sampler = unwrapped_type->As(); - if (sampler == nullptr || sampler->kind() != kind) { - continue; - } - - BindingDecoration* binding = nullptr; - GroupDecoration* group = nullptr; - for (auto* deco : var->decorations()) { - if (auto* b = deco->As()) { - binding = b; - } - if (auto* s = deco->As()) { - group = s; - } - } - if (binding == nullptr || group == nullptr) { - continue; - } - - ret.push_back({var, BindingInfo{binding, group}}); - } - return ret; -} - -const std::vector> -Function::ReferencedSampledTextureVariablesImpl(bool multisampled) const { - std::vector> ret; - - for (auto* var : referenced_module_variables()) { - auto* unwrapped_type = var->type()->UnwrapIfNeeded(); - auto* texture = unwrapped_type->As(); - if (texture == nullptr) { - continue; - } - - auto is_multisampled = texture->Is(); - auto is_sampled = texture->Is(); - - if ((multisampled && !is_multisampled) || (!multisampled && !is_sampled)) { - continue; - } - - BindingDecoration* binding = nullptr; - GroupDecoration* group = nullptr; - for (auto* deco : var->decorations()) { - if (auto* b = deco->As()) { - binding = b; - } else if (auto* s = deco->As()) { - group = s; - } - } - if (binding == nullptr || group == nullptr) { - continue; - } - - ret.push_back({var, BindingInfo{binding, group}}); - } - - return ret; -} - Function* FunctionList::Find(Symbol sym) const { for (auto* func : *this) { if (func->symbol() == sym) { diff --git a/src/ast/function.h b/src/ast/function.h index ff72b14b38..02c6cbafd0 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -43,14 +43,6 @@ namespace ast { /// A Function statement. class Function : public Castable { public: - /// Information about a binding - struct BindingInfo { - /// The binding decoration - BindingDecoration* binding = nullptr; - /// The group decoration - GroupDecoration* group = nullptr; - }; - /// Create a function /// @param source the variable source /// @param symbol the function symbol @@ -87,82 +79,9 @@ class Function : public Castable { /// @returns true if this function is an entry point bool IsEntryPoint() const { return pipeline_stage() != PipelineStage::kNone; } - /// Adds the given variable to the list of referenced module variables if it - /// is not already included. - /// @param var the module variable to add - void add_referenced_module_variable(Variable* var); - /// Adds the given variable to the list of locally referenced module variables - /// if it is not already included. - /// @param var the module variable to add - void add_local_referenced_module_variable(Variable* var); - /// Note: If this function calls other functions, the return will also include - /// all of the referenced variables from the callees. - /// @returns the referenced module variables - const std::vector& referenced_module_variables() const { - return referenced_module_vars_; - } - /// @returns the locally referenced module variables - const std::vector& local_referenced_module_variables() const { - return local_referenced_module_vars_; - } - /// Retrieves any referenced location variables - /// @returns the pair. - const std::vector> - referenced_location_variables() const; - /// Retrieves any referenced builtin variables - /// @returns the pair. - const std::vector> - referenced_builtin_variables() const; - /// Retrieves any referenced uniform variables. Note, the variables must be - /// decorated with both binding and group decorations. - /// @returns the referenced uniforms - const std::vector> - referenced_uniform_variables() const; - /// Retrieves any referenced storagebuffer variables. Note, the variables - /// must be decorated with both binding and group decorations. - /// @returns the referenced storagebuffers - const std::vector> - referenced_storagebuffer_variables() const; - /// Retrieves any referenced regular Sampler variables. Note, the - /// variables must be decorated with both binding and group decorations. - /// @returns the referenced storagebuffers - const std::vector> - referenced_sampler_variables() const; - /// Retrieves any referenced comparison Sampler variables. Note, the - /// variables must be decorated with both binding and group decorations. - /// @returns the referenced storagebuffers - const std::vector> - referenced_comparison_sampler_variables() const; - /// Retrieves any referenced sampled textures variables. Note, the - /// variables must be decorated with both binding and group decorations. - /// @returns the referenced sampled textures - const std::vector> - referenced_sampled_texture_variables() const; - /// Retrieves any referenced multisampled textures variables. Note, the - /// variables must be decorated with both binding and group decorations. - /// @returns the referenced sampled textures - const std::vector> - referenced_multisampled_texture_variables() const; - - /// Retrieves any locally referenced builtin variables - /// @returns the pairs. - const std::vector> - local_referenced_builtin_variables() const; - - /// Adds an ancestor entry point - /// @param ep the entry point ancestor - void add_ancestor_entry_point(Symbol ep); - /// @returns the ancestor entry points - const std::vector& ancestor_entry_points() const { - return ancestor_entry_points_; - } - /// Checks if the given entry point is an ancestor - /// @param sym the entry point symbol - /// @returns true if `sym` is an ancestor entry point of this function - bool HasAncestorEntryPoint(Symbol sym) const; - /// @returns the function return type. type::Type* return_type() const { return return_type_; } + /// @returns a pointer to the last statement of the function or nullptr if // function is empty const Statement* get_last_statement() const; @@ -196,20 +115,12 @@ class Function : public Castable { private: Function(const Function&) = delete; - const std::vector> - ReferencedSamplerVariablesImpl(type::SamplerKind kind) const; - const std::vector> - ReferencedSampledTextureVariablesImpl(bool multisampled) const; Symbol const symbol_; VariableList const params_; type::Type* const return_type_; BlockStatement* const body_; - - std::vector referenced_module_vars_; // Semantic info - std::vector local_referenced_module_vars_; // Semantic info - std::vector ancestor_entry_points_; // Semantic info - FunctionDecorationList decorations_; // Semantic info + FunctionDecorationList const decorations_; }; /// A list of functions diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index f148bedf44..a74cb6aa0c 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -53,114 +53,6 @@ TEST_F(FunctionTest, Creation_WithSource) { EXPECT_EQ(src.range.begin.column, 2u); } -TEST_F(FunctionTest, AddDuplicateReferencedVariables) { - auto* v = Var("var", StorageClass::kInput, ty.i32()); - auto* f = Func("func", VariableList{}, ty.void_(), StatementList{}, - FunctionDecorationList{}); - - f->add_referenced_module_variable(v); - ASSERT_EQ(f->referenced_module_variables().size(), 1u); - EXPECT_EQ(f->referenced_module_variables()[0], v); - - f->add_referenced_module_variable(v); - ASSERT_EQ(f->referenced_module_variables().size(), 1u); - - auto* v2 = Var("var2", StorageClass::kOutput, ty.i32()); - f->add_referenced_module_variable(v2); - ASSERT_EQ(f->referenced_module_variables().size(), 2u); - EXPECT_EQ(f->referenced_module_variables()[1], v2); -} - -TEST_F(FunctionTest, GetReferenceLocations) { - auto* loc1 = Var("loc1", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(0), - }); - - auto* loc2 = Var("loc2", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(1), - }); - - auto* builtin1 = Var("builtin1", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(Builtin::kPosition), - }); - - auto* builtin2 = Var("builtin2", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(Builtin::kFragDepth), - }); - - auto* f = Func("func", VariableList{}, ty.void_(), StatementList{}, - FunctionDecorationList{}); - - f->add_referenced_module_variable(loc1); - f->add_referenced_module_variable(builtin1); - f->add_referenced_module_variable(loc2); - f->add_referenced_module_variable(builtin2); - ASSERT_EQ(f->referenced_module_variables().size(), 4u); - - auto ref_locs = f->referenced_location_variables(); - ASSERT_EQ(ref_locs.size(), 2u); - EXPECT_EQ(ref_locs[0].first, loc1); - EXPECT_EQ(ref_locs[0].second->value(), 0u); - EXPECT_EQ(ref_locs[1].first, loc2); - EXPECT_EQ(ref_locs[1].second->value(), 1u); -} - -TEST_F(FunctionTest, GetReferenceBuiltins) { - auto* loc1 = Var("loc1", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(0), - }); - - auto* loc2 = Var("loc2", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(1), - }); - - auto* builtin1 = Var("builtin1", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(Builtin::kPosition), - }); - - auto* builtin2 = Var("builtin2", StorageClass::kInput, ty.i32(), nullptr, - ast::VariableDecorationList{ - create(Builtin::kFragDepth), - }); - - auto* f = Func("func", VariableList{}, ty.void_(), StatementList{}, - FunctionDecorationList{}); - - f->add_referenced_module_variable(loc1); - f->add_referenced_module_variable(builtin1); - f->add_referenced_module_variable(loc2); - f->add_referenced_module_variable(builtin2); - ASSERT_EQ(f->referenced_module_variables().size(), 4u); - - auto ref_locs = f->referenced_builtin_variables(); - ASSERT_EQ(ref_locs.size(), 2u); - EXPECT_EQ(ref_locs[0].first, builtin1); - EXPECT_EQ(ref_locs[0].second->value(), Builtin::kPosition); - EXPECT_EQ(ref_locs[1].first, builtin2); - EXPECT_EQ(ref_locs[1].second->value(), Builtin::kFragDepth); -} - -TEST_F(FunctionTest, AddDuplicateEntryPoints) { - auto* f = Func("func", VariableList{}, ty.void_(), StatementList{}, - FunctionDecorationList{}); - - auto main_sym = Symbols().Get("main"); - f->add_ancestor_entry_point(main_sym); - ASSERT_EQ(1u, f->ancestor_entry_points().size()); - EXPECT_EQ(main_sym, f->ancestor_entry_points()[0]); - - f->add_ancestor_entry_point(main_sym); - ASSERT_EQ(1u, f->ancestor_entry_points().size()); - EXPECT_EQ(main_sym, f->ancestor_entry_points()[0]); -} - TEST_F(FunctionTest, IsValid) { VariableList params; params.push_back(Var("var", StorageClass::kNone, ty.i32())); diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index 3c2e7a46f9..ccfe471bc9 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -29,6 +29,7 @@ #include "src/ast/uint_literal.h" #include "src/ast/variable.h" #include "src/program.h" +#include "src/semantic/function.h" #include "src/type/access_control_type.h" #include "src/type/array_type.h" #include "src/type/f32_type.h" @@ -64,7 +65,7 @@ std::vector Inspector::GetEntryPoints() { std::tie(entry_point.workgroup_size_x, entry_point.workgroup_size_y, entry_point.workgroup_size_z) = func->workgroup_size(); - for (auto* var : func->referenced_module_variables()) { + for (auto* var : program_->Sem().Get(func)->ReferencedModuleVariables()) { auto name = program_->Symbols().NameFor(var->symbol()); if (var->HasBuiltinDecoration()) { continue; @@ -185,10 +186,11 @@ std::vector Inspector::GetUniformBufferResourceBindings( std::vector result; - for (auto& ruv : func->referenced_uniform_variables()) { + auto* func_sem = program_->Sem().Get(func); + for (auto& ruv : func_sem->ReferencedUniformVariables()) { ResourceBinding entry; ast::Variable* var = nullptr; - ast::Function::BindingInfo binding_info; + semantic::Function::BindingInfo binding_info; std::tie(var, binding_info) = ruv; if (!var->type()->Is()) { continue; @@ -235,10 +237,11 @@ std::vector Inspector::GetSamplerResourceBindings( std::vector result; - for (auto& rs : func->referenced_sampler_variables()) { + auto* func_sem = program_->Sem().Get(func); + for (auto& rs : func_sem->ReferencedSamplerVariables()) { ResourceBinding entry; ast::Variable* var = nullptr; - ast::Function::BindingInfo binding_info; + semantic::Function::BindingInfo binding_info; std::tie(var, binding_info) = rs; entry.bind_group = binding_info.group->value(); @@ -259,10 +262,11 @@ std::vector Inspector::GetComparisonSamplerResourceBindings( std::vector result; - for (auto& rcs : func->referenced_comparison_sampler_variables()) { + auto* func_sem = program_->Sem().Get(func); + for (auto& rcs : func_sem->ReferencedComparisonSamplerVariables()) { ResourceBinding entry; ast::Variable* var = nullptr; - ast::Function::BindingInfo binding_info; + semantic::Function::BindingInfo binding_info; std::tie(var, binding_info) = rcs; entry.bind_group = binding_info.group->value(); @@ -307,11 +311,12 @@ std::vector Inspector::GetStorageBufferResourceBindingsImpl( return {}; } + auto* func_sem = program_->Sem().Get(func); std::vector result; - for (auto& rsv : func->referenced_storagebuffer_variables()) { + for (auto& rsv : func_sem->ReferencedStoragebufferVariables()) { ResourceBinding entry; ast::Variable* var = nullptr; - ast::Function::BindingInfo binding_info; + semantic::Function::BindingInfo binding_info; std::tie(var, binding_info) = rsv; auto* ac_type = var->type()->As(); @@ -347,13 +352,14 @@ std::vector Inspector::GetSampledTextureResourceBindingsImpl( } std::vector result; + auto* func_sem = program_->Sem().Get(func); auto& referenced_variables = - multisampled_only ? func->referenced_multisampled_texture_variables() - : func->referenced_sampled_texture_variables(); + multisampled_only ? func_sem->ReferencedMultisampledTextureVariables() + : func_sem->ReferencedSampledTextureVariables(); for (auto& ref : referenced_variables) { ResourceBinding entry; ast::Variable* var = nullptr; - ast::Function::BindingInfo binding_info; + semantic::Function::BindingInfo binding_info; std::tie(var, binding_info) = ref; entry.bind_group = binding_info.group->value(); diff --git a/src/semantic/expression.h b/src/semantic/expression.h index 40b15e4c28..76052be37e 100644 --- a/src/semantic/expression.h +++ b/src/semantic/expression.h @@ -17,8 +17,6 @@ #include "src/semantic/node.h" -#include "src/semantic/type_mappings.h" - namespace tint { // Forward declarations diff --git a/src/semantic/function.h b/src/semantic/function.h new file mode 100644 index 0000000000..0c751138b9 --- /dev/null +++ b/src/semantic/function.h @@ -0,0 +1,148 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0(the "License"); + +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_SEMANTIC_FUNCTION_H_ +#define SRC_SEMANTIC_FUNCTION_H_ + +#include +#include + +#include "src/semantic/node.h" +#include "src/type/sampler_type.h" + +namespace tint { + +// Forward declarations +namespace ast { +class BindingDecoration; +class GroupDecoration; +class Variable; +class LocationDecoration; +class BuiltinDecoration; +} // namespace ast +namespace type { +class Type; +} // namespace type + +namespace semantic { + +/// Function holds the semantic information for function nodes. +class Function : public Castable { + public: + /// Information about a binding + struct BindingInfo { + /// The binding decoration + ast::BindingDecoration* binding = nullptr; + /// The group decoration + ast::GroupDecoration* group = nullptr; + }; + + /// Constructor + /// @param referenced_module_vars the referenced module variables + /// @param local_referenced_module_vars the locally referenced module + /// variables + /// @param ancestor_entry_points the ancestor entry points + explicit Function(std::vector referenced_module_vars, + std::vector local_referenced_module_vars, + std::vector ancestor_entry_points); + + /// Destructor + ~Function() override; + + /// Note: If this function calls other functions, the return will also include + /// all of the referenced variables from the callees. + /// @returns the referenced module variables + const std::vector& ReferencedModuleVariables() const { + return referenced_module_vars_; + } + /// @returns the locally referenced module variables + const std::vector& LocalReferencedModuleVariables() const { + return local_referenced_module_vars_; + } + /// @returns the ancestor entry points + const std::vector& AncestorEntryPoints() const { + return ancestor_entry_points_; + } + /// Retrieves any referenced location variables + /// @returns the pair. + const std::vector> + ReferencedLocationVariables() const; + + /// Retrieves any referenced builtin variables + /// @returns the pair. + const std::vector> + ReferencedBuiltinVariables() const; + + /// Retrieves any referenced uniform variables. Note, the variables must be + /// decorated with both binding and group decorations. + /// @returns the referenced uniforms + const std::vector> + ReferencedUniformVariables() const; + + /// Retrieves any referenced storagebuffer variables. Note, the variables + /// must be decorated with both binding and group decorations. + /// @returns the referenced storagebuffers + const std::vector> + ReferencedStoragebufferVariables() const; + + /// Retrieves any referenced regular Sampler variables. Note, the + /// variables must be decorated with both binding and group decorations. + /// @returns the referenced storagebuffers + const std::vector> + ReferencedSamplerVariables() const; + + /// Retrieves any referenced comparison Sampler variables. Note, the + /// variables must be decorated with both binding and group decorations. + /// @returns the referenced storagebuffers + const std::vector> + ReferencedComparisonSamplerVariables() const; + + /// Retrieves any referenced sampled textures variables. Note, the + /// variables must be decorated with both binding and group decorations. + /// @returns the referenced sampled textures + const std::vector> + ReferencedSampledTextureVariables() const; + + /// Retrieves any referenced multisampled textures variables. Note, the + /// variables must be decorated with both binding and group decorations. + /// @returns the referenced sampled textures + const std::vector> + ReferencedMultisampledTextureVariables() const; + + /// Retrieves any locally referenced builtin variables + /// @returns the pairs. + const std::vector> + LocalReferencedBuiltinVariables() const; + + /// Checks if the given entry point is an ancestor + /// @param sym the entry point symbol + /// @returns true if `sym` is an ancestor entry point of this function + bool HasAncestorEntryPoint(Symbol sym) const; + + private: + const std::vector> + ReferencedSamplerVariablesImpl(type::SamplerKind kind) const; + const std::vector> + ReferencedSampledTextureVariablesImpl(bool multisampled) const; + + std::vector const referenced_module_vars_; + std::vector const local_referenced_module_vars_; + std::vector const ancestor_entry_points_; +}; + +} // namespace semantic +} // namespace tint + +#endif // SRC_SEMANTIC_FUNCTION_H_ diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc new file mode 100644 index 0000000000..561369702f --- /dev/null +++ b/src/semantic/sem_function.cc @@ -0,0 +1,237 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/semantic/function.h" + +#include "src/ast/binding_decoration.h" +#include "src/ast/builtin_decoration.h" +#include "src/ast/group_decoration.h" +#include "src/ast/location_decoration.h" +#include "src/ast/variable.h" +#include "src/ast/variable_decoration.h" +#include "src/type/multisampled_texture_type.h" +#include "src/type/sampled_texture_type.h" +#include "src/type/texture_type.h" + +TINT_INSTANTIATE_CLASS_ID(tint::semantic::Function); + +namespace tint { +namespace semantic { + +Function::Function(std::vector referenced_module_vars, + std::vector local_referenced_module_vars, + std::vector ancestor_entry_points) + : referenced_module_vars_(std::move(referenced_module_vars)), + local_referenced_module_vars_(std::move(local_referenced_module_vars)), + ancestor_entry_points_(std::move(ancestor_entry_points)) {} + +Function::~Function() = default; + +const std::vector> +Function::ReferencedLocationVariables() const { + std::vector> ret; + + for (auto* var : ReferencedModuleVariables()) { + for (auto* deco : var->decorations()) { + if (auto* location = deco->As()) { + ret.push_back({var, location}); + break; + } + } + } + return ret; +} + +const std::vector> +Function::ReferencedUniformVariables() const { + std::vector> ret; + + for (auto* var : ReferencedModuleVariables()) { + if (var->storage_class() != ast::StorageClass::kUniform) { + continue; + } + + ast::BindingDecoration* binding = nullptr; + ast::GroupDecoration* group = nullptr; + for (auto* deco : var->decorations()) { + if (auto* b = deco->As()) { + binding = b; + } else if (auto* g = deco->As()) { + group = g; + } + } + if (binding == nullptr || group == nullptr) { + continue; + } + + ret.push_back({var, BindingInfo{binding, group}}); + } + return ret; +} + +const std::vector> +Function::ReferencedStoragebufferVariables() const { + std::vector> ret; + + for (auto* var : ReferencedModuleVariables()) { + if (var->storage_class() != ast::StorageClass::kStorage) { + continue; + } + + ast::BindingDecoration* binding = nullptr; + ast::GroupDecoration* group = nullptr; + for (auto* deco : var->decorations()) { + if (auto* b = deco->As()) { + binding = b; + } else if (auto* s = deco->As()) { + group = s; + } + } + if (binding == nullptr || group == nullptr) { + continue; + } + + ret.push_back({var, BindingInfo{binding, group}}); + } + return ret; +} + +const std::vector> +Function::ReferencedBuiltinVariables() const { + std::vector> ret; + + for (auto* var : ReferencedModuleVariables()) { + for (auto* deco : var->decorations()) { + if (auto* builtin = deco->As()) { + ret.push_back({var, builtin}); + break; + } + } + } + return ret; +} + +const std::vector> +Function::ReferencedSamplerVariables() const { + return ReferencedSamplerVariablesImpl(type::SamplerKind::kSampler); +} + +const std::vector> +Function::ReferencedComparisonSamplerVariables() const { + return ReferencedSamplerVariablesImpl(type::SamplerKind::kComparisonSampler); +} + +const std::vector> +Function::ReferencedSampledTextureVariables() const { + return ReferencedSampledTextureVariablesImpl(false); +} + +const std::vector> +Function::ReferencedMultisampledTextureVariables() const { + return ReferencedSampledTextureVariablesImpl(true); +} + +const std::vector> +Function::LocalReferencedBuiltinVariables() const { + std::vector> ret; + + for (auto* var : LocalReferencedModuleVariables()) { + for (auto* deco : var->decorations()) { + if (auto* builtin = deco->As()) { + ret.push_back({var, builtin}); + break; + } + } + } + return ret; +} + +bool Function::HasAncestorEntryPoint(Symbol symbol) const { + for (const auto& point : ancestor_entry_points_) { + if (point == symbol) { + return true; + } + } + return false; +} + +const std::vector> +Function::ReferencedSamplerVariablesImpl(type::SamplerKind kind) const { + std::vector> ret; + + for (auto* var : ReferencedModuleVariables()) { + auto* unwrapped_type = var->type()->UnwrapIfNeeded(); + auto* sampler = unwrapped_type->As(); + if (sampler == nullptr || sampler->kind() != kind) { + continue; + } + + ast::BindingDecoration* binding = nullptr; + ast::GroupDecoration* group = nullptr; + for (auto* deco : var->decorations()) { + if (auto* b = deco->As()) { + binding = b; + } + if (auto* s = deco->As()) { + group = s; + } + } + if (binding == nullptr || group == nullptr) { + continue; + } + + ret.push_back({var, BindingInfo{binding, group}}); + } + return ret; +} + +const std::vector> +Function::ReferencedSampledTextureVariablesImpl(bool multisampled) const { + std::vector> ret; + + for (auto* var : ReferencedModuleVariables()) { + auto* unwrapped_type = var->type()->UnwrapIfNeeded(); + auto* texture = unwrapped_type->As(); + if (texture == nullptr) { + continue; + } + + auto is_multisampled = texture->Is(); + auto is_sampled = texture->Is(); + + if ((multisampled && !is_multisampled) || (!multisampled && !is_sampled)) { + continue; + } + + ast::BindingDecoration* binding = nullptr; + ast::GroupDecoration* group = nullptr; + for (auto* deco : var->decorations()) { + if (auto* b = deco->As()) { + binding = b; + } else if (auto* s = deco->As()) { + group = s; + } + } + if (binding == nullptr || group == nullptr) { + continue; + } + + ret.push_back({var, BindingInfo{binding, group}}); + } + + return ret; +} + +} // namespace semantic +} // namespace tint diff --git a/src/semantic/sem_function_test.cc b/src/semantic/sem_function_test.cc new file mode 100644 index 0000000000..49eb8b8423 --- /dev/null +++ b/src/semantic/sem_function_test.cc @@ -0,0 +1,107 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/semantic/function.h" + +#include "src/ast/builtin_decoration.h" +#include "src/ast/location_decoration.h" +#include "src/semantic/test_helper.h" + +namespace tint { +namespace semantic { +namespace { + +using FunctionTest = TestHelper; + +TEST_F(FunctionTest, GetReferenceLocations) { + auto* loc1 = Var("loc1", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(0), + }); + + auto* loc2 = Var("loc2", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(1), + }); + + auto* builtin1 = + Var("builtin1", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(ast::Builtin::kPosition), + }); + + auto* builtin2 = + Var("builtin2", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(ast::Builtin::kFragDepth), + }); + + auto* f = create( + /* referenced_module_vars */ std::vector{loc1, builtin1, + loc2, builtin2}, + /* local_referenced_module_vars */ std::vector{}, + /* ancestor_entry_points */ std::vector{}); + + ASSERT_EQ(f->ReferencedModuleVariables().size(), 4u); + + auto ref_locs = f->ReferencedLocationVariables(); + ASSERT_EQ(ref_locs.size(), 2u); + EXPECT_EQ(ref_locs[0].first, loc1); + EXPECT_EQ(ref_locs[0].second->value(), 0u); + EXPECT_EQ(ref_locs[1].first, loc2); + EXPECT_EQ(ref_locs[1].second->value(), 1u); +} + +TEST_F(FunctionTest, GetReferenceBuiltins) { + auto* loc1 = Var("loc1", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(0), + }); + + auto* loc2 = Var("loc2", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(1), + }); + + auto* builtin1 = + Var("builtin1", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(ast::Builtin::kPosition), + }); + + auto* builtin2 = + Var("builtin2", ast::StorageClass::kInput, ty.i32(), nullptr, + ast::VariableDecorationList{ + create(ast::Builtin::kFragDepth), + }); + + auto* f = create( + /* referenced_module_vars */ std::vector{loc1, builtin1, + loc2, builtin2}, + /* local_referenced_module_vars */ std::vector{}, + /* ancestor_entry_points */ std::vector{}); + + ASSERT_EQ(f->ReferencedModuleVariables().size(), 4u); + + auto ref_locs = f->ReferencedBuiltinVariables(); + ASSERT_EQ(ref_locs.size(), 2u); + EXPECT_EQ(ref_locs[0].first, builtin1); + EXPECT_EQ(ref_locs[0].second->value(), ast::Builtin::kPosition); + EXPECT_EQ(ref_locs[1].first, builtin2); + EXPECT_EQ(ref_locs[1].second->value(), ast::Builtin::kFragDepth); +} + +} // namespace +} // namespace semantic +} // namespace tint diff --git a/src/semantic/test_helper.h b/src/semantic/test_helper.h new file mode 100644 index 0000000000..0a32422781 --- /dev/null +++ b/src/semantic/test_helper.h @@ -0,0 +1,39 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_SEMANTIC_TEST_HELPER_H_ +#define SRC_SEMANTIC_TEST_HELPER_H_ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "src/program_builder.h" + +namespace tint { +namespace semantic { + +/// Helper class for testing +template +class TestHelperBase : public BASE, public ProgramBuilder {}; +using TestHelper = TestHelperBase; + +template +using TestParamHelper = TestHelperBase>; + +} // namespace semantic +} // namespace tint + +#endif // SRC_SEMANTIC_TEST_HELPER_H_ diff --git a/src/semantic/type_mappings.h b/src/semantic/type_mappings.h index 4c59bd3163..ac099a8c59 100644 --- a/src/semantic/type_mappings.h +++ b/src/semantic/type_mappings.h @@ -23,12 +23,14 @@ namespace tint { namespace ast { class Expression; +class Function; } // namespace ast namespace semantic { class Expression; +class Function; /// TypeMappings is a struct that holds dummy `operator()` methods that's used /// by SemanticNodeTypeFor to map AST node types to their corresponding semantic @@ -38,6 +40,7 @@ class Expression; struct TypeMappings { //! @cond Doxygen_Suppress semantic::Expression* operator()(ast::Expression*); + semantic::Function* operator()(ast::Function*); //! @endcond }; diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 64b25dce5d..fec1068960 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -49,6 +49,7 @@ #include "src/clone_context.h" #include "src/program.h" #include "src/program_builder.h" +#include "src/semantic/function.h" #include "src/type/struct_type.h" #include "src/type/u32_type.h" #include "src/type_determiner.h" @@ -143,9 +144,10 @@ Transform::Output FirstIndexOffset::Run(const Program* in) { if (buffer_var == nullptr) { return nullptr; // no transform need, just clone func } + auto* func_sem = in->Sem().Get(func); ast::StatementList statements; for (const auto& data : - func->local_referenced_builtin_variables()) { + func_sem->LocalReferencedBuiltinVariables()) { if (data.second->value() == ast::Builtin::kVertexIndex) { statements.emplace_back(CreateFirstIndexOffset( in->Symbols().NameFor(vertex_index_sym), kFirstVertexName, diff --git a/src/type_determiner.cc b/src/type_determiner.cc index d31cfc190e..61723843ab 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -44,6 +44,7 @@ #include "src/ast/variable_decl_statement.h" #include "src/program_builder.h" #include "src/semantic/expression.h" +#include "src/semantic/function.h" #include "src/type/array_type.h" #include "src/type/bool_type.h" #include "src/type/depth_texture_type.h" @@ -97,9 +98,9 @@ void TypeDeterminer::set_referenced_from_function_if_needed(ast::Variable* var, return; } - current_function_->add_referenced_module_variable(var); + current_function_->referenced_module_vars.Add(var); if (local) { - current_function_->add_local_referenced_module_variable(var); + current_function_->local_referenced_module_vars.Add(var); } } @@ -145,11 +146,14 @@ bool TypeDeterminer::Determine() { } } + CreateSemanticFunctions(); + return true; } void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) { - symbol_to_function_[fn_sym]->add_ancestor_entry_point(ep_sym); + auto* info = symbol_to_function_.at(fn_sym); + info->ancestor_entry_points.Add(ep_sym); for (const auto& callee : caller_to_callee_[fn_sym]) { set_entry_points(callee, ep_sym); @@ -166,9 +170,11 @@ bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) { } bool TypeDeterminer::DetermineFunction(ast::Function* func) { - symbol_to_function_[func->symbol()] = func; + auto* info = function_infos_.Create(func); + symbol_to_function_[func->symbol()] = info; + function_to_info_.emplace(func, info); - current_function_ = func; + current_function_ = info; variable_stack_.push_scope(); for (auto* param : func->params()) { @@ -409,19 +415,20 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) { } } else { if (current_function_) { - caller_to_callee_[current_function_->symbol()].push_back( + caller_to_callee_[current_function_->declaration->symbol()].push_back( ident->symbol()); - auto* callee_func = builder_->AST().Functions().Find(ident->symbol()); - if (callee_func == nullptr) { + auto callee_func_it = symbol_to_function_.find(ident->symbol()); + if (callee_func_it == symbol_to_function_.end()) { set_error(expr->source(), "unable to find called function: " + builder_->Symbols().NameFor(ident->symbol())); return false; } + auto* callee_func = callee_func_it->second; // We inherit any referenced variables from the callee. - for (auto* var : callee_func->referenced_module_variables()) { + for (auto* var : callee_func->referenced_module_vars) { set_referenced_from_function_if_needed(var, false); } } @@ -828,7 +835,7 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) { auto iter = symbol_to_function_.find(symbol); if (iter != symbol_to_function_.end()) { - SetType(expr, iter->second->return_type()); + SetType(expr, iter->second->declaration->return_type()); return true; } @@ -1204,4 +1211,25 @@ void TypeDeterminer::SetType(ast::Expression* expr, type::Type* type) const { builder_->create(type)); } +void TypeDeterminer::CreateSemanticFunctions() const { + for (auto it : function_to_info_) { + auto* func = it.first; + auto* info = it.second; + if (builder_->Sem().Get(func)) { + // ast::Function already has a semantic::Function node. + // This is likely via explicit call to DetermineXXX() in test. + continue; + } + builder_->Sem().Add(func, builder_->create( + info->referenced_module_vars, + info->local_referenced_module_vars, + info->ancestor_entry_points)); + } +} + +TypeDeterminer::FunctionInfo::FunctionInfo(ast::Function* decl) + : declaration(decl) {} + +TypeDeterminer::FunctionInfo::~FunctionInfo() = default; + } // namespace tint diff --git a/src/type_determiner.h b/src/type_determiner.h index 663b5886ea..6f7e3f38f8 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -17,6 +17,7 @@ #include #include +#include #include #include "src/ast/module.h" @@ -98,6 +99,10 @@ class TypeDeterminer { /// @returns false on error bool DetermineStorageTextureSubtype(type::StorageTexture* tex); + /// Creates the semantic::Function nodes and adds them to the semantic::Info + /// of the ProgramBuilder. + void CreateSemanticFunctions() const; + /// Testing method to set a given variable into the type stack /// @param var the variable to set void RegisterVariableForTesting(ast::Variable* var) { @@ -123,6 +128,37 @@ class TypeDeterminer { bool SetIntrinsicIfNeeded(ast::IdentifierExpression* ident); private: + template + struct UniqueVector { + using ConstIterator = typename std::vector::const_iterator; + + void Add(const T& val) { + if (set.count(val) == 0) { + vector.emplace_back(val); + set.emplace(val); + } + } + ConstIterator begin() const { return vector.begin(); } + ConstIterator end() const { return vector.end(); } + operator const std::vector &() const { return vector; } + + private: + std::vector vector; + std::unordered_set set; + }; + + /// Structure holding semantic information about a function. + /// Used to build the semantic::Function nodes at the end of resolving. + struct FunctionInfo { + explicit FunctionInfo(ast::Function* decl); + ~FunctionInfo(); + + ast::Function* const declaration; + UniqueVector referenced_module_vars; + UniqueVector local_referenced_module_vars; + UniqueVector ancestor_entry_points; + }; + void set_error(const Source& src, const std::string& msg); void set_referenced_from_function_if_needed(ast::Variable* var, bool local); void set_entry_points(const Symbol& fn_sym, Symbol ep_sym); @@ -153,8 +189,10 @@ class TypeDeterminer { ProgramBuilder* builder_; std::string error_; ScopeStack variable_stack_; - std::unordered_map symbol_to_function_; - ast::Function* current_function_ = nullptr; + std::unordered_map symbol_to_function_; + std::unordered_map function_to_info_; + FunctionInfo* current_function_ = nullptr; + BlockAllocator function_infos_; // Map from caller functions to callee functions. std::unordered_map> caller_to_callee_; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 8559765cbe..2c749d9cf1 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -52,6 +52,7 @@ #include "src/ast/variable_decl_statement.h" #include "src/program_builder.h" #include "src/semantic/expression.h" +#include "src/semantic/function.h" #include "src/type/alias_type.h" #include "src/type/array_type.h" #include "src/type/bool_type.h" @@ -659,7 +660,10 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) { // Register the function EXPECT_TRUE(td()->Determine()); - const auto& vars = func->referenced_module_variables(); + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + const auto& vars = func_sem->ReferencedModuleVariables(); ASSERT_EQ(vars.size(), 5u); EXPECT_EQ(vars[0], out_var); EXPECT_EQ(vars[1], in_var); @@ -700,7 +704,10 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { // Register the function EXPECT_TRUE(td()->Determine()); - const auto& vars = func2->referenced_module_variables(); + auto* func2_sem = Sem().Get(func2); + ASSERT_NE(func2_sem, nullptr); + + const auto& vars = func2_sem->ReferencedModuleVariables(); ASSERT_EQ(vars.size(), 5u); EXPECT_EQ(vars[0], out_var); EXPECT_EQ(vars[1], in_var); @@ -726,7 +733,10 @@ TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) { // Register the function EXPECT_TRUE(td()->Determine()) << td()->error(); - EXPECT_EQ(func->referenced_module_variables().size(), 0u); + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u); } TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) { @@ -2284,22 +2294,33 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { // Register the functions and calculate the callers ASSERT_TRUE(td()->Determine()) << td()->error(); - const auto& b_eps = func_b->ancestor_entry_points(); + auto* func_b_sem = Sem().Get(func_b); + auto* func_a_sem = Sem().Get(func_a); + auto* func_c_sem = Sem().Get(func_c); + auto* ep_1_sem = Sem().Get(ep_1); + auto* ep_2_sem = Sem().Get(ep_2); + ASSERT_NE(func_b_sem, nullptr); + ASSERT_NE(func_a_sem, nullptr); + ASSERT_NE(func_c_sem, nullptr); + ASSERT_NE(ep_1_sem, nullptr); + ASSERT_NE(ep_2_sem, nullptr); + + const auto& b_eps = func_b_sem->AncestorEntryPoints(); ASSERT_EQ(2u, b_eps.size()); EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]); EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]); - const auto& a_eps = func_a->ancestor_entry_points(); + const auto& a_eps = func_a_sem->AncestorEntryPoints(); ASSERT_EQ(1u, a_eps.size()); EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]); - const auto& c_eps = func_c->ancestor_entry_points(); + const auto& c_eps = func_c_sem->AncestorEntryPoints(); ASSERT_EQ(2u, c_eps.size()); EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]); EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]); - EXPECT_TRUE(ep_1->ancestor_entry_points().empty()); - EXPECT_TRUE(ep_2->ancestor_entry_points().empty()); + EXPECT_TRUE(ep_1_sem->AncestorEntryPoints().empty()); + EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty()); } using TypeDeterminerTextureIntrinsicTest = diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 09040829fc..62e52e5ce6 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -46,6 +46,7 @@ #include "src/ast/variable_decl_statement.h" #include "src/program_builder.h" #include "src/semantic/expression.h" +#include "src/semantic/function.h" #include "src/type/access_control_type.h" #include "src/type/alias_type.h" #include "src/type/array_type.h" @@ -591,15 +592,17 @@ bool GeneratorImpl::EmitCall(std::ostream& pre, out << name << "("; + auto* func_sem = builder_.Sem().Get(func); + bool first = true; - if (has_referenced_in_var_needing_struct(func)) { + if (has_referenced_in_var_needing_struct(func_sem)) { auto var_name = current_ep_var_name(VarType::kIn); if (!var_name.empty()) { out << var_name; first = false; } } - if (has_referenced_out_var_needing_struct(func)) { + if (has_referenced_out_var_needing_struct(func_sem)) { auto var_name = current_ep_var_name(VarType::kOut); if (!var_name.empty()) { if (!first) { @@ -1223,15 +1226,16 @@ bool GeneratorImpl::EmitIf(std::ostream& out, ast::IfStatement* stmt) { return true; } -bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) { - for (auto data : func->referenced_location_variables()) { +bool GeneratorImpl::has_referenced_in_var_needing_struct( + const semantic::Function* func) { + for (auto data : func->ReferencedLocationVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kInput) { return true; } } - for (auto data : func->referenced_builtin_variables()) { + for (auto data : func->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kInput) { return true; @@ -1240,15 +1244,16 @@ bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) { return false; } -bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) { - for (auto data : func->referenced_location_variables()) { +bool GeneratorImpl::has_referenced_out_var_needing_struct( + const semantic::Function* func) { + for (auto data : func->ReferencedLocationVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kOutput) { return true; } } - for (auto data : func->referenced_builtin_variables()) { + for (auto data : func->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kOutput) { return true; @@ -1257,8 +1262,9 @@ bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) { return false; } -bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) { - for (auto data : func->referenced_location_variables()) { +bool GeneratorImpl::has_referenced_var_needing_struct( + const semantic::Function* func) { + for (auto data : func->ReferencedLocationVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kOutput || var->storage_class() == ast::StorageClass::kInput) { @@ -1266,7 +1272,7 @@ bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) { } } - for (auto data : func->referenced_builtin_variables()) { + for (auto data : func->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kOutput || var->storage_class() == ast::StorageClass::kInput) { @@ -1284,14 +1290,16 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) { return true; } + auto* func_sem = builder_.Sem().Get(func); + // TODO(dsinclair): This could be smarter. If the input/outputs for multiple // entry points are the same we could generate a single struct and then have // this determine it's the same struct and just emit once. - bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 && - has_referenced_var_needing_struct(func); + bool emit_duplicate_functions = func_sem->AncestorEntryPoints().size() > 0 && + has_referenced_var_needing_struct(func_sem); if (emit_duplicate_functions) { - for (const auto& ep_sym : func->ancestor_entry_points()) { + for (const auto& ep_sym : func_sem->AncestorEntryPoints()) { if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_sym)) { return false; } @@ -1394,7 +1402,10 @@ bool GeneratorImpl::EmitEntryPointData( std::unordered_set& emitted_globals) { std::vector> in_variables; std::vector> outvariables; - for (auto data : func->referenced_location_variables()) { + auto* func_sem = builder_.Sem().Get(func); + auto func_sym = func->symbol(); + + for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; auto* deco = data.second; @@ -1405,7 +1416,7 @@ bool GeneratorImpl::EmitEntryPointData( } } - for (auto data : func->referenced_builtin_variables()) { + for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; auto* deco = data.second; @@ -1417,7 +1428,7 @@ bool GeneratorImpl::EmitEntryPointData( } bool emitted_uniform = false; - for (auto data : func->referenced_uniform_variables()) { + for (auto data : func_sem->ReferencedUniformVariables()) { auto* var = data.first; // TODO(dsinclair): We're using the binding to make up the buffer number but // we should instead be using a provided mapping that uses both buffer and @@ -1471,7 +1482,7 @@ bool GeneratorImpl::EmitEntryPointData( } bool emitted_storagebuffer = false; - for (auto data : func->referenced_storagebuffer_variables()) { + for (auto data : func_sem->ReferencedStoragebufferVariables()) { auto* var = data.first; auto* binding = data.second.binding; @@ -1500,10 +1511,10 @@ bool GeneratorImpl::EmitEntryPointData( } if (!in_variables.empty()) { - auto in_struct_name = generate_name( - builder_.Symbols().NameFor(func->symbol()) + "_" + kInStructNameSuffix); + auto in_struct_name = generate_name(builder_.Symbols().NameFor(func_sym) + + "_" + kInStructNameSuffix); auto in_var_name = generate_name(kTintStructInVarPrefix); - ep_sym_to_in_data_[func->symbol()] = {in_struct_name, in_var_name}; + ep_sym_to_in_data_[func_sym] = {in_struct_name, in_var_name}; make_indent(out); out << "struct " << in_struct_name << " {" << std::endl; @@ -1547,11 +1558,10 @@ bool GeneratorImpl::EmitEntryPointData( } if (!outvariables.empty()) { - auto outstruct_name = - generate_name(builder_.Symbols().NameFor(func->symbol()) + "_" + - kOutStructNameSuffix); + auto outstruct_name = generate_name(builder_.Symbols().NameFor(func_sym) + + "_" + kOutStructNameSuffix); auto outvar_name = generate_name(kTintStructOutVarPrefix); - ep_sym_to_out_data_[func->symbol()] = {outstruct_name, outvar_name}; + ep_sym_to_out_data_[func_sym] = {outstruct_name, outvar_name}; make_indent(out); out << "struct " << outstruct_name << " {" << std::endl; diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 07f65c6798..24089ae64f 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -356,16 +356,16 @@ class GeneratorImpl { /// Determines if the function needs the input struct passed to it. /// @param func the function to check /// @returns true if there are input struct variables used in the function - bool has_referenced_in_var_needing_struct(ast::Function* func); + bool has_referenced_in_var_needing_struct(const semantic::Function* func); /// Determines if the function needs the output struct passed to it. /// @param func the function to check /// @returns true if there are output struct variables used in the function - bool has_referenced_out_var_needing_struct(ast::Function* func); + bool has_referenced_out_var_needing_struct(const semantic::Function* func); /// Determines if any used program variable requires an input or output /// struct. /// @param func the function to check /// @returns true if an input or output struct is required. - bool has_referenced_var_needing_struct(ast::Function* func); + bool has_referenced_var_needing_struct(const semantic::Function* func); /// @returns the namer for testing Namer* namer_for_testing() { return &namer_; } diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index beb90f1582..0f9d249679 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -51,6 +51,7 @@ #include "src/ast/variable_decl_statement.h" #include "src/program.h" #include "src/semantic/expression.h" +#include "src/semantic/function.h" #include "src/type/access_control_type.h" #include "src/type/alias_type.h" #include "src/type/array_type.h" @@ -504,7 +505,8 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { } } - for (const auto& data : func->referenced_builtin_variables()) { + auto* func_sem = program_->Sem().Get(func); + for (const auto& data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->storage_class() != ast::StorageClass::kInput) { continue; @@ -516,7 +518,7 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { out_ << program_->Symbols().NameFor(var->symbol()); } - for (const auto& data : func->referenced_uniform_variables()) { + for (const auto& data : func_sem->ReferencedUniformVariables()) { auto* var = data.first; if (!first) { out_ << ", "; @@ -525,7 +527,7 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { out_ << program_->Symbols().NameFor(var->symbol()); } - for (const auto& data : func->referenced_storagebuffer_variables()) { + for (const auto& data : func_sem->ReferencedStoragebufferVariables()) { auto* var = data.first; if (!first) { out_ << ", "; @@ -1021,10 +1023,13 @@ bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { } bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { + auto* func_sem = program_->Sem().Get(func); + std::vector> in_locations; std::vector> out_variables; - for (auto data : func->referenced_location_variables()) { + + for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; auto* deco = data.second; @@ -1035,7 +1040,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { } } - for (auto data : func->referenced_builtin_variables()) { + for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; auto* deco = data.second; @@ -1183,7 +1188,8 @@ void GeneratorImpl::EmitStage(ast::PipelineStage stage) { } bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) { - for (auto data : func->referenced_location_variables()) { + auto* func_sem = program_->Sem().Get(func); + for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kInput) { return true; @@ -1193,14 +1199,16 @@ bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) { } bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) { - for (auto data : func->referenced_location_variables()) { + auto* func_sem = program_->Sem().Get(func); + + for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kOutput) { return true; } } - for (auto data : func->referenced_builtin_variables()) { + for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->storage_class() == ast::StorageClass::kOutput) { return true; @@ -1215,6 +1223,8 @@ bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) { } bool GeneratorImpl::EmitFunction(ast::Function* func) { + auto* func_sem = program_->Sem().Get(func); + make_indent(); // Entry points will be emitted later, skip for now. @@ -1225,11 +1235,11 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { // TODO(dsinclair): This could be smarter. If the input/outputs for multiple // entry points are the same we could generate a single struct and then have // this determine it's the same struct and just emit once. - bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 && + bool emit_duplicate_functions = func_sem->AncestorEntryPoints().size() > 0 && has_referenced_var_needing_struct(func); if (emit_duplicate_functions) { - for (const auto& ep_sym : func->ancestor_entry_points()) { + for (const auto& ep_sym : func_sem->AncestorEntryPoints()) { if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) { return false; } @@ -1249,6 +1259,8 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, bool emit_duplicate_functions, Symbol ep_sym) { + auto* func_sem = program_->Sem().Get(func); + auto name = func->symbol().to_str(); if (!EmitType(func->return_type(), "")) { return false; @@ -1294,7 +1306,7 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, } } - for (const auto& data : func->referenced_builtin_variables()) { + for (const auto& data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->storage_class() != ast::StorageClass::kInput) { continue; @@ -1311,7 +1323,7 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, out_ << "& " << program_->Symbols().NameFor(var->symbol()); } - for (const auto& data : func->referenced_uniform_variables()) { + for (const auto& data : func_sem->ReferencedUniformVariables()) { auto* var = data.first; if (!first) { out_ << ", "; @@ -1326,7 +1338,7 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, out_ << "& " << program_->Symbols().NameFor(var->symbol()); } - for (const auto& data : func->referenced_storagebuffer_variables()) { + for (const auto& data : func_sem->ReferencedStoragebufferVariables()) { auto* var = data.first; if (!first) { out_ << ", "; @@ -1404,6 +1416,8 @@ std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const { } bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { + auto* func_sem = program_->Sem().Get(func); + make_indent(); current_ep_sym_ = func->symbol(); @@ -1431,7 +1445,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { first = false; } - for (auto data : func->referenced_builtin_variables()) { + for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->storage_class() != ast::StorageClass::kInput) { continue; @@ -1457,7 +1471,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { << "]]"; } - for (auto data : func->referenced_uniform_variables()) { + for (auto data : func_sem->ReferencedUniformVariables()) { if (!first) { out_ << ", "; } @@ -1485,7 +1499,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { << binding->value() << ")]]"; } - for (auto data : func->referenced_storagebuffer_variables()) { + for (auto data : func_sem->ReferencedStoragebufferVariables()) { if (!first) { out_ << ", "; } diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index d9b936fafa..645c6b4e1b 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -60,6 +60,7 @@ #include "src/ast/variable_decl_statement.h" #include "src/program.h" #include "src/semantic/expression.h" +#include "src/semantic/function.h" #include "src/type/access_control_type.h" #include "src/type/alias_type.h" #include "src/type/array_type.h" @@ -457,7 +458,8 @@ bool Builder::GenerateEntryPoint(ast::Function* func, uint32_t id) { Operand::Int(stage), Operand::Int(id), Operand::String(builder_.Symbols().NameFor(func->symbol()))}; - for (const auto* var : func->referenced_module_variables()) { + auto* func_sem = builder_.Sem().Get(func); + for (const auto* var : func_sem->ReferencedModuleVariables()) { // For SPIR-V 1.3 we only output Input/output variables. If we update to // SPIR-V 1.4 or later this should be all variables. if (var->storage_class() != ast::StorageClass::kInput && @@ -496,7 +498,8 @@ bool Builder::GenerateExecutionModes(ast::Function* func, uint32_t id) { Operand::Int(x), Operand::Int(y), Operand::Int(z)}); } - for (auto builtin : func->referenced_builtin_variables()) { + auto* func_sem = builder_.Sem().Get(func); + for (auto builtin : func_sem->ReferencedBuiltinVariables()) { if (builtin.second->value() == ast::Builtin::kFragDepth) { push_execution_mode( spv::Op::OpExecutionMode, diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc index d904ca718b..71e8291c15 100644 --- a/src/writer/spirv/builder_function_decoration_test.cc +++ b/src/writer/spirv/builder_function_decoration_test.cc @@ -24,6 +24,7 @@ #include "src/ast/stage_decoration.h" #include "src/ast/variable.h" #include "src/ast/workgroup_decoration.h" +#include "src/semantic/function.h" #include "src/type_determiner.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -153,12 +154,6 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUsedInterfaceIds) { AST().AddGlobalVariable(v_out); AST().AddGlobalVariable(v_wg); - td.RegisterVariableForTesting(v_in); - td.RegisterVariableForTesting(v_out); - td.RegisterVariableForTesting(v_wg); - - ASSERT_TRUE(td.DetermineFunction(func)) << td.error(); - spirv::Builder& b = Build(); EXPECT_TRUE(b.GenerateGlobalVariable(v_in)) << b.error(); @@ -285,8 +280,6 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_FragDepth) { }, ast::FunctionDecorationList{}); - func->add_referenced_module_variable(fragdepth); - spirv::Builder& b = Build(); ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error(); diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index eeb4dcde62..7410675778 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -58,6 +58,7 @@ #include "src/ast/variable_decl_statement.h" #include "src/ast/workgroup_decoration.h" #include "src/program.h" +#include "src/semantic/function.h" #include "src/type/access_control_type.h" #include "src/type/alias_type.h" #include "src/type/array_type.h" @@ -146,7 +147,7 @@ bool GeneratorImpl::GenerateEntryPoint(ast::PipelineStage stage, } bool found_func_variable = false; - for (auto* var : func->referenced_module_variables()) { + for (auto* var : program_->Sem().Get(func)->ReferencedModuleVariables()) { if (!EmitVariable(var)) { return false; } @@ -157,7 +158,8 @@ bool GeneratorImpl::GenerateEntryPoint(ast::PipelineStage stage, } for (auto* f : program_->AST().Functions()) { - if (!f->HasAncestorEntryPoint(program_->Symbols().Get(name))) { + auto* f_sem = program_->Sem().Get(f); + if (!f_sem->HasAncestorEntryPoint(program_->Symbols().Get(name))) { continue; }