CloneContext: Don't create named symbols from unnamed

Registering a new Symbol with the NameFor() of the source symbol creates
a new *named* symbol. When mixing these with unnamed symbols we can have
collisions.

Update CloneContext::Clone(Symbol) to properly clone unnamed symbols.

Update (most) the transforms to ctx.Clone() the symbols instead of
registering the names directly.

Fix up the tests where the symbol IDs have changed.

Note: We can still have symbol collisions if a program is authored with
identifiers like 'tint_symbol_3'. This will be fixed up in a later
change.

Change-Id: I0ce559644da3d60e1060f2eef185fa55ae284521
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46866
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-04-07 11:16:01 +00:00 committed by Commit Bot service account
parent 3bfb6817df
commit 1b8d9f227b
16 changed files with 295 additions and 193 deletions

View File

@ -457,6 +457,7 @@ source_set("libtint_core_src") {
"type/vector_type.h", "type/vector_type.h",
"type/void_type.cc", "type/void_type.cc",
"type/void_type.h", "type/void_type.h",
"utils/get_or_create.h",
"utils/hash.h", "utils/hash.h",
"utils/math.h", "utils/math.h",
"utils/unique_vector.h", "utils/unique_vector.h",

View File

@ -272,6 +272,7 @@ set(TINT_LIB_SRCS
type/vector_type.h type/vector_type.h
type/void_type.cc type/void_type.cc
type/void_type.h type/void_type.h
utils/get_or_create.h
utils/hash.h utils/hash.h
utils/math.h utils/math.h
utils/unique_vector.h utils/unique_vector.h
@ -519,6 +520,7 @@ if(${TINT_BUILD_TESTS})
type/vector_type_test.cc type/vector_type_test.cc
utils/command_test.cc utils/command_test.cc
utils/command.h utils/command.h
utils/get_or_create_test.cc
utils/hash_test.cc utils/hash_test.cc
utils/math_test.cc utils/math_test.cc
utils/tmpfile_test.cc utils/tmpfile_test.cc

View File

@ -15,6 +15,7 @@
#include "src/clone_context.h" #include "src/clone_context.h"
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::Cloneable); TINT_INSTANTIATE_TYPEINFO(tint::Cloneable);
@ -27,11 +28,16 @@ CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
: dst(to), src(from) {} : dst(to), src(from) {}
CloneContext::~CloneContext() = default; CloneContext::~CloneContext() = default;
Symbol CloneContext::Clone(const Symbol& s) const { Symbol CloneContext::Clone(Symbol s) {
if (symbol_transform_) { return utils::GetOrCreate(cloned_symbols_, s, [&]() -> Symbol {
return symbol_transform_(s); if (symbol_transform_) {
} return symbol_transform_(s);
return dst->Symbols().Register(src->Symbols().NameFor(s)); }
if (!src->Symbols().HasName(s)) {
return dst->Symbols().New();
}
return dst->Symbols().Register(src->Symbols().NameFor(s));
});
} }
void CloneContext::Clone() { void CloneContext::Clone() {

View File

@ -148,7 +148,7 @@ class CloneContext {
/// ///
/// @param s the Symbol to clone /// @param s the Symbol to clone
/// @return the cloned source /// @return the cloned source
Symbol Clone(const Symbol& s) const; Symbol Clone(Symbol s);
/// Clones each of the elements of the vector `v` into the ProgramBuilder /// Clones each of the elements of the vector `v` into the ProgramBuilder
/// #dst. /// #dst.
@ -448,6 +448,9 @@ class CloneContext {
/// A map of object in #src to their cloned equivalent in #dst /// A map of object in #src to their cloned equivalent in #dst
std::unordered_map<const Cloneable*, Cloneable*> cloned_; std::unordered_map<const Cloneable*, Cloneable*> cloned_;
/// A map of symbol in #src to their cloned equivalent in #dst
std::unordered_map<Symbol, Symbol> cloned_symbols_;
/// Cloneable transform functions registered with ReplaceAll() /// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_; std::vector<CloneableTransform> transforms_;

View File

@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gtest/gtest-spi.h" #include <unordered_set>
#include "gtest/gtest-spi.h"
#include "src/program_builder.h" #include "src/program_builder.h"
namespace tint { namespace tint {
@ -416,6 +417,27 @@ TEST(CloneContext, CloneWithReplace_WithNotANode) {
"internal compiler error"); "internal compiler error");
} }
TEST(CloneContext, CloneUnnamedSymbols) {
ProgramBuilder builder;
Symbol old_a = builder.Symbols().New();
Symbol old_b = builder.Symbols().New();
Symbol old_c = builder.Symbols().New();
Program original(std::move(builder));
ProgramBuilder cloned;
CloneContext ctx(&cloned, &original);
Symbol new_a = ctx.Clone(old_a);
Symbol new_x = cloned.Symbols().New();
Symbol new_b = ctx.Clone(old_b);
Symbol new_y = cloned.Symbols().New();
Symbol new_c = ctx.Clone(old_c);
Symbol new_z = cloned.Symbols().New();
std::unordered_set<Symbol> all{new_a, new_x, new_b, new_y, new_c, new_z};
EXPECT_EQ(all.size(), 6u);
}
} // namespace } // namespace
TINT_INSTANTIATE_TYPEINFO(Node); TINT_INSTANTIATE_TYPEINFO(Node);

View File

@ -413,7 +413,7 @@ class ProgramBuilder {
/// @param subtype the array element type /// @param subtype the array element type
/// @param n the array size. 0 represents a runtime-array. /// @param n the array size. 0 represents a runtime-array.
/// @return the tint AST type for a array of size `n` of type `T` /// @return the tint AST type for a array of size `n` of type `T`
type::Array* array(type::Type* subtype, uint32_t n) const { type::Array* array(type::Type* subtype, uint32_t n = 0) const {
return builder->create<type::Array>(subtype, n, ast::DecorationList{}); return builder->create<type::Array>(subtype, n, ast::DecorationList{});
} }
@ -490,6 +490,14 @@ class ProgramBuilder {
// AST helper methods // AST helper methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// @param name the symbol string
/// @return a Symbol with the given name
Symbol Sym(const std::string& name) { return Symbols().Register(name); }
/// @param sym the symbol
/// @return `sym`
Symbol Sym(Symbol sym) { return sym; }
/// @param expr the expression /// @param expr the expression
/// @return expr /// @return expr
template <typename T> template <typename T>
@ -775,13 +783,14 @@ class ProgramBuilder {
/// @param constructor constructor expression /// @param constructor constructor expression
/// @param decorations variable decorations /// @param decorations variable decorations
/// @returns a `ast::Variable` with the given name, storage and type /// @returns a `ast::Variable` with the given name, storage and type
ast::Variable* Var(const std::string& name, template <typename NAME>
ast::Variable* Var(NAME&& name,
type::Type* type, type::Type* type,
ast::StorageClass storage, ast::StorageClass storage,
ast::Expression* constructor = nullptr, ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) { ast::DecorationList decorations = {}) {
return create<ast::Variable>(Symbols().Register(name), storage, type, false, return create<ast::Variable>(Sym(std::forward<NAME>(name)), storage, type,
constructor, decorations); false, constructor, decorations);
} }
/// @param source the variable source /// @param source the variable source
@ -791,58 +800,28 @@ class ProgramBuilder {
/// @param constructor constructor expression /// @param constructor constructor expression
/// @param decorations variable decorations /// @param decorations variable decorations
/// @returns a `ast::Variable` with the given name, storage and type /// @returns a `ast::Variable` with the given name, storage and type
template <typename NAME>
ast::Variable* Var(const Source& source, ast::Variable* Var(const Source& source,
const std::string& name, NAME&& name,
type::Type* type, type::Type* type,
ast::StorageClass storage, ast::StorageClass storage,
ast::Expression* constructor = nullptr, ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) { ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, Symbols().Register(name), storage, return create<ast::Variable>(source, Sym(std::forward<NAME>(name)), storage,
type, false, constructor, decorations); type, false, constructor, decorations);
} }
/// @param symbol the variable symbol
/// @param type the variable type
/// @param storage the variable storage class
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given symbol, storage and type
ast::Variable* Var(Symbol symbol,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(symbol, storage, type, false, constructor,
decorations);
}
/// @param source the variable source
/// @param symbol the variable symbol
/// @param type the variable type
/// @param storage the variable storage class
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given symbol, storage and type
ast::Variable* Var(const Source& source,
Symbol symbol,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, symbol, storage, type, false,
constructor, decorations);
}
/// @param name the variable name /// @param name the variable name
/// @param type the variable type /// @param type the variable type
/// @param constructor optional constructor expression /// @param constructor optional constructor expression
/// @param decorations optional variable decorations /// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given name, storage and type /// @returns a constant `ast::Variable` with the given name, storage and type
ast::Variable* Const(const std::string& name, template <typename NAME>
ast::Variable* Const(NAME&& name,
type::Type* type, type::Type* type,
ast::Expression* constructor = nullptr, ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) { ast::DecorationList decorations = {}) {
return create<ast::Variable>(Symbols().Register(name), return create<ast::Variable>(Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, type, true, ast::StorageClass::kNone, type, true,
constructor, decorations); constructor, decorations);
} }
@ -853,46 +832,17 @@ class ProgramBuilder {
/// @param constructor optional constructor expression /// @param constructor optional constructor expression
/// @param decorations optional variable decorations /// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given name, storage and type /// @returns a constant `ast::Variable` with the given name, storage and type
template <typename NAME>
ast::Variable* Const(const Source& source, ast::Variable* Const(const Source& source,
const std::string& name, NAME&& name,
type::Type* type, type::Type* type,
ast::Expression* constructor = nullptr, ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) { ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, Symbols().Register(name), return create<ast::Variable>(source, Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, type, true, ast::StorageClass::kNone, type, true,
constructor, decorations); constructor, decorations);
} }
/// @param symbol the variable symbol
/// @param type the variable type
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given symbol, storage and
/// type
ast::Variable* Const(Symbol symbol,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(symbol, ast::StorageClass::kNone, type, true,
constructor, decorations);
}
/// @param source the variable source
/// @param symbol the variable symbol
/// @param type the variable type
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given symbol, storage and
/// type
ast::Variable* Const(const Source& source,
Symbol symbol,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, symbol, ast::StorageClass::kNone, type,
true, constructor, decorations);
}
/// @param args the arguments to pass to Var() /// @param args the arguments to pass to Var()
/// @returns a `ast::Variable` constructed by calling Var() with the arguments /// @returns a `ast::Variable` constructed by calling Var() with the arguments
/// of `args`, which is automatically registered as a global variable with the /// of `args`, which is automatically registered as a global variable with the
@ -966,6 +916,16 @@ class ProgramBuilder {
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
} }
/// @param lhs the left hand argument to the division operation
/// @param rhs the right hand argument to the division operation
/// @returns a `ast::BinaryExpression` dividing `lhs` by `rhs`
template <typename LHS, typename RHS>
ast::Expression* Div(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kDivide,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param arr the array argument for the array accessor expression /// @param arr the array argument for the array accessor expression
/// @param idx the index argument for the array accessor expression /// @param idx the index argument for the array accessor expression
/// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx` /// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
@ -1027,19 +987,22 @@ class ProgramBuilder {
/// @param params the function parameters /// @param params the function parameters
/// @param type the function return type /// @param type the function return type
/// @param body the function body /// @param body the function body
/// @param decorations the function decorations /// @param decorations the optional function decorations
/// @param return_type_decorations the function return type decorations /// @param return_type_decorations the optional function return type
/// decorations
/// @returns the function pointer /// @returns the function pointer
template <typename NAME>
ast::Function* Func(Source source, ast::Function* Func(Source source,
std::string name, NAME&& name,
ast::VariableList params, ast::VariableList params,
type::Type* type, type::Type* type,
ast::StatementList body, ast::StatementList body,
ast::DecorationList decorations = {}, ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) { ast::DecorationList return_type_decorations = {}) {
auto* func = create<ast::Function>(source, Symbols().Register(name), params, auto* func =
type, create<ast::BlockStatement>(body), create<ast::Function>(source, Sym(std::forward<NAME>(name)), params,
decorations, return_type_decorations); type, create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func); AST().AddFunction(func);
return func; return func;
} }
@ -1049,17 +1012,19 @@ class ProgramBuilder {
/// @param params the function parameters /// @param params the function parameters
/// @param type the function return type /// @param type the function return type
/// @param body the function body /// @param body the function body
/// @param decorations the function decorations /// @param decorations the optional function decorations
/// @param return_type_decorations the function return type decorations /// @param return_type_decorations the optional function return type
/// decorations
/// @returns the function pointer /// @returns the function pointer
ast::Function* Func(std::string name, template <typename NAME>
ast::Function* Func(NAME&& name,
ast::VariableList params, ast::VariableList params,
type::Type* type, type::Type* type,
ast::StatementList body, ast::StatementList body,
ast::DecorationList decorations = {}, ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) { ast::DecorationList return_type_decorations = {}) {
auto* func = create<ast::Function>(Symbols().Register(name), params, type, auto* func = create<ast::Function>(Sym(std::forward<NAME>(name)), params,
create<ast::BlockStatement>(body), type, create<ast::BlockStatement>(body),
decorations, return_type_decorations); decorations, return_type_decorations);
AST().AddFunction(func); AST().AddFunction(func);
return func; return func;
@ -1113,12 +1078,13 @@ class ProgramBuilder {
/// @param type the struct member type /// @param type the struct member type
/// @param decorations the optional struct member decorations /// @param decorations the optional struct member decorations
/// @returns the struct member pointer /// @returns the struct member pointer
template <typename NAME>
ast::StructMember* Member(const Source& source, ast::StructMember* Member(const Source& source,
const std::string& name, NAME&& name,
type::Type* type, type::Type* type,
ast::DecorationList decorations = {}) { ast::DecorationList decorations = {}) {
return create<ast::StructMember>(source, Symbols().Register(name), type, return create<ast::StructMember>(source, Sym(std::forward<NAME>(name)),
std::move(decorations)); type, std::move(decorations));
} }
/// Creates a ast::StructMember /// Creates a ast::StructMember
@ -1126,11 +1092,12 @@ class ProgramBuilder {
/// @param type the struct member type /// @param type the struct member type
/// @param decorations the optional struct member decorations /// @param decorations the optional struct member decorations
/// @returns the struct member pointer /// @returns the struct member pointer
ast::StructMember* Member(const std::string& name, template <typename NAME>
ast::StructMember* Member(NAME&& name,
type::Type* type, type::Type* type,
ast::DecorationList decorations = {}) { ast::DecorationList decorations = {}) {
return create<ast::StructMember>(source_, Symbols().Register(name), type, return create<ast::StructMember>(source_, Sym(std::forward<NAME>(name)),
std::move(decorations)); type, std::move(decorations));
} }
/// Creates a ast::StructMember with the given byte offset /// Creates a ast::StructMember with the given byte offset
@ -1138,11 +1105,10 @@ class ProgramBuilder {
/// @param name the struct member name /// @param name the struct member name
/// @param type the struct member type /// @param type the struct member type
/// @returns the struct member pointer /// @returns the struct member pointer
ast::StructMember* Member(uint32_t offset, template <typename NAME>
const std::string& name, ast::StructMember* Member(uint32_t offset, NAME&& name, type::Type* type) {
type::Type* type) {
return create<ast::StructMember>( return create<ast::StructMember>(
source_, Symbols().Register(name), type, source_, Sym(std::forward<NAME>(name)), type,
ast::DecorationList{ ast::DecorationList{
create<ast::StructMemberOffsetDecoration>(offset), create<ast::StructMemberOffsetDecoration>(offset),
}); });

View File

@ -50,6 +50,11 @@ Symbol SymbolTable::Get(const std::string& name) const {
return it != name_to_symbol_.end() ? it->second : Symbol(); return it != name_to_symbol_.end() ? it->second : Symbol();
} }
bool SymbolTable::HasName(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol);
return it != symbol_to_name_.end();
}
std::string SymbolTable::NameFor(const Symbol symbol) const { std::string SymbolTable::NameFor(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol); auto it = symbol_to_name_.find(symbol);
if (it == symbol_to_name_.end()) { if (it == symbol_to_name_.end()) {

View File

@ -53,6 +53,10 @@ class SymbolTable {
/// @returns the symbol for the name or symbol::kInvalid if not found. /// @returns the symbol for the name or symbol::kInvalid if not found.
Symbol Get(const std::string& name) const; Symbol Get(const std::string& name) const;
/// @returns true if the symbol has a name
/// @param symbol the symbol to query
bool HasName(const Symbol symbol) const;
/// Returns the name for the given symbol /// Returns the name for the given symbol
/// @param symbol the symbol to retrieve the name for /// @param symbol the symbol to retrieve the name for
/// @returns the symbol name or "" if not found /// @returns the symbol name or "" if not found

View File

@ -45,7 +45,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>(); ->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
}); });
new_struct_members.push_back( new_struct_members.push_back(
ctx.dst->Member(ctx.src->Symbols().NameFor(member->symbol()), ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations)); ctx.Clone(member->type()), new_decorations));
} }
@ -70,10 +70,9 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
auto new_struct_param_symbol = ctx.dst->Symbols().New(); auto new_struct_param_symbol = ctx.dst->Symbols().New();
ast::StructMemberList new_struct_members; ast::StructMemberList new_struct_members;
for (auto* param : func->params()) { for (auto* param : func->params()) {
auto param_name = ctx.src->Symbols().NameFor(param->symbol()); auto param_name = ctx.Clone(param->symbol());
auto* param_ty = ctx.src->Sem().Get(param)->Type(); auto* param_ty = ctx.src->Sem().Get(param)->Type();
auto func_const_symbol = ctx.dst->Symbols().Register(param_name);
ast::Expression* func_const_initializer = nullptr; ast::Expression* func_const_initializer = nullptr;
if (auto* struct_ty = if (auto* struct_ty =
@ -90,7 +89,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
return !deco->IsAnyOf<ast::BuiltinDecoration, return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>(); ast::LocationDecoration>();
}); });
auto member_name = ctx.src->Symbols().NameFor(member->symbol()); auto member_name = ctx.Clone(member->symbol());
new_struct_members.push_back(ctx.dst->Member( new_struct_members.push_back(ctx.dst->Member(
member_name, ctx.Clone(member->type()), new_decorations)); member_name, ctx.Clone(member->type()), new_decorations));
init_values.push_back( init_values.push_back(
@ -118,15 +117,15 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
// Create a function-scope const to replace the parameter. // Create a function-scope const to replace the parameter.
// Initialize it with the value extracted from the new struct parameter. // Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const( auto* func_const = ctx.dst->Const(param_name, ctx.Clone(param_ty),
func_const_symbol, ctx.Clone(param_ty), func_const_initializer); func_const_initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(), ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_const)); ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const. // Replace all uses of the function parameter with the function const.
for (auto* user : ctx.src->Sem().Get(param)->Users()) { for (auto* user : ctx.src->Sem().Get(param)->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(), ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(func_const_symbol)); ctx.dst->Expr(param_name));
} }
} }
@ -163,9 +162,9 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
return !deco->IsAnyOf<ast::BuiltinDecoration, return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>(); ast::LocationDecoration>();
}); });
auto member_name = ctx.src->Symbols().NameFor(member->symbol()); new_struct_members.push_back(
new_struct_members.push_back(ctx.dst->Member( ctx.dst->Member(ctx.Clone(member->symbol()),
member_name, ctx.Clone(member->type()), new_decorations)); ctx.Clone(member->type()), new_decorations));
} }
} else { } else {
new_struct_members.push_back( new_struct_members.push_back(

View File

@ -121,7 +121,7 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>(); ->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
}); });
new_struct_members.push_back( new_struct_members.push_back(
ctx.dst->Member(ctx.src->Symbols().NameFor(member->symbol()), ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations)); ctx.Clone(member->type()), new_decorations));
} }
@ -215,7 +215,7 @@ void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
} }
// Use the same name as the old variable. // Use the same name as the old variable.
std::string var_name = ctx.src->Symbols().NameFor(var->symbol()); auto var_name = ctx.Clone(var->symbol());
// Use `array<u32, 1>` for the new variable. // Use `array<u32, 1>` for the new variable.
auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u); auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
// Create the new variable. // Create the new variable.

View File

@ -74,7 +74,7 @@ fn frag_main([[location(1)]] loc1 : myf32) -> void {
auto* expect = R"( auto* expect = R"(
type myf32 = f32; type myf32 = f32;
[[location(1)]] var<in> tint_symbol_1 : myf32; [[location(1)]] var<in> tint_symbol_2 : myf32;
[[stage(fragment)]] [[stage(fragment)]]
fn frag_main() -> void { fn frag_main() -> void {
@ -95,15 +95,15 @@ fn vert_main() -> [[builtin(position)]] vec4<f32> {
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(position)]] var<out> tint_symbol_2 : vec4<f32>; [[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
fn tint_symbol_3(tint_symbol_1 : vec4<f32>) -> void { fn tint_symbol_2(tint_symbol_3 : vec4<f32>) -> void {
tint_symbol_2 = tint_symbol_1; tint_symbol_1 = tint_symbol_3;
} }
[[stage(vertex)]] [[stage(vertex)]]
fn vert_main() -> void { fn vert_main() -> void {
tint_symbol_3(vec4<f32>(1.0, 2.0, 3.0, 0.0)); tint_symbol_2(vec4<f32>(1.0, 2.0, 3.0, 0.0));
return; return;
} }
)"; )";
@ -127,19 +127,19 @@ fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
auto* expect = R"( auto* expect = R"(
[[location(0)]] var<in> tint_symbol_1 : u32; [[location(0)]] var<in> tint_symbol_1 : u32;
[[location(0)]] var<out> tint_symbol_3 : f32; [[location(0)]] var<out> tint_symbol_2 : f32;
fn tint_symbol_4(tint_symbol_2 : f32) -> void { fn tint_symbol_3(tint_symbol_4 : f32) -> void {
tint_symbol_3 = tint_symbol_2; tint_symbol_2 = tint_symbol_4;
} }
[[stage(fragment)]] [[stage(fragment)]]
fn frag_main() -> void { fn frag_main() -> void {
if ((tint_symbol_1 > 10u)) { if ((tint_symbol_1 > 10u)) {
tint_symbol_4(0.5); tint_symbol_3(0.5);
return; return;
} }
tint_symbol_4(1.0); tint_symbol_3(1.0);
return; return;
} }
)"; )";
@ -165,21 +165,21 @@ fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 {
auto* expect = R"( auto* expect = R"(
type myf32 = f32; type myf32 = f32;
[[location(0)]] var<in> tint_symbol_1 : u32; [[location(0)]] var<in> tint_symbol_2 : u32;
[[location(0)]] var<out> tint_symbol_3 : myf32; [[location(0)]] var<out> tint_symbol_3 : myf32;
fn tint_symbol_5(tint_symbol_2 : myf32) -> void { fn tint_symbol_4(tint_symbol_5 : myf32) -> void {
tint_symbol_3 = tint_symbol_2; tint_symbol_3 = tint_symbol_5;
} }
[[stage(fragment)]] [[stage(fragment)]]
fn frag_main() -> void { fn frag_main() -> void {
if ((tint_symbol_1 > 10u)) { if ((tint_symbol_2 > 10u)) {
tint_symbol_5(0.5); tint_symbol_4(0.5);
return; return;
} }
tint_symbol_5(1.0); tint_symbol_4(1.0);
return; return;
} }
)"; )";
@ -214,8 +214,8 @@ struct FragmentInput {
[[stage(fragment)]] [[stage(fragment)]]
fn frag_main() -> void { fn frag_main() -> void {
const tint_symbol_6 : FragmentInput = FragmentInput(tint_symbol_4, tint_symbol_5); const tint_symbol_7 : FragmentInput = FragmentInput(tint_symbol_4, tint_symbol_5);
var col : f32 = (tint_symbol_6.coord.x * tint_symbol_6.value); var col : f32 = (tint_symbol_7.coord.x * tint_symbol_7.value);
} }
)"; )";
@ -275,23 +275,23 @@ struct VertexOutput {
value : f32; value : f32;
}; };
[[builtin(position)]] var<out> tint_symbol_5 : vec4<f32>; [[builtin(position)]] var<out> tint_symbol_4 : vec4<f32>;
[[location(1)]] var<out> tint_symbol_6 : f32; [[location(1)]] var<out> tint_symbol_5 : f32;
fn tint_symbol_7(tint_symbol_4 : VertexOutput) -> void { fn tint_symbol_6(tint_symbol_7 : VertexOutput) -> void {
tint_symbol_5 = tint_symbol_4.pos; tint_symbol_4 = tint_symbol_7.pos;
tint_symbol_6 = tint_symbol_4.value; tint_symbol_5 = tint_symbol_7.value;
} }
[[stage(vertex)]] [[stage(vertex)]]
fn vert_main() -> void { fn vert_main() -> void {
if (false) { if (false) {
tint_symbol_7(VertexOutput()); tint_symbol_6(VertexOutput());
return; return;
} }
var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0); var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0);
tint_symbol_7(VertexOutput(pos, 2.0)); tint_symbol_6(VertexOutput(pos, 2.0));
return; return;
} }
)"; )";
@ -320,16 +320,16 @@ struct Interface {
[[location(1)]] var<in> tint_symbol_3 : f32; [[location(1)]] var<in> tint_symbol_3 : f32;
[[location(1)]] var<out> tint_symbol_6 : f32; [[location(1)]] var<out> tint_symbol_4 : f32;
fn tint_symbol_7(tint_symbol_5 : Interface) -> void { fn tint_symbol_5(tint_symbol_6 : Interface) -> void {
tint_symbol_6 = tint_symbol_5.value; tint_symbol_4 = tint_symbol_6.value;
} }
[[stage(vertex)]] [[stage(vertex)]]
fn vert_main() -> void { fn vert_main() -> void {
const tint_symbol_4 : Interface = Interface(tint_symbol_3); const tint_symbol_8 : Interface = Interface(tint_symbol_3);
tint_symbol_7(tint_symbol_4); tint_symbol_5(tint_symbol_8);
return; return;
} }
)"; )";
@ -361,15 +361,15 @@ struct Interface {
value : f32; value : f32;
}; };
[[location(1)]] var<out> tint_symbol_4 : f32; [[location(1)]] var<out> tint_symbol_3 : f32;
fn tint_symbol_5(tint_symbol_3 : Interface) -> void { fn tint_symbol_4(tint_symbol_5 : Interface) -> void {
tint_symbol_4 = tint_symbol_3.value; tint_symbol_3 = tint_symbol_5.value;
} }
[[stage(vertex)]] [[stage(vertex)]]
fn vert_main() -> void { fn vert_main() -> void {
tint_symbol_5(Interface(42.0)); tint_symbol_4(Interface(42.0));
return; return;
} }
@ -377,8 +377,8 @@ fn vert_main() -> void {
[[stage(fragment)]] [[stage(fragment)]]
fn frag_main() -> void { fn frag_main() -> void {
const tint_symbol_8 : Interface = Interface(tint_symbol_7); const tint_symbol_9 : Interface = Interface(tint_symbol_7);
var x : f32 = tint_symbol_8.value; var x : f32 = tint_symbol_9.value;
} }
)"; )";
@ -423,16 +423,16 @@ struct FragmentOutput {
[[builtin(frag_coord)]] var<in> tint_symbol_6 : vec4<f32>; [[builtin(frag_coord)]] var<in> tint_symbol_6 : vec4<f32>;
[[location(1)]] var<out> tint_symbol_9 : f32; [[location(1)]] var<out> tint_symbol_7 : f32;
fn tint_symbol_10(tint_symbol_8 : FragmentOutput) -> void { fn tint_symbol_8(tint_symbol_9 : FragmentOutput) -> void {
tint_symbol_9 = tint_symbol_8.value; tint_symbol_7 = tint_symbol_9.value;
} }
[[stage(fragment)]] [[stage(fragment)]]
fn frag_main() -> void { fn frag_main() -> void {
const tint_symbol_7 : FragmentInput = FragmentInput(tint_symbol_5, tint_symbol_6); const tint_symbol_11 : FragmentInput = FragmentInput(tint_symbol_5, tint_symbol_6);
tint_symbol_10(FragmentOutput((tint_symbol_7.coord.x * tint_symbol_7.value))); tint_symbol_8(FragmentOutput((tint_symbol_11.coord.x * tint_symbol_11.value)));
return; return;
} }
)"; )";
@ -467,8 +467,8 @@ struct VertexOutput {
[[builtin(position)]] var<out> tint_symbol_4 : vec4<f32>; [[builtin(position)]] var<out> tint_symbol_4 : vec4<f32>;
fn tint_symbol_5(tint_symbol_3 : VertexOutput) -> void { fn tint_symbol_5(tint_symbol_6 : VertexOutput) -> void {
tint_symbol_4 = tint_symbol_3.Position; tint_symbol_4 = tint_symbol_6.Position;
} }
[[stage(vertex)]] [[stage(vertex)]]
@ -585,19 +585,19 @@ fn main([[builtin(sample_index)]] sample_index : u32,
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(sample_index)]] var<in> tint_symbol_1 : u32; [[builtin(sample_index)]] var<in> tint_symbol_3 : u32;
[[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>; [[builtin(sample_mask_in)]] var<in> tint_symbol_1 : array<u32, 1>;
[[builtin(sample_mask_out)]] var<out> tint_symbol_4 : array<u32, 1>; [[builtin(sample_mask_out)]] var<out> tint_symbol_2 : array<u32, 1>;
fn tint_symbol_5(tint_symbol_3 : u32) -> void { fn tint_symbol_4(tint_symbol_5 : u32) -> void {
tint_symbol_4[0] = tint_symbol_3; tint_symbol_2[0] = tint_symbol_5;
} }
[[stage(fragment)]] [[stage(fragment)]]
fn main() -> void { fn main() -> void {
tint_symbol_5(tint_symbol_2[0]); tint_symbol_4(tint_symbol_1[0]);
return; return;
} }
)"; )";

44
src/utils/get_or_create.h Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_UTILS_GET_OR_CREATE_H_
#define SRC_UTILS_GET_OR_CREATE_H_
#include <unordered_map>
namespace tint {
namespace utils {
/// GetOrCreate is a utility function for lazily adding to an unordered map.
/// If the map already contains the key `key` then this is returned, otherwise
/// `create()` is called and the result is added to the map and is returned.
/// @param map the unordered_map
/// @param key the map key of the item to query or add
/// @param create a callable function-like object with the signature `V()`
/// @return the value of the item with the given key, or the newly created item
template <typename K, typename V, typename CREATE, typename H>
V GetOrCreate(std::unordered_map<K, V, H>& map, K key, CREATE&& create) {
auto it = map.find(key);
if (it != map.end()) {
return it->second;
}
V value = create();
map.emplace(key, value);
return value;
}
} // namespace utils
} // namespace tint
#endif // SRC_UTILS_GET_OR_CREATE_H_

View File

@ -0,0 +1,49 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/utils/get_or_create.h"
#include <unordered_map>
#include "gtest/gtest.h"
namespace tint {
namespace utils {
namespace {
TEST(GetOrCreateTest, NewKey) {
std::unordered_map<int, int> map;
EXPECT_EQ(GetOrCreate(map, 1, [&] { return 2; }), 2);
EXPECT_EQ(map.size(), 1u);
EXPECT_EQ(map[1], 2);
}
TEST(GetOrCreateTest, ExistingKey) {
std::unordered_map<int, int> map;
map[1] = 2;
bool called = false;
EXPECT_EQ(GetOrCreate(map, 1,
[&] {
called = true;
return -2;
}),
2);
EXPECT_EQ(called, false);
EXPECT_EQ(map.size(), 1u);
EXPECT_EQ(map[1], 2);
}
} // namespace
} // namespace utils
} // namespace tint

View File

@ -124,16 +124,16 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = SanitizeAndBuild(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error(); ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct tint_symbol_3 { EXPECT_EQ(result(), R"(struct tint_symbol_1 {
float foo : TEXCOORD0; float foo : TEXCOORD0;
}; };
struct tint_symbol_5 { struct tint_symbol_3 {
float value : SV_Target1; float value : SV_Target1;
}; };
tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) { tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
const float foo = tint_symbol_1.foo; const float foo = tint_symbol_6.foo;
return tint_symbol_5(foo); return tint_symbol_3(foo);
} }
)"); )");
@ -157,16 +157,16 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = SanitizeAndBuild(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error(); ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct tint_symbol_3 { EXPECT_EQ(result(), R"(struct tint_symbol_1 {
float4 coord : SV_Position; float4 coord : SV_Position;
}; };
struct tint_symbol_5 { struct tint_symbol_3 {
float value : SV_Depth; float value : SV_Depth;
}; };
tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) { tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
const float4 coord = tint_symbol_1.coord; const float4 coord = tint_symbol_6.coord;
return tint_symbol_5(coord.x); return tint_symbol_3(coord.x);
} }
)"); )");
@ -217,18 +217,18 @@ struct tint_symbol_4 {
float col1 : TEXCOORD1; float col1 : TEXCOORD1;
float col2 : TEXCOORD2; float col2 : TEXCOORD2;
}; };
struct tint_symbol_9 { struct tint_symbol_7 {
float col1 : TEXCOORD1; float col1 : TEXCOORD1;
float col2 : TEXCOORD2; float col2 : TEXCOORD2;
}; };
tint_symbol_4 vert_main() { tint_symbol_4 vert_main() {
const Interface tint_symbol_5 = Interface(0.5f, 0.25f); const Interface tint_symbol_6 = Interface(0.5f, 0.25f);
return tint_symbol_4(tint_symbol_5.col1, tint_symbol_5.col2); return tint_symbol_4(tint_symbol_6.col1, tint_symbol_6.col2);
} }
void frag_main(tint_symbol_9 tint_symbol_7) { void frag_main(tint_symbol_7 tint_symbol_9) {
const Interface colors = Interface(tint_symbol_7.col1, tint_symbol_7.col2); const Interface colors = Interface(tint_symbol_9.col1, tint_symbol_9.col2);
const float r = colors.col1; const float r = colors.col1;
const float g = colors.col2; const float g = colors.col2;
return; return;
@ -281,10 +281,10 @@ TEST_F(HlslGeneratorImplTest_Function,
EXPECT_EQ(result(), R"(struct VertexOutput { EXPECT_EQ(result(), R"(struct VertexOutput {
float4 pos; float4 pos;
}; };
struct tint_symbol_3 { struct tint_symbol_5 {
float4 pos : SV_Position; float4 pos : SV_Position;
}; };
struct tint_symbol_7 { struct tint_symbol_8 {
float4 pos : SV_Position; float4 pos : SV_Position;
}; };
@ -292,14 +292,14 @@ VertexOutput foo(float x) {
return VertexOutput(float4(x, x, x, 1.0f)); return VertexOutput(float4(x, x, x, 1.0f));
} }
tint_symbol_3 vert_main1() { tint_symbol_5 vert_main1() {
const VertexOutput tint_symbol_5 = VertexOutput(foo(0.5f)); const VertexOutput tint_symbol_7 = VertexOutput(foo(0.5f));
return tint_symbol_3(tint_symbol_5.pos); return tint_symbol_5(tint_symbol_7.pos);
} }
tint_symbol_7 vert_main2() { tint_symbol_8 vert_main2() {
const VertexOutput tint_symbol_8 = VertexOutput(foo(0.25f)); const VertexOutput tint_symbol_10 = VertexOutput(foo(0.25f));
return tint_symbol_7(tint_symbol_8.pos); return tint_symbol_8(tint_symbol_10.pos);
} }
)"); )");

View File

@ -134,9 +134,9 @@ OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %14 "frag_main" %1 %4 OpEntryPoint Fragment %14 "frag_main" %1 %4
OpExecutionMode %14 OriginUpperLeft OpExecutionMode %14 OriginUpperLeft
OpName %1 "tint_symbol_1" OpName %1 "tint_symbol_1"
OpName %4 "tint_symbol_3" OpName %4 "tint_symbol_2"
OpName %10 "tint_symbol_4" OpName %10 "tint_symbol_3"
OpName %11 "tint_symbol_2" OpName %11 "tint_symbol_4"
OpName %14 "frag_main" OpName %14 "frag_main"
OpDecorate %1 Location 0 OpDecorate %1 Location 0
OpDecorate %4 Location 0 OpDecorate %4 Location 0
@ -220,16 +220,16 @@ OpEntryPoint Vertex %16 "vert_main" %1
OpEntryPoint Fragment %25 "frag_main" %5 %7 OpEntryPoint Fragment %25 "frag_main" %5 %7
OpExecutionMode %25 OriginUpperLeft OpExecutionMode %25 OriginUpperLeft
OpExecutionMode %25 DepthReplacing OpExecutionMode %25 DepthReplacing
OpName %1 "tint_symbol_4" OpName %1 "tint_symbol_3"
OpName %5 "tint_symbol_7" OpName %5 "tint_symbol_7"
OpName %7 "tint_symbol_10" OpName %7 "tint_symbol_8"
OpName %10 "Interface" OpName %10 "Interface"
OpMemberName %10 0 "value" OpMemberName %10 0 "value"
OpName %11 "tint_symbol_5" OpName %11 "tint_symbol_4"
OpName %12 "tint_symbol_3" OpName %12 "tint_symbol_5"
OpName %16 "vert_main" OpName %16 "vert_main"
OpName %22 "tint_symbol_11" OpName %22 "tint_symbol_9"
OpName %23 "tint_symbol_9" OpName %23 "tint_symbol_10"
OpName %25 "frag_main" OpName %25 "frag_main"
OpDecorate %1 Location 1 OpDecorate %1 Location 1
OpDecorate %5 Location 1 OpDecorate %5 Location 1

View File

@ -220,6 +220,7 @@ source_set("tint_unittests_core_src") {
"../src/type/vector_type_test.cc", "../src/type/vector_type_test.cc",
"../src/utils/command.h", "../src/utils/command.h",
"../src/utils/command_test.cc", "../src/utils/command_test.cc",
"../src/utils/get_or_create_test.cc",
"../src/utils/hash_test.cc", "../src/utils/hash_test.cc",
"../src/utils/math_test.cc", "../src/utils/math_test.cc",
"../src/utils/tmpfile.h", "../src/utils/tmpfile.h",