CloneContext: Pass the vector to InsertBefore()
There's usually only ever one vector we want to insert into. Inserting into *all* vectors that happen to contain the reference object is likely unintended, and is a foot-gun waiting to go off. Change-Id: I533084ccad102fc998b851fd238fd6bea9299450 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46445 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
90f43cf87f
commit
b4275c870e
|
@ -20,6 +20,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::Cloneable);
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
|
||||||
|
CloneContext::ListTransforms::ListTransforms() = default;
|
||||||
|
CloneContext::ListTransforms::~ListTransforms() = default;
|
||||||
|
|
||||||
CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
|
CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
|
||||||
: dst(to), src(from) {}
|
: dst(to), src(from) {}
|
||||||
CloneContext::~CloneContext() = default;
|
CloneContext::~CloneContext() = default;
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#ifndef SRC_CLONE_CONTEXT_H_
|
#ifndef SRC_CLONE_CONTEXT_H_
|
||||||
#define SRC_CLONE_CONTEXT_H_
|
#define SRC_CLONE_CONTEXT_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -178,14 +179,29 @@ class CloneContext {
|
||||||
std::vector<T*> Clone(const std::vector<T*>& v) {
|
std::vector<T*> Clone(const std::vector<T*>& v) {
|
||||||
std::vector<T*> out;
|
std::vector<T*> out;
|
||||||
out.reserve(v.size());
|
out.reserve(v.size());
|
||||||
|
|
||||||
|
auto list_transform_it = list_transforms_.find(&v);
|
||||||
|
if (list_transform_it != list_transforms_.end()) {
|
||||||
|
const auto& transforms = list_transform_it->second;
|
||||||
for (auto& el : v) {
|
for (auto& el : v) {
|
||||||
auto it = insert_before_.find(el);
|
auto insert_before_it = transforms.insert_before_.find(el);
|
||||||
if (it != insert_before_.end()) {
|
if (insert_before_it != transforms.insert_before_.end()) {
|
||||||
for (auto insert : it->second) {
|
for (auto insert : insert_before_it->second) {
|
||||||
out.emplace_back(CheckedCast<T>(insert));
|
out.emplace_back(CheckedCast<T>(insert));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out.emplace_back(Clone(el));
|
out.emplace_back(Clone(el));
|
||||||
|
auto insert_after_it = transforms.insert_after_.find(el);
|
||||||
|
if (insert_after_it != transforms.insert_after_.end()) {
|
||||||
|
for (auto insert : insert_after_it->second) {
|
||||||
|
out.emplace_back(CheckedCast<T>(insert));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (auto& el : v) {
|
||||||
|
out.emplace_back(Clone(el));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
@ -293,15 +309,46 @@ class CloneContext {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inserts `object` before `before` whenever a vector containing `object` is
|
/// Inserts `object` before `before` whenever `vector` is cloned.
|
||||||
/// cloned.
|
/// @param vector the vector in #src
|
||||||
/// @param before a pointer to the object in #src
|
/// @param before a pointer to the object in #src
|
||||||
/// @param object a pointer to the object in #dst that will be inserted before
|
/// @param object a pointer to the object in #dst that will be inserted before
|
||||||
/// any occurrence of the clone of `before`
|
/// any occurrence of the clone of `before`
|
||||||
/// @returns this CloneContext so calls can be chained
|
/// @returns this CloneContext so calls can be chained
|
||||||
template <typename BEFORE, typename OBJECT>
|
template <typename T, typename BEFORE, typename OBJECT>
|
||||||
CloneContext& InsertBefore(BEFORE* before, OBJECT* object) {
|
CloneContext& InsertBefore(const std::vector<T>& vector,
|
||||||
auto& list = insert_before_[before];
|
BEFORE* before,
|
||||||
|
OBJECT* object) {
|
||||||
|
if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
|
||||||
|
TINT_ICE(Diagnostics())
|
||||||
|
<< "CloneContext::InsertBefore() vector does not contain before";
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& transforms = list_transforms_[&vector];
|
||||||
|
auto& list = transforms.insert_before_[before];
|
||||||
|
list.emplace_back(object);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inserts `object` after `after` whenever `vector` is cloned.
|
||||||
|
/// @param vector the vector in #src
|
||||||
|
/// @param after a pointer to the object in #src
|
||||||
|
/// @param object a pointer to the object in #dst that will be inserted after
|
||||||
|
/// any occurrence of the clone of `after`
|
||||||
|
/// @returns this CloneContext so calls can be chained
|
||||||
|
template <typename T, typename AFTER, typename OBJECT>
|
||||||
|
CloneContext& InsertAfter(const std::vector<T>& vector,
|
||||||
|
AFTER* after,
|
||||||
|
OBJECT* object) {
|
||||||
|
if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
|
||||||
|
TINT_ICE(Diagnostics())
|
||||||
|
<< "CloneContext::InsertAfter() vector does not contain after";
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& transforms = list_transforms_[&vector];
|
||||||
|
auto& list = transforms.insert_after_[after];
|
||||||
list.emplace_back(object);
|
list.emplace_back(object);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -380,17 +427,33 @@ class CloneContext {
|
||||||
/// A vector of Cloneable*
|
/// A vector of Cloneable*
|
||||||
using CloneableList = std::vector<Cloneable*>;
|
using CloneableList = std::vector<Cloneable*>;
|
||||||
|
|
||||||
|
// Transformations to be applied to a list (vector)
|
||||||
|
struct ListTransforms {
|
||||||
|
/// Constructor
|
||||||
|
ListTransforms();
|
||||||
|
/// Destructor
|
||||||
|
~ListTransforms();
|
||||||
|
|
||||||
|
/// A map of object in #src to the list of cloned objects in #dst.
|
||||||
|
/// Clone(const std::vector<T*>& v) will use this to insert the map-value
|
||||||
|
/// list into the target vector before cloning and inserting the map-key.
|
||||||
|
std::unordered_map<const Cloneable*, CloneableList> insert_before_;
|
||||||
|
|
||||||
|
/// A map of object in #src to the list of cloned objects in #dst.
|
||||||
|
/// Clone(const std::vector<T*>& v) will use this to insert the map-value
|
||||||
|
/// list into the target vector after cloning and inserting the map-key.
|
||||||
|
std::unordered_map<const Cloneable*, CloneableList> insert_after_;
|
||||||
|
};
|
||||||
|
|
||||||
/// A map of object in #src to their cloned equivalent in #dst
|
/// A map of object in #src to their cloned equivalent in #dst
|
||||||
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
|
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
|
||||||
|
|
||||||
/// A map of object in #src to the list of cloned objects in #dst.
|
|
||||||
/// Clone(const std::vector<T*>& v) will use this to insert the map-value list
|
|
||||||
/// into the target vector/ before cloning and inserting the map-key.
|
|
||||||
std::unordered_map<const Cloneable*, CloneableList> insert_before_;
|
|
||||||
|
|
||||||
/// Cloneable transform functions registered with ReplaceAll()
|
/// Cloneable transform functions registered with ReplaceAll()
|
||||||
std::vector<CloneableTransform> transforms_;
|
std::vector<CloneableTransform> transforms_;
|
||||||
|
|
||||||
|
/// Map of std::vector pointer to transforms for that list
|
||||||
|
std::unordered_map<const void*, ListTransforms> list_transforms_;
|
||||||
|
|
||||||
/// Symbol transform registered with ReplaceAll()
|
/// Symbol transform registered with ReplaceAll()
|
||||||
SymbolTransform symbol_transform_;
|
SymbolTransform symbol_transform_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -287,8 +287,9 @@ TEST(CloneContext, CloneWithInsertBefore) {
|
||||||
ProgramBuilder cloned;
|
ProgramBuilder cloned;
|
||||||
auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
|
auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
|
||||||
|
|
||||||
auto* cloned_root = CloneContext(&cloned, &original)
|
auto* cloned_root =
|
||||||
.InsertBefore(original_root->b, insertion)
|
CloneContext(&cloned, &original)
|
||||||
|
.InsertBefore(original_root->vec, original_root->b, insertion)
|
||||||
.Clone(original_root);
|
.Clone(original_root);
|
||||||
|
|
||||||
EXPECT_EQ(cloned_root->vec.size(), 4u);
|
EXPECT_EQ(cloned_root->vec.size(), 4u);
|
||||||
|
@ -303,6 +304,36 @@ TEST(CloneContext, CloneWithInsertBefore) {
|
||||||
EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
|
EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CloneContext, CloneWithInsertAfter) {
|
||||||
|
ProgramBuilder builder;
|
||||||
|
auto* original_root =
|
||||||
|
builder.create<Node>(builder.Symbols().Register("root"));
|
||||||
|
original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
|
||||||
|
original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
|
||||||
|
original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
|
||||||
|
original_root->vec = {original_root->a, original_root->b, original_root->c};
|
||||||
|
Program original(std::move(builder));
|
||||||
|
|
||||||
|
ProgramBuilder cloned;
|
||||||
|
auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
|
||||||
|
|
||||||
|
auto* cloned_root =
|
||||||
|
CloneContext(&cloned, &original)
|
||||||
|
.InsertAfter(original_root->vec, original_root->b, insertion)
|
||||||
|
.Clone(original_root);
|
||||||
|
|
||||||
|
EXPECT_EQ(cloned_root->vec.size(), 4u);
|
||||||
|
EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
|
||||||
|
EXPECT_EQ(cloned_root->vec[1], cloned_root->b);
|
||||||
|
EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
|
||||||
|
|
||||||
|
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
|
||||||
|
EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
|
||||||
|
EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("b"));
|
||||||
|
EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("insertion"));
|
||||||
|
EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
|
TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
|
||||||
EXPECT_FATAL_FAILURE(
|
EXPECT_FATAL_FAILURE(
|
||||||
{
|
{
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include "src/program_builder.h"
|
#include "src/program_builder.h"
|
||||||
#include "src/semantic/function.h"
|
#include "src/semantic/function.h"
|
||||||
|
#include "src/semantic/statement.h"
|
||||||
#include "src/semantic/variable.h"
|
#include "src/semantic/variable.h"
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
@ -119,7 +120,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
||||||
// Initialize it with the value extracted from the new struct parameter.
|
// Initialize it with the value extracted from the new struct parameter.
|
||||||
auto* func_const = ctx.dst->Const(
|
auto* func_const = ctx.dst->Const(
|
||||||
func_const_symbol, ctx.Clone(param_ty), func_const_initializer);
|
func_const_symbol, ctx.Clone(param_ty), func_const_initializer);
|
||||||
ctx.InsertBefore(*func->body()->begin(),
|
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
|
||||||
ctx.dst->WrapInStatement(func_const));
|
ctx.dst->WrapInStatement(func_const));
|
||||||
|
|
||||||
// Replace all uses of the function parameter with the function const.
|
// Replace all uses of the function parameter with the function const.
|
||||||
|
@ -134,7 +135,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
||||||
ctx.dst->Symbols().New(),
|
ctx.dst->Symbols().New(),
|
||||||
ctx.dst->create<ast::Struct>(new_struct_members,
|
ctx.dst->create<ast::Struct>(new_struct_members,
|
||||||
ast::DecorationList{}));
|
ast::DecorationList{}));
|
||||||
ctx.InsertBefore(func, in_struct);
|
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct);
|
||||||
|
|
||||||
// Create a new function parameter using this struct type.
|
// Create a new function parameter using this struct type.
|
||||||
auto* struct_param = ctx.dst->Var(new_struct_param_symbol, in_struct,
|
auto* struct_param = ctx.dst->Var(new_struct_param_symbol, in_struct,
|
||||||
|
@ -177,12 +178,13 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
||||||
ctx.dst->Symbols().New(),
|
ctx.dst->Symbols().New(),
|
||||||
ctx.dst->create<ast::Struct>(new_struct_members,
|
ctx.dst->create<ast::Struct>(new_struct_members,
|
||||||
ast::DecorationList{}));
|
ast::DecorationList{}));
|
||||||
ctx.InsertBefore(func, out_struct);
|
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, out_struct);
|
||||||
new_ret_type = out_struct;
|
new_ret_type = out_struct;
|
||||||
|
|
||||||
// Replace all return statements.
|
// Replace all return statements.
|
||||||
auto* sem_func = ctx.src->Sem().Get(func);
|
auto* sem_func = ctx.src->Sem().Get(func);
|
||||||
for (auto* ret : sem_func->ReturnStatements()) {
|
for (auto* ret : sem_func->ReturnStatements()) {
|
||||||
|
auto* ret_sem = ctx.src->Sem().Get(ret);
|
||||||
// Reconstruct the return value using the newly created struct.
|
// Reconstruct the return value using the newly created struct.
|
||||||
auto* new_ret_value = ctx.Clone(ret->value());
|
auto* new_ret_value = ctx.Clone(ret->value());
|
||||||
ast::ExpressionList ret_values;
|
ast::ExpressionList ret_values;
|
||||||
|
@ -193,7 +195,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
||||||
auto temp = ctx.dst->Symbols().New();
|
auto temp = ctx.dst->Symbols().New();
|
||||||
auto* temp_var = ctx.dst->Decl(
|
auto* temp_var = ctx.dst->Decl(
|
||||||
ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value));
|
ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value));
|
||||||
ctx.InsertBefore(ret, temp_var);
|
ctx.InsertBefore(ret_sem->Block()->statements(), ret, temp_var);
|
||||||
new_ret_value = ctx.dst->Expr(temp);
|
new_ret_value = ctx.dst->Expr(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,8 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const {
|
||||||
auto* dst_ident = ctx.dst->Expr(dst_symbol);
|
auto* dst_ident = ctx.dst->Expr(dst_symbol);
|
||||||
|
|
||||||
// Insert the constant before the usage
|
// Insert the constant before the usage
|
||||||
ctx.InsertBefore(src_stmt, dst_var_decl);
|
ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
|
||||||
|
dst_var_decl);
|
||||||
// Replace the inlined array with a reference to the constant
|
// Replace the inlined array with a reference to the constant
|
||||||
ctx.Replace(src_init, dst_ident);
|
ctx.Replace(src_init, dst_ident);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "src/ast/return_statement.h"
|
#include "src/ast/return_statement.h"
|
||||||
#include "src/program_builder.h"
|
#include "src/program_builder.h"
|
||||||
#include "src/semantic/function.h"
|
#include "src/semantic/function.h"
|
||||||
|
#include "src/semantic/statement.h"
|
||||||
#include "src/semantic/variable.h"
|
#include "src/semantic/variable.h"
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
@ -162,13 +163,15 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
||||||
return_func_symbol, ast::VariableList{store_value},
|
return_func_symbol, ast::VariableList{store_value},
|
||||||
ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
|
ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
|
||||||
ast::DecorationList{}, ast::DecorationList{});
|
ast::DecorationList{}, ast::DecorationList{});
|
||||||
ctx.InsertBefore(func, return_func);
|
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, return_func);
|
||||||
|
|
||||||
// Replace all return statements with calls to the output function.
|
// Replace all return statements with calls to the output function.
|
||||||
auto* sem_func = ctx.src->Sem().Get(func);
|
auto* sem_func = ctx.src->Sem().Get(func);
|
||||||
for (auto* ret : sem_func->ReturnStatements()) {
|
for (auto* ret : sem_func->ReturnStatements()) {
|
||||||
|
auto* ret_sem = ctx.src->Sem().Get(ret);
|
||||||
auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
|
auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
|
||||||
ctx.InsertBefore(ret, ctx.dst->create<ast::CallStatement>(call));
|
ctx.InsertBefore(ret_sem->Block()->statements(), ret,
|
||||||
|
ctx.dst->create<ast::CallStatement>(call));
|
||||||
ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
|
ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -247,7 +250,7 @@ Symbol Spirv::HoistToInputVariables(
|
||||||
auto* global_var =
|
auto* global_var =
|
||||||
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
|
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
|
||||||
ast::StorageClass::kInput, nullptr, new_decorations);
|
ast::StorageClass::kInput, nullptr, new_decorations);
|
||||||
ctx.InsertBefore(func, global_var);
|
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
|
||||||
return global_var_symbol;
|
return global_var_symbol;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -269,7 +272,8 @@ Symbol Spirv::HoistToInputVariables(
|
||||||
// Create a function-scope variable for the struct.
|
// Create a function-scope variable for the struct.
|
||||||
auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
|
auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
|
||||||
auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
|
auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
|
||||||
ctx.InsertBefore(*func->body()->begin(), ctx.dst->WrapInStatement(func_var));
|
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
|
||||||
|
ctx.dst->WrapInStatement(func_var));
|
||||||
return func_var_symbol;
|
return func_var_symbol;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,7 +296,7 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
|
||||||
auto* global_var =
|
auto* global_var =
|
||||||
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
|
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
|
||||||
ast::StorageClass::kOutput, nullptr, new_decorations);
|
ast::StorageClass::kOutput, nullptr, new_decorations);
|
||||||
ctx.InsertBefore(func, global_var);
|
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
|
||||||
|
|
||||||
// Create the assignment instruction.
|
// Create the assignment instruction.
|
||||||
ast::Expression* rhs = ctx.dst->Expr(store_value);
|
ast::Expression* rhs = ctx.dst->Expr(store_value);
|
||||||
|
|
Loading…
Reference in New Issue