Report referenced pipeline overridable constants

Adding this information to each entry point reported by the inspector.

BUG=tint:855

Change-Id: I043e48afed1503a4267dc4cb198fb86245984551
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53820
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ryan Harrison 2021-06-09 20:45:09 +00:00 committed by Tint LUCI CQ
parent 14b3403148
commit 3832b8e05d
6 changed files with 203 additions and 58 deletions

View File

@ -45,6 +45,13 @@ struct StageVariable {
ComponentType component_type; ComponentType component_type;
}; };
/// Reflection data about a pipeline overridable constant referenced by an entry
/// point
struct OverridableConstant {
/// Name of the constant
std::string name;
};
/// Reflection data for an entry point in the shader. /// Reflection data for an entry point in the shader.
struct EntryPoint { struct EntryPoint {
/// Constructors /// Constructors
@ -71,6 +78,8 @@ struct EntryPoint {
std::vector<StageVariable> input_variables; std::vector<StageVariable> input_variables;
/// List of the output variable accessed via this entry point. /// List of the output variable accessed via this entry point.
std::vector<StageVariable> output_variables; std::vector<StageVariable> output_variables;
/// List of the pipeline overridable constants accessed via this entry point.
std::vector<OverridableConstant> overridable_constants;
/// @returns the size of the workgroup in {x,y,z} format /// @returns the size of the workgroup in {x,y,z} format
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() { std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() {

View File

@ -222,7 +222,6 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
entry_point.output_variables); entry_point.output_variables);
} }
// TODO(crbug.com/tint/697): Remove this.
for (auto* var : sem->ReferencedModuleVariables()) { for (auto* var : sem->ReferencedModuleVariables()) {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
@ -231,32 +230,43 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
continue; continue;
} }
StageVariable stage_variable; // TODO(crbug.com/tint/697): Remove this.
stage_variable.name = name; {
StageVariable stage_variable;
stage_variable.name = name;
stage_variable.component_type = ComponentType::kUnknown; stage_variable.component_type = ComponentType::kUnknown;
auto* type = var->Type()->UnwrapRef(); auto* type = var->Type()->UnwrapRef();
if (type->is_float_scalar_or_vector() || type->is_float_matrix()) { if (type->is_float_scalar_or_vector() || type->is_float_matrix()) {
stage_variable.component_type = ComponentType::kFloat; stage_variable.component_type = ComponentType::kFloat;
} else if (type->is_unsigned_scalar_or_vector()) { } else if (type->is_unsigned_scalar_or_vector()) {
stage_variable.component_type = ComponentType::kUInt; stage_variable.component_type = ComponentType::kUInt;
} else if (type->is_signed_scalar_or_vector()) { } else if (type->is_signed_scalar_or_vector()) {
stage_variable.component_type = ComponentType::kSInt; stage_variable.component_type = ComponentType::kSInt;
}
auto* location_decoration =
ast::GetDecoration<ast::LocationDecoration>(decl->decorations());
if (location_decoration) {
stage_variable.has_location_decoration = true;
stage_variable.location_decoration = location_decoration->value();
} else {
stage_variable.has_location_decoration = false;
}
if (var->StorageClass() == ast::StorageClass::kInput) {
entry_point.input_variables.push_back(stage_variable);
} else if (var->StorageClass() == ast::StorageClass::kOutput) {
entry_point.output_variables.push_back(stage_variable);
}
} }
auto* location_decoration = {
ast::GetDecoration<ast::LocationDecoration>(decl->decorations()); if (var->IsPipelineConstant()) {
if (location_decoration) { OverridableConstant overridable_constant;
stage_variable.has_location_decoration = true; overridable_constant.name = name;
stage_variable.location_decoration = location_decoration->value(); entry_point.overridable_constants.push_back(overridable_constant);
} else { }
stage_variable.has_location_decoration = false;
}
if (var->StorageClass() == ast::StorageClass::kInput) {
entry_point.input_variables.push_back(stage_variable);
} else if (var->StorageClass() == ast::StorageClass::kOutput) {
entry_point.output_variables.push_back(stage_variable);
} }
} }

View File

@ -148,10 +148,10 @@ class InspectorHelper : public ProgramBuilder {
/// will be added. /// will be added.
/// @returns the constant that was created /// @returns the constant that was created
template <class T> template <class T>
ast::Variable* AddConstantWithID(std::string name, ast::Variable* AddOverridableConstantWithID(std::string name,
uint32_t id, uint32_t id,
ast::Type* type, ast::Type* type,
T* val) { T* val) {
ast::Expression* constructor = nullptr; ast::Expression* constructor = nullptr;
if (val) { if (val) {
constructor = Expr(*val); constructor = Expr(*val);
@ -169,9 +169,9 @@ class InspectorHelper : public ProgramBuilder {
/// will be added. /// will be added.
/// @returns the constant that was created /// @returns the constant that was created
template <class T> template <class T>
ast::Variable* AddConstantWithoutID(std::string name, ast::Variable* AddOverridableConstantWithoutID(std::string name,
ast::Type* type, ast::Type* type,
T* val) { T* val) {
ast::Expression* constructor = nullptr; ast::Expression* constructor = nullptr;
if (val) { if (val) {
constructor = Expr(*val); constructor = Expr(*val);
@ -182,6 +182,25 @@ class InspectorHelper : public ProgramBuilder {
}); });
} }
/// Generates a function that references module constant
/// @param func name of the function created
/// @param var name of the constant to be reference
/// @param type type of the const being referenced
/// @param decorations the function decorations
/// @returns a function object
ast::Function* MakeConstReferenceBodyFunction(
std::string func,
std::string var,
ast::Type* type,
ast::DecorationList decorations) {
ast::StatementList stmts;
stmts.emplace_back(Decl(Var("local_" + var, type)));
stmts.emplace_back(Assign("local_" + var, var));
stmts.emplace_back(Return());
return Func(func, ast::VariableList(), ty.void_(), stmts, decorations);
}
/// @param vec Vector of StageVariable to be searched /// @param vec Vector of StageVariable to be searched
/// @param name Name to be searching for /// @param name Name to be searching for
/// @returns true if name is in vec, otherwise false /// @returns true if name is in vec, otherwise false
@ -1446,6 +1465,81 @@ TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables_Legacy) {
EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type); EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) {
AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
MakeEmptyBodyFunction("ep_func", {Stage(ast::PipelineStage::kCompute)});
Inspector& inspector = Build();
auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size());
EXPECT_EQ(0u, result[0].overridable_constants.size());
}
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
MakeConstReferenceBodyFunction("ep_func", "foo", ty.f32(),
{Stage(ast::PipelineStage::kCompute)});
Inspector& inspector = Build();
tint::writer::wgsl::Generator writer(program_.get());
writer.Generate();
auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
}
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction("ep_func", {"callee_func"},
{Stage(ast::PipelineStage::kCompute)});
Inspector& inspector = Build();
auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
}
TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) {
AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr);
AddOverridableConstantWithID<float>("bar", 2, ty.f32(), nullptr);
MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction("ep_func", {"callee_func"},
{Stage(ast::PipelineStage::kCompute)});
Inspector& inspector = Build();
auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
}
TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) {
ast::Struct* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()});
AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0);
MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}});
MakeCallerBodyFunction("ep_func", {"ub_func"},
{Stage(ast::PipelineStage::kFragment)});
Inspector& inspector = Build();
auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size());
EXPECT_EQ(0u, result[0].overridable_constants.size());
}
// TODO(rharrison): Reenable once GetRemappedNameForEntryPoint isn't a pass // TODO(rharrison): Reenable once GetRemappedNameForEntryPoint isn't a pass
// through // through
TEST_F(InspectorGetRemappedNameForEntryPointTest, DISABLED_NoFunctions) { TEST_F(InspectorGetRemappedNameForEntryPointTest, DISABLED_NoFunctions) {
@ -1518,9 +1612,9 @@ TEST_F(InspectorGetRemappedNameForEntryPointTest,
TEST_F(InspectorGetConstantIDsTest, Bool) { TEST_F(InspectorGetConstantIDsTest, Bool) {
bool val_true = true; bool val_true = true;
bool val_false = false; bool val_false = false;
AddConstantWithID<bool>("foo", 1, ty.bool_(), nullptr); AddOverridableConstantWithID<bool>("foo", 1, ty.bool_(), nullptr);
AddConstantWithID<bool>("bar", 20, ty.bool_(), &val_true); AddOverridableConstantWithID<bool>("bar", 20, ty.bool_(), &val_true);
AddConstantWithID<bool>("baz", 300, ty.bool_(), &val_false); AddOverridableConstantWithID<bool>("baz", 300, ty.bool_(), &val_false);
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -1541,8 +1635,8 @@ TEST_F(InspectorGetConstantIDsTest, Bool) {
TEST_F(InspectorGetConstantIDsTest, U32) { TEST_F(InspectorGetConstantIDsTest, U32) {
uint32_t val = 42; uint32_t val = 42;
AddConstantWithID<uint32_t>("foo", 1, ty.u32(), nullptr); AddOverridableConstantWithID<uint32_t>("foo", 1, ty.u32(), nullptr);
AddConstantWithID<uint32_t>("bar", 20, ty.u32(), &val); AddOverridableConstantWithID<uint32_t>("bar", 20, ty.u32(), &val);
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -1560,9 +1654,9 @@ TEST_F(InspectorGetConstantIDsTest, U32) {
TEST_F(InspectorGetConstantIDsTest, I32) { TEST_F(InspectorGetConstantIDsTest, I32) {
int32_t val_neg = -42; int32_t val_neg = -42;
int32_t val_pos = 42; int32_t val_pos = 42;
AddConstantWithID<int32_t>("foo", 1, ty.i32(), nullptr); AddOverridableConstantWithID<int32_t>("foo", 1, ty.i32(), nullptr);
AddConstantWithID<int32_t>("bar", 20, ty.i32(), &val_neg); AddOverridableConstantWithID<int32_t>("bar", 20, ty.i32(), &val_neg);
AddConstantWithID<int32_t>("baz", 300, ty.i32(), &val_pos); AddOverridableConstantWithID<int32_t>("baz", 300, ty.i32(), &val_pos);
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -1585,10 +1679,10 @@ TEST_F(InspectorGetConstantIDsTest, Float) {
float val_zero = 0.0f; float val_zero = 0.0f;
float val_neg = -10.0f; float val_neg = -10.0f;
float val_pos = 15.0f; float val_pos = 15.0f;
AddConstantWithID<float>("foo", 1, ty.f32(), nullptr); AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr);
AddConstantWithID<float>("bar", 20, ty.f32(), &val_zero); AddOverridableConstantWithID<float>("bar", 20, ty.f32(), &val_zero);
AddConstantWithID<float>("baz", 300, ty.f32(), &val_neg); AddOverridableConstantWithID<float>("baz", 300, ty.f32(), &val_neg);
AddConstantWithID<float>("x", 4000, ty.f32(), &val_pos); AddOverridableConstantWithID<float>("x", 4000, ty.f32(), &val_pos);
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -1612,12 +1706,12 @@ TEST_F(InspectorGetConstantIDsTest, Float) {
} }
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) { TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
AddConstantWithID<float>("v1", 1, ty.f32(), nullptr); AddOverridableConstantWithID<float>("v1", 1, ty.f32(), nullptr);
AddConstantWithID<float>("v20", 20, ty.f32(), nullptr); AddOverridableConstantWithID<float>("v20", 20, ty.f32(), nullptr);
AddConstantWithID<float>("v300", 300, ty.f32(), nullptr); AddOverridableConstantWithID<float>("v300", 300, ty.f32(), nullptr);
auto* a = AddConstantWithoutID<float>("a", ty.f32(), nullptr); auto* a = AddOverridableConstantWithoutID<float>("a", ty.f32(), nullptr);
auto* b = AddConstantWithoutID<float>("b", ty.f32(), nullptr); auto* b = AddOverridableConstantWithoutID<float>("b", ty.f32(), nullptr);
auto* c = AddConstantWithoutID<float>("c", ty.f32(), nullptr); auto* c = AddOverridableConstantWithoutID<float>("c", ty.f32(), nullptr);
Inspector& inspector = Build(); Inspector& inspector = Build();

View File

@ -135,8 +135,8 @@ void Resolver::set_referenced_from_function_if_needed(VariableInfo* var,
if (current_function_ == nullptr) { if (current_function_ == nullptr) {
return; return;
} }
if (var->storage_class == ast::StorageClass::kNone ||
var->storage_class == ast::StorageClass::kFunction) { if (var->kind != VariableKind::kGlobal) {
return; return;
} }
@ -496,7 +496,7 @@ Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
} }
auto* info = variable_infos_.Create(var, const_cast<sem::Type*>(type), auto* info = variable_infos_.Create(var, const_cast<sem::Type*>(type),
type_name, storage_class, access); type_name, storage_class, access, kind);
variable_to_info_.emplace(var, info); variable_to_info_.emplace(var, info);
return info; return info;
@ -3377,12 +3377,14 @@ Resolver::VariableInfo::VariableInfo(const ast::Variable* decl,
sem::Type* ty, sem::Type* ty,
const std::string& tn, const std::string& tn,
ast::StorageClass sc, ast::StorageClass sc,
ast::Access ac) ast::Access ac,
VariableKind k)
: declaration(decl), : declaration(decl),
type(ty), type(ty),
type_name(tn), type_name(tn),
storage_class(sc), storage_class(sc),
access(ac) {} access(ac),
kind(k) {}
Resolver::VariableInfo::~VariableInfo() = default; Resolver::VariableInfo::~VariableInfo() = default;

View File

@ -86,6 +86,9 @@ class Resolver {
bool IsHostShareable(const sem::Type* type); bool IsHostShareable(const sem::Type* type);
private: private:
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
/// Structure holding semantic information about a variable. /// Structure holding semantic information about a variable.
/// Used to build the sem::Variable nodes at the end of resolving. /// Used to build the sem::Variable nodes at the end of resolving.
struct VariableInfo { struct VariableInfo {
@ -93,7 +96,8 @@ class Resolver {
sem::Type* type, sem::Type* type,
const std::string& type_name, const std::string& type_name,
ast::StorageClass storage_class, ast::StorageClass storage_class,
ast::Access ac); ast::Access ac,
VariableKind k);
~VariableInfo(); ~VariableInfo();
ast::Variable const* const declaration; ast::Variable const* const declaration;
@ -103,6 +107,7 @@ class Resolver {
ast::Access const access; ast::Access const access;
std::vector<ast::IdentifierExpression*> users; std::vector<ast::IdentifierExpression*> users;
sem::BindingPoint binding_point; sem::BindingPoint binding_point;
VariableKind kind;
}; };
struct IntrinsicCallInfo { struct IntrinsicCallInfo {
@ -190,9 +195,6 @@ class Resolver {
sem::Type* const sem; sem::Type* const sem;
}; };
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
/// Resolves the program, without creating final the semantic nodes. /// Resolves the program, without creating final the semantic nodes.
/// @returns true on success, false on error /// @returns true on success, false on error
bool ResolveInternal(); bool ResolveInternal();

View File

@ -23,6 +23,7 @@
#include "src/ast/break_statement.h" #include "src/ast/break_statement.h"
#include "src/ast/call_statement.h" #include "src/ast/call_statement.h"
#include "src/ast/continue_statement.h" #include "src/ast/continue_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/if_statement.h" #include "src/ast/if_statement.h"
#include "src/ast/intrinsic_texture_helper_test.h" #include "src/ast/intrinsic_texture_helper_test.h"
#include "src/ast/loop_statement.h" #include "src/ast/loop_statement.h"
@ -903,6 +904,33 @@ TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) {
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>()); EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
} }
TEST_F(ResolverTest, Function_NotRegisterFunctionConstant) {
auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
{
Decl(Const("var", ty.f32(), Construct(ty.f32()))),
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
}
TEST_F(ResolverTest, Function_NotRegisterFunctionParams) {
auto* func = Func("my_func", {Const("var", ty.f32(), Construct(ty.f32()))},
ty.void_(), {});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
}
TEST_F(ResolverTest, Function_ReturnStatements) { TEST_F(ResolverTest, Function_ReturnStatements) {
auto* var = Var("foo", ty.f32()); auto* var = Var("foo", ty.f32());