From 3832b8e05d563ae5405728bc47d8a42c8bb6d3d2 Mon Sep 17 00:00:00 2001 From: Ryan Harrison Date: Wed, 9 Jun 2021 20:45:09 +0000 Subject: [PATCH] Report referenced pipeline overridable constants Adding this information to each entry point reported by the inspector. BUG=tint:855 Change-Id: I043e48afed1503a4267dc4cb198fb86245984551 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53820 Auto-Submit: Ryan Harrison Reviewed-by: Ben Clayton Reviewed-by: James Price Commit-Queue: Ryan Harrison Kokoro: Kokoro --- src/inspector/entry_point.h | 9 ++ src/inspector/inspector.cc | 58 +++++++------ src/inspector/inspector_test.cc | 144 ++++++++++++++++++++++++++------ src/resolver/resolver.cc | 12 +-- src/resolver/resolver.h | 10 ++- src/resolver/resolver_test.cc | 28 +++++++ 6 files changed, 203 insertions(+), 58 deletions(-) diff --git a/src/inspector/entry_point.h b/src/inspector/entry_point.h index aac9a4c34e..b3570bbb53 100644 --- a/src/inspector/entry_point.h +++ b/src/inspector/entry_point.h @@ -45,6 +45,13 @@ struct StageVariable { ComponentType component_type; }; +/// Reflection data about a pipeline overridable constant referenced by an entry +/// point +struct OverridableConstant { + /// Name of the constant + std::string name; +}; + /// Reflection data for an entry point in the shader. struct EntryPoint { /// Constructors @@ -71,6 +78,8 @@ struct EntryPoint { std::vector input_variables; /// List of the output variable accessed via this entry point. std::vector output_variables; + /// List of the pipeline overridable constants accessed via this entry point. + std::vector overridable_constants; /// @returns the size of the workgroup in {x,y,z} format std::tuple workgroup_size() { diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index c0a100c888..8fe3b3246d 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -222,7 +222,6 @@ std::vector Inspector::GetEntryPoints() { entry_point.output_variables); } - // TODO(crbug.com/tint/697): Remove this. for (auto* var : sem->ReferencedModuleVariables()) { auto* decl = var->Declaration(); @@ -231,32 +230,43 @@ std::vector Inspector::GetEntryPoints() { continue; } - StageVariable stage_variable; - stage_variable.name = name; + // TODO(crbug.com/tint/697): Remove this. + { + StageVariable stage_variable; + stage_variable.name = name; - stage_variable.component_type = ComponentType::kUnknown; - auto* type = var->Type()->UnwrapRef(); - if (type->is_float_scalar_or_vector() || type->is_float_matrix()) { - stage_variable.component_type = ComponentType::kFloat; - } else if (type->is_unsigned_scalar_or_vector()) { - stage_variable.component_type = ComponentType::kUInt; - } else if (type->is_signed_scalar_or_vector()) { - stage_variable.component_type = ComponentType::kSInt; + stage_variable.component_type = ComponentType::kUnknown; + auto* type = var->Type()->UnwrapRef(); + if (type->is_float_scalar_or_vector() || type->is_float_matrix()) { + stage_variable.component_type = ComponentType::kFloat; + } else if (type->is_unsigned_scalar_or_vector()) { + stage_variable.component_type = ComponentType::kUInt; + } else if (type->is_signed_scalar_or_vector()) { + stage_variable.component_type = ComponentType::kSInt; + } + + auto* location_decoration = + ast::GetDecoration(decl->decorations()); + if (location_decoration) { + stage_variable.has_location_decoration = true; + stage_variable.location_decoration = location_decoration->value(); + } else { + stage_variable.has_location_decoration = false; + } + + if (var->StorageClass() == ast::StorageClass::kInput) { + entry_point.input_variables.push_back(stage_variable); + } else if (var->StorageClass() == ast::StorageClass::kOutput) { + entry_point.output_variables.push_back(stage_variable); + } } - auto* location_decoration = - ast::GetDecoration(decl->decorations()); - if (location_decoration) { - stage_variable.has_location_decoration = true; - stage_variable.location_decoration = location_decoration->value(); - } else { - stage_variable.has_location_decoration = false; - } - - if (var->StorageClass() == ast::StorageClass::kInput) { - entry_point.input_variables.push_back(stage_variable); - } else if (var->StorageClass() == ast::StorageClass::kOutput) { - entry_point.output_variables.push_back(stage_variable); + { + if (var->IsPipelineConstant()) { + OverridableConstant overridable_constant; + overridable_constant.name = name; + entry_point.overridable_constants.push_back(overridable_constant); + } } } diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc index 994945264f..0e050f6c9b 100644 --- a/src/inspector/inspector_test.cc +++ b/src/inspector/inspector_test.cc @@ -148,10 +148,10 @@ class InspectorHelper : public ProgramBuilder { /// will be added. /// @returns the constant that was created template - ast::Variable* AddConstantWithID(std::string name, - uint32_t id, - ast::Type* type, - T* val) { + ast::Variable* AddOverridableConstantWithID(std::string name, + uint32_t id, + ast::Type* type, + T* val) { ast::Expression* constructor = nullptr; if (val) { constructor = Expr(*val); @@ -169,9 +169,9 @@ class InspectorHelper : public ProgramBuilder { /// will be added. /// @returns the constant that was created template - ast::Variable* AddConstantWithoutID(std::string name, - ast::Type* type, - T* val) { + ast::Variable* AddOverridableConstantWithoutID(std::string name, + ast::Type* type, + T* val) { ast::Expression* constructor = nullptr; if (val) { constructor = Expr(*val); @@ -182,6 +182,25 @@ class InspectorHelper : public ProgramBuilder { }); } + /// Generates a function that references module constant + /// @param func name of the function created + /// @param var name of the constant to be reference + /// @param type type of the const being referenced + /// @param decorations the function decorations + /// @returns a function object + ast::Function* MakeConstReferenceBodyFunction( + std::string func, + std::string var, + ast::Type* type, + ast::DecorationList decorations) { + ast::StatementList stmts; + stmts.emplace_back(Decl(Var("local_" + var, type))); + stmts.emplace_back(Assign("local_" + var, var)); + stmts.emplace_back(Return()); + + return Func(func, ast::VariableList(), ty.void_(), stmts, decorations); + } + /// @param vec Vector of StageVariable to be searched /// @param name Name to be searching for /// @returns true if name is in vec, otherwise false @@ -1446,6 +1465,81 @@ TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables_Legacy) { EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type); } +TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) { + AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); + MakeEmptyBodyFunction("ep_func", {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + EXPECT_EQ(0u, result[0].overridable_constants.size()); +} + +TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) { + AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); + MakeConstReferenceBodyFunction("ep_func", "foo", ty.f32(), + {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + tint::writer::wgsl::Generator writer(program_.get()); + writer.Generate(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overridable_constants.size()); + EXPECT_EQ("foo", result[0].overridable_constants[0].name); +} + +TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) { + AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); + MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); + MakeCallerBodyFunction("ep_func", {"callee_func"}, + {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overridable_constants.size()); + EXPECT_EQ("foo", result[0].overridable_constants[0].name); +} + +TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) { + AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr); + AddOverridableConstantWithID("bar", 2, ty.f32(), nullptr); + MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); + MakeCallerBodyFunction("ep_func", {"callee_func"}, + {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overridable_constants.size()); + EXPECT_EQ("foo", result[0].overridable_constants[0].name); +} + +TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) { + ast::Struct* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()}); + AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0); + MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}}); + MakeCallerBodyFunction("ep_func", {"ub_func"}, + {Stage(ast::PipelineStage::kFragment)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + EXPECT_EQ(0u, result[0].overridable_constants.size()); +} + // TODO(rharrison): Reenable once GetRemappedNameForEntryPoint isn't a pass // through TEST_F(InspectorGetRemappedNameForEntryPointTest, DISABLED_NoFunctions) { @@ -1518,9 +1612,9 @@ TEST_F(InspectorGetRemappedNameForEntryPointTest, TEST_F(InspectorGetConstantIDsTest, Bool) { bool val_true = true; bool val_false = false; - AddConstantWithID("foo", 1, ty.bool_(), nullptr); - AddConstantWithID("bar", 20, ty.bool_(), &val_true); - AddConstantWithID("baz", 300, ty.bool_(), &val_false); + AddOverridableConstantWithID("foo", 1, ty.bool_(), nullptr); + AddOverridableConstantWithID("bar", 20, ty.bool_(), &val_true); + AddOverridableConstantWithID("baz", 300, ty.bool_(), &val_false); Inspector& inspector = Build(); @@ -1541,8 +1635,8 @@ TEST_F(InspectorGetConstantIDsTest, Bool) { TEST_F(InspectorGetConstantIDsTest, U32) { uint32_t val = 42; - AddConstantWithID("foo", 1, ty.u32(), nullptr); - AddConstantWithID("bar", 20, ty.u32(), &val); + AddOverridableConstantWithID("foo", 1, ty.u32(), nullptr); + AddOverridableConstantWithID("bar", 20, ty.u32(), &val); Inspector& inspector = Build(); @@ -1560,9 +1654,9 @@ TEST_F(InspectorGetConstantIDsTest, U32) { TEST_F(InspectorGetConstantIDsTest, I32) { int32_t val_neg = -42; int32_t val_pos = 42; - AddConstantWithID("foo", 1, ty.i32(), nullptr); - AddConstantWithID("bar", 20, ty.i32(), &val_neg); - AddConstantWithID("baz", 300, ty.i32(), &val_pos); + AddOverridableConstantWithID("foo", 1, ty.i32(), nullptr); + AddOverridableConstantWithID("bar", 20, ty.i32(), &val_neg); + AddOverridableConstantWithID("baz", 300, ty.i32(), &val_pos); Inspector& inspector = Build(); @@ -1585,10 +1679,10 @@ TEST_F(InspectorGetConstantIDsTest, Float) { float val_zero = 0.0f; float val_neg = -10.0f; float val_pos = 15.0f; - AddConstantWithID("foo", 1, ty.f32(), nullptr); - AddConstantWithID("bar", 20, ty.f32(), &val_zero); - AddConstantWithID("baz", 300, ty.f32(), &val_neg); - AddConstantWithID("x", 4000, ty.f32(), &val_pos); + AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr); + AddOverridableConstantWithID("bar", 20, ty.f32(), &val_zero); + AddOverridableConstantWithID("baz", 300, ty.f32(), &val_neg); + AddOverridableConstantWithID("x", 4000, ty.f32(), &val_pos); Inspector& inspector = Build(); @@ -1612,12 +1706,12 @@ TEST_F(InspectorGetConstantIDsTest, Float) { } TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) { - AddConstantWithID("v1", 1, ty.f32(), nullptr); - AddConstantWithID("v20", 20, ty.f32(), nullptr); - AddConstantWithID("v300", 300, ty.f32(), nullptr); - auto* a = AddConstantWithoutID("a", ty.f32(), nullptr); - auto* b = AddConstantWithoutID("b", ty.f32(), nullptr); - auto* c = AddConstantWithoutID("c", ty.f32(), nullptr); + AddOverridableConstantWithID("v1", 1, ty.f32(), nullptr); + AddOverridableConstantWithID("v20", 20, ty.f32(), nullptr); + AddOverridableConstantWithID("v300", 300, ty.f32(), nullptr); + auto* a = AddOverridableConstantWithoutID("a", ty.f32(), nullptr); + auto* b = AddOverridableConstantWithoutID("b", ty.f32(), nullptr); + auto* c = AddOverridableConstantWithoutID("c", ty.f32(), nullptr); Inspector& inspector = Build(); diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 41c411423e..e80d11d7c0 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -135,8 +135,8 @@ void Resolver::set_referenced_from_function_if_needed(VariableInfo* var, if (current_function_ == nullptr) { return; } - if (var->storage_class == ast::StorageClass::kNone || - var->storage_class == ast::StorageClass::kFunction) { + + if (var->kind != VariableKind::kGlobal) { return; } @@ -496,7 +496,7 @@ Resolver::VariableInfo* Resolver::Variable(ast::Variable* var, } auto* info = variable_infos_.Create(var, const_cast(type), - type_name, storage_class, access); + type_name, storage_class, access, kind); variable_to_info_.emplace(var, info); return info; @@ -3377,12 +3377,14 @@ Resolver::VariableInfo::VariableInfo(const ast::Variable* decl, sem::Type* ty, const std::string& tn, ast::StorageClass sc, - ast::Access ac) + ast::Access ac, + VariableKind k) : declaration(decl), type(ty), type_name(tn), storage_class(sc), - access(ac) {} + access(ac), + kind(k) {} Resolver::VariableInfo::~VariableInfo() = default; diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 9f9102729e..b58d2b504d 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -86,6 +86,9 @@ class Resolver { bool IsHostShareable(const sem::Type* type); private: + /// Describes the context in which a variable is declared + enum class VariableKind { kParameter, kLocal, kGlobal }; + /// Structure holding semantic information about a variable. /// Used to build the sem::Variable nodes at the end of resolving. struct VariableInfo { @@ -93,7 +96,8 @@ class Resolver { sem::Type* type, const std::string& type_name, ast::StorageClass storage_class, - ast::Access ac); + ast::Access ac, + VariableKind k); ~VariableInfo(); ast::Variable const* const declaration; @@ -103,6 +107,7 @@ class Resolver { ast::Access const access; std::vector users; sem::BindingPoint binding_point; + VariableKind kind; }; struct IntrinsicCallInfo { @@ -190,9 +195,6 @@ class Resolver { sem::Type* const sem; }; - /// Describes the context in which a variable is declared - enum class VariableKind { kParameter, kLocal, kGlobal }; - /// Resolves the program, without creating final the semantic nodes. /// @returns true on success, false on error bool ResolveInternal(); diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index e060682c01..1250deba81 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -23,6 +23,7 @@ #include "src/ast/break_statement.h" #include "src/ast/call_statement.h" #include "src/ast/continue_statement.h" +#include "src/ast/float_literal.h" #include "src/ast/if_statement.h" #include "src/ast/intrinsic_texture_helper_test.h" #include "src/ast/loop_statement.h" @@ -903,6 +904,33 @@ TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) { EXPECT_TRUE(func_sem->ReturnType()->Is()); } +TEST_F(ResolverTest, Function_NotRegisterFunctionConstant) { + auto* func = Func("my_func", ast::VariableList{}, ty.void_(), + { + Decl(Const("var", ty.f32(), Construct(ty.f32()))), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u); + EXPECT_TRUE(func_sem->ReturnType()->Is()); +} + +TEST_F(ResolverTest, Function_NotRegisterFunctionParams) { + auto* func = Func("my_func", {Const("var", ty.f32(), Construct(ty.f32()))}, + ty.void_(), {}); + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u); + EXPECT_TRUE(func_sem->ReturnType()->Is()); +} + TEST_F(ResolverTest, Function_ReturnStatements) { auto* var = Var("foo", ty.f32());