CloneContext: Don't create named symbols from unnamed

Registering a new Symbol with the NameFor() of the source symbol creates
a new *named* symbol. When mixing these with unnamed symbols we can have
collisions.

Update CloneContext::Clone(Symbol) to properly clone unnamed symbols.

Update (most) the transforms to ctx.Clone() the symbols instead of
registering the names directly.

Fix up the tests where the symbol IDs have changed.

Note: We can still have symbol collisions if a program is authored with
identifiers like 'tint_symbol_3'. This will be fixed up in a later
change.

Change-Id: I0ce559644da3d60e1060f2eef185fa55ae284521
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46866
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-04-07 11:16:01 +00:00 committed by Commit Bot service account
parent 3bfb6817df
commit 1b8d9f227b
16 changed files with 295 additions and 193 deletions

View File

@ -457,6 +457,7 @@ source_set("libtint_core_src") {
"type/vector_type.h",
"type/void_type.cc",
"type/void_type.h",
"utils/get_or_create.h",
"utils/hash.h",
"utils/math.h",
"utils/unique_vector.h",

View File

@ -272,6 +272,7 @@ set(TINT_LIB_SRCS
type/vector_type.h
type/void_type.cc
type/void_type.h
utils/get_or_create.h
utils/hash.h
utils/math.h
utils/unique_vector.h
@ -519,6 +520,7 @@ if(${TINT_BUILD_TESTS})
type/vector_type_test.cc
utils/command_test.cc
utils/command.h
utils/get_or_create_test.cc
utils/hash_test.cc
utils/math_test.cc
utils/tmpfile_test.cc

View File

@ -15,6 +15,7 @@
#include "src/clone_context.h"
#include "src/program_builder.h"
#include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::Cloneable);
@ -27,11 +28,16 @@ CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
: dst(to), src(from) {}
CloneContext::~CloneContext() = default;
Symbol CloneContext::Clone(const Symbol& s) const {
if (symbol_transform_) {
return symbol_transform_(s);
}
return dst->Symbols().Register(src->Symbols().NameFor(s));
Symbol CloneContext::Clone(Symbol s) {
return utils::GetOrCreate(cloned_symbols_, s, [&]() -> Symbol {
if (symbol_transform_) {
return symbol_transform_(s);
}
if (!src->Symbols().HasName(s)) {
return dst->Symbols().New();
}
return dst->Symbols().Register(src->Symbols().NameFor(s));
});
}
void CloneContext::Clone() {

View File

@ -148,7 +148,7 @@ class CloneContext {
///
/// @param s the Symbol to clone
/// @return the cloned source
Symbol Clone(const Symbol& s) const;
Symbol Clone(Symbol s);
/// Clones each of the elements of the vector `v` into the ProgramBuilder
/// #dst.
@ -448,6 +448,9 @@ class CloneContext {
/// A map of object in #src to their cloned equivalent in #dst
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
/// A map of symbol in #src to their cloned equivalent in #dst
std::unordered_map<Symbol, Symbol> cloned_symbols_;
/// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_;

View File

@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest-spi.h"
#include <unordered_set>
#include "gtest/gtest-spi.h"
#include "src/program_builder.h"
namespace tint {
@ -416,6 +417,27 @@ TEST(CloneContext, CloneWithReplace_WithNotANode) {
"internal compiler error");
}
TEST(CloneContext, CloneUnnamedSymbols) {
ProgramBuilder builder;
Symbol old_a = builder.Symbols().New();
Symbol old_b = builder.Symbols().New();
Symbol old_c = builder.Symbols().New();
Program original(std::move(builder));
ProgramBuilder cloned;
CloneContext ctx(&cloned, &original);
Symbol new_a = ctx.Clone(old_a);
Symbol new_x = cloned.Symbols().New();
Symbol new_b = ctx.Clone(old_b);
Symbol new_y = cloned.Symbols().New();
Symbol new_c = ctx.Clone(old_c);
Symbol new_z = cloned.Symbols().New();
std::unordered_set<Symbol> all{new_a, new_x, new_b, new_y, new_c, new_z};
EXPECT_EQ(all.size(), 6u);
}
} // namespace
TINT_INSTANTIATE_TYPEINFO(Node);

View File

@ -413,7 +413,7 @@ class ProgramBuilder {
/// @param subtype the array element type
/// @param n the array size. 0 represents a runtime-array.
/// @return the tint AST type for a array of size `n` of type `T`
type::Array* array(type::Type* subtype, uint32_t n) const {
type::Array* array(type::Type* subtype, uint32_t n = 0) const {
return builder->create<type::Array>(subtype, n, ast::DecorationList{});
}
@ -490,6 +490,14 @@ class ProgramBuilder {
// AST helper methods
//////////////////////////////////////////////////////////////////////////////
/// @param name the symbol string
/// @return a Symbol with the given name
Symbol Sym(const std::string& name) { return Symbols().Register(name); }
/// @param sym the symbol
/// @return `sym`
Symbol Sym(Symbol sym) { return sym; }
/// @param expr the expression
/// @return expr
template <typename T>
@ -775,13 +783,14 @@ class ProgramBuilder {
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given name, storage and type
ast::Variable* Var(const std::string& name,
template <typename NAME>
ast::Variable* Var(NAME&& name,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(Symbols().Register(name), storage, type, false,
constructor, decorations);
return create<ast::Variable>(Sym(std::forward<NAME>(name)), storage, type,
false, constructor, decorations);
}
/// @param source the variable source
@ -791,58 +800,28 @@ class ProgramBuilder {
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given name, storage and type
template <typename NAME>
ast::Variable* Var(const Source& source,
const std::string& name,
NAME&& name,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, Symbols().Register(name), storage,
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)), storage,
type, false, constructor, decorations);
}
/// @param symbol the variable symbol
/// @param type the variable type
/// @param storage the variable storage class
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given symbol, storage and type
ast::Variable* Var(Symbol symbol,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(symbol, storage, type, false, constructor,
decorations);
}
/// @param source the variable source
/// @param symbol the variable symbol
/// @param type the variable type
/// @param storage the variable storage class
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given symbol, storage and type
ast::Variable* Var(const Source& source,
Symbol symbol,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, symbol, storage, type, false,
constructor, decorations);
}
/// @param name the variable name
/// @param type the variable type
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given name, storage and type
ast::Variable* Const(const std::string& name,
template <typename NAME>
ast::Variable* Const(NAME&& name,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(Symbols().Register(name),
return create<ast::Variable>(Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, type, true,
constructor, decorations);
}
@ -853,46 +832,17 @@ class ProgramBuilder {
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given name, storage and type
template <typename NAME>
ast::Variable* Const(const Source& source,
const std::string& name,
NAME&& name,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, Symbols().Register(name),
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, type, true,
constructor, decorations);
}
/// @param symbol the variable symbol
/// @param type the variable type
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given symbol, storage and
/// type
ast::Variable* Const(Symbol symbol,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(symbol, ast::StorageClass::kNone, type, true,
constructor, decorations);
}
/// @param source the variable source
/// @param symbol the variable symbol
/// @param type the variable type
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given symbol, storage and
/// type
ast::Variable* Const(const Source& source,
Symbol symbol,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
return create<ast::Variable>(source, symbol, ast::StorageClass::kNone, type,
true, constructor, decorations);
}
/// @param args the arguments to pass to Var()
/// @returns a `ast::Variable` constructed by calling Var() with the arguments
/// of `args`, which is automatically registered as a global variable with the
@ -966,6 +916,16 @@ class ProgramBuilder {
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the division operation
/// @param rhs the right hand argument to the division operation
/// @returns a `ast::BinaryExpression` dividing `lhs` by `rhs`
template <typename LHS, typename RHS>
ast::Expression* Div(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kDivide,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param arr the array argument for the array accessor expression
/// @param idx the index argument for the array accessor expression
/// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
@ -1027,19 +987,22 @@ class ProgramBuilder {
/// @param params the function parameters
/// @param type the function return type
/// @param body the function body
/// @param decorations the function decorations
/// @param return_type_decorations the function return type decorations
/// @param decorations the optional function decorations
/// @param return_type_decorations the optional function return type
/// decorations
/// @returns the function pointer
template <typename NAME>
ast::Function* Func(Source source,
std::string name,
NAME&& name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) {
auto* func = create<ast::Function>(source, Symbols().Register(name), params,
type, create<ast::BlockStatement>(body),
decorations, return_type_decorations);
auto* func =
create<ast::Function>(source, Sym(std::forward<NAME>(name)), params,
type, create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func);
return func;
}
@ -1049,17 +1012,19 @@ class ProgramBuilder {
/// @param params the function parameters
/// @param type the function return type
/// @param body the function body
/// @param decorations the function decorations
/// @param return_type_decorations the function return type decorations
/// @param decorations the optional function decorations
/// @param return_type_decorations the optional function return type
/// decorations
/// @returns the function pointer
ast::Function* Func(std::string name,
template <typename NAME>
ast::Function* Func(NAME&& name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) {
auto* func = create<ast::Function>(Symbols().Register(name), params, type,
create<ast::BlockStatement>(body),
auto* func = create<ast::Function>(Sym(std::forward<NAME>(name)), params,
type, create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func);
return func;
@ -1113,12 +1078,13 @@ class ProgramBuilder {
/// @param type the struct member type
/// @param decorations the optional struct member decorations
/// @returns the struct member pointer
template <typename NAME>
ast::StructMember* Member(const Source& source,
const std::string& name,
NAME&& name,
type::Type* type,
ast::DecorationList decorations = {}) {
return create<ast::StructMember>(source, Symbols().Register(name), type,
std::move(decorations));
return create<ast::StructMember>(source, Sym(std::forward<NAME>(name)),
type, std::move(decorations));
}
/// Creates a ast::StructMember
@ -1126,11 +1092,12 @@ class ProgramBuilder {
/// @param type the struct member type
/// @param decorations the optional struct member decorations
/// @returns the struct member pointer
ast::StructMember* Member(const std::string& name,
template <typename NAME>
ast::StructMember* Member(NAME&& name,
type::Type* type,
ast::DecorationList decorations = {}) {
return create<ast::StructMember>(source_, Symbols().Register(name), type,
std::move(decorations));
return create<ast::StructMember>(source_, Sym(std::forward<NAME>(name)),
type, std::move(decorations));
}
/// Creates a ast::StructMember with the given byte offset
@ -1138,11 +1105,10 @@ class ProgramBuilder {
/// @param name the struct member name
/// @param type the struct member type
/// @returns the struct member pointer
ast::StructMember* Member(uint32_t offset,
const std::string& name,
type::Type* type) {
template <typename NAME>
ast::StructMember* Member(uint32_t offset, NAME&& name, type::Type* type) {
return create<ast::StructMember>(
source_, Symbols().Register(name), type,
source_, Sym(std::forward<NAME>(name)), type,
ast::DecorationList{
create<ast::StructMemberOffsetDecoration>(offset),
});

View File

@ -50,6 +50,11 @@ Symbol SymbolTable::Get(const std::string& name) const {
return it != name_to_symbol_.end() ? it->second : Symbol();
}
bool SymbolTable::HasName(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol);
return it != symbol_to_name_.end();
}
std::string SymbolTable::NameFor(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol);
if (it == symbol_to_name_.end()) {

View File

@ -53,6 +53,10 @@ class SymbolTable {
/// @returns the symbol for the name or symbol::kInvalid if not found.
Symbol Get(const std::string& name) const;
/// @returns true if the symbol has a name
/// @param symbol the symbol to query
bool HasName(const Symbol symbol) const;
/// Returns the name for the given symbol
/// @param symbol the symbol to retrieve the name for
/// @returns the symbol name or "" if not found

View File

@ -45,7 +45,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
});
new_struct_members.push_back(
ctx.dst->Member(ctx.src->Symbols().NameFor(member->symbol()),
ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
}
@ -70,10 +70,9 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
auto new_struct_param_symbol = ctx.dst->Symbols().New();
ast::StructMemberList new_struct_members;
for (auto* param : func->params()) {
auto param_name = ctx.src->Symbols().NameFor(param->symbol());
auto param_name = ctx.Clone(param->symbol());
auto* param_ty = ctx.src->Sem().Get(param)->Type();
auto func_const_symbol = ctx.dst->Symbols().Register(param_name);
ast::Expression* func_const_initializer = nullptr;
if (auto* struct_ty =
@ -90,7 +89,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
auto member_name = ctx.src->Symbols().NameFor(member->symbol());
auto member_name = ctx.Clone(member->symbol());
new_struct_members.push_back(ctx.dst->Member(
member_name, ctx.Clone(member->type()), new_decorations));
init_values.push_back(
@ -118,15 +117,15 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
// Create a function-scope const to replace the parameter.
// Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const(
func_const_symbol, ctx.Clone(param_ty), func_const_initializer);
auto* func_const = ctx.dst->Const(param_name, ctx.Clone(param_ty),
func_const_initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(func_const_symbol));
ctx.dst->Expr(param_name));
}
}
@ -163,9 +162,9 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
auto member_name = ctx.src->Symbols().NameFor(member->symbol());
new_struct_members.push_back(ctx.dst->Member(
member_name, ctx.Clone(member->type()), new_decorations));
new_struct_members.push_back(
ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
}
} else {
new_struct_members.push_back(

View File

@ -121,7 +121,7 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
});
new_struct_members.push_back(
ctx.dst->Member(ctx.src->Symbols().NameFor(member->symbol()),
ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
}
@ -215,7 +215,7 @@ void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
}
// Use the same name as the old variable.
std::string var_name = ctx.src->Symbols().NameFor(var->symbol());
auto var_name = ctx.Clone(var->symbol());
// Use `array<u32, 1>` for the new variable.
auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
// Create the new variable.

View File

@ -74,7 +74,7 @@ fn frag_main([[location(1)]] loc1 : myf32) -> void {
auto* expect = R"(
type myf32 = f32;
[[location(1)]] var<in> tint_symbol_1 : myf32;
[[location(1)]] var<in> tint_symbol_2 : myf32;
[[stage(fragment)]]
fn frag_main() -> void {
@ -95,15 +95,15 @@ fn vert_main() -> [[builtin(position)]] vec4<f32> {
)";
auto* expect = R"(
[[builtin(position)]] var<out> tint_symbol_2 : vec4<f32>;
[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
fn tint_symbol_3(tint_symbol_1 : vec4<f32>) -> void {
tint_symbol_2 = tint_symbol_1;
fn tint_symbol_2(tint_symbol_3 : vec4<f32>) -> void {
tint_symbol_1 = tint_symbol_3;
}
[[stage(vertex)]]
fn vert_main() -> void {
tint_symbol_3(vec4<f32>(1.0, 2.0, 3.0, 0.0));
tint_symbol_2(vec4<f32>(1.0, 2.0, 3.0, 0.0));
return;
}
)";
@ -127,19 +127,19 @@ fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
auto* expect = R"(
[[location(0)]] var<in> tint_symbol_1 : u32;
[[location(0)]] var<out> tint_symbol_3 : f32;
[[location(0)]] var<out> tint_symbol_2 : f32;
fn tint_symbol_4(tint_symbol_2 : f32) -> void {
tint_symbol_3 = tint_symbol_2;
fn tint_symbol_3(tint_symbol_4 : f32) -> void {
tint_symbol_2 = tint_symbol_4;
}
[[stage(fragment)]]
fn frag_main() -> void {
if ((tint_symbol_1 > 10u)) {
tint_symbol_4(0.5);
tint_symbol_3(0.5);
return;
}
tint_symbol_4(1.0);
tint_symbol_3(1.0);
return;
}
)";
@ -165,21 +165,21 @@ fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 {
auto* expect = R"(
type myf32 = f32;
[[location(0)]] var<in> tint_symbol_1 : u32;
[[location(0)]] var<in> tint_symbol_2 : u32;
[[location(0)]] var<out> tint_symbol_3 : myf32;
fn tint_symbol_5(tint_symbol_2 : myf32) -> void {
tint_symbol_3 = tint_symbol_2;
fn tint_symbol_4(tint_symbol_5 : myf32) -> void {
tint_symbol_3 = tint_symbol_5;
}
[[stage(fragment)]]
fn frag_main() -> void {
if ((tint_symbol_1 > 10u)) {
tint_symbol_5(0.5);
if ((tint_symbol_2 > 10u)) {
tint_symbol_4(0.5);
return;
}
tint_symbol_5(1.0);
tint_symbol_4(1.0);
return;
}
)";
@ -214,8 +214,8 @@ struct FragmentInput {
[[stage(fragment)]]
fn frag_main() -> void {
const tint_symbol_6 : FragmentInput = FragmentInput(tint_symbol_4, tint_symbol_5);
var col : f32 = (tint_symbol_6.coord.x * tint_symbol_6.value);
const tint_symbol_7 : FragmentInput = FragmentInput(tint_symbol_4, tint_symbol_5);
var col : f32 = (tint_symbol_7.coord.x * tint_symbol_7.value);
}
)";
@ -275,23 +275,23 @@ struct VertexOutput {
value : f32;
};
[[builtin(position)]] var<out> tint_symbol_5 : vec4<f32>;
[[builtin(position)]] var<out> tint_symbol_4 : vec4<f32>;
[[location(1)]] var<out> tint_symbol_6 : f32;
[[location(1)]] var<out> tint_symbol_5 : f32;
fn tint_symbol_7(tint_symbol_4 : VertexOutput) -> void {
tint_symbol_5 = tint_symbol_4.pos;
tint_symbol_6 = tint_symbol_4.value;
fn tint_symbol_6(tint_symbol_7 : VertexOutput) -> void {
tint_symbol_4 = tint_symbol_7.pos;
tint_symbol_5 = tint_symbol_7.value;
}
[[stage(vertex)]]
fn vert_main() -> void {
if (false) {
tint_symbol_7(VertexOutput());
tint_symbol_6(VertexOutput());
return;
}
var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0);
tint_symbol_7(VertexOutput(pos, 2.0));
tint_symbol_6(VertexOutput(pos, 2.0));
return;
}
)";
@ -320,16 +320,16 @@ struct Interface {
[[location(1)]] var<in> tint_symbol_3 : f32;
[[location(1)]] var<out> tint_symbol_6 : f32;
[[location(1)]] var<out> tint_symbol_4 : f32;
fn tint_symbol_7(tint_symbol_5 : Interface) -> void {
tint_symbol_6 = tint_symbol_5.value;
fn tint_symbol_5(tint_symbol_6 : Interface) -> void {
tint_symbol_4 = tint_symbol_6.value;
}
[[stage(vertex)]]
fn vert_main() -> void {
const tint_symbol_4 : Interface = Interface(tint_symbol_3);
tint_symbol_7(tint_symbol_4);
const tint_symbol_8 : Interface = Interface(tint_symbol_3);
tint_symbol_5(tint_symbol_8);
return;
}
)";
@ -361,15 +361,15 @@ struct Interface {
value : f32;
};
[[location(1)]] var<out> tint_symbol_4 : f32;
[[location(1)]] var<out> tint_symbol_3 : f32;
fn tint_symbol_5(tint_symbol_3 : Interface) -> void {
tint_symbol_4 = tint_symbol_3.value;
fn tint_symbol_4(tint_symbol_5 : Interface) -> void {
tint_symbol_3 = tint_symbol_5.value;
}
[[stage(vertex)]]
fn vert_main() -> void {
tint_symbol_5(Interface(42.0));
tint_symbol_4(Interface(42.0));
return;
}
@ -377,8 +377,8 @@ fn vert_main() -> void {
[[stage(fragment)]]
fn frag_main() -> void {
const tint_symbol_8 : Interface = Interface(tint_symbol_7);
var x : f32 = tint_symbol_8.value;
const tint_symbol_9 : Interface = Interface(tint_symbol_7);
var x : f32 = tint_symbol_9.value;
}
)";
@ -423,16 +423,16 @@ struct FragmentOutput {
[[builtin(frag_coord)]] var<in> tint_symbol_6 : vec4<f32>;
[[location(1)]] var<out> tint_symbol_9 : f32;
[[location(1)]] var<out> tint_symbol_7 : f32;
fn tint_symbol_10(tint_symbol_8 : FragmentOutput) -> void {
tint_symbol_9 = tint_symbol_8.value;
fn tint_symbol_8(tint_symbol_9 : FragmentOutput) -> void {
tint_symbol_7 = tint_symbol_9.value;
}
[[stage(fragment)]]
fn frag_main() -> void {
const tint_symbol_7 : FragmentInput = FragmentInput(tint_symbol_5, tint_symbol_6);
tint_symbol_10(FragmentOutput((tint_symbol_7.coord.x * tint_symbol_7.value)));
const tint_symbol_11 : FragmentInput = FragmentInput(tint_symbol_5, tint_symbol_6);
tint_symbol_8(FragmentOutput((tint_symbol_11.coord.x * tint_symbol_11.value)));
return;
}
)";
@ -467,8 +467,8 @@ struct VertexOutput {
[[builtin(position)]] var<out> tint_symbol_4 : vec4<f32>;
fn tint_symbol_5(tint_symbol_3 : VertexOutput) -> void {
tint_symbol_4 = tint_symbol_3.Position;
fn tint_symbol_5(tint_symbol_6 : VertexOutput) -> void {
tint_symbol_4 = tint_symbol_6.Position;
}
[[stage(vertex)]]
@ -585,19 +585,19 @@ fn main([[builtin(sample_index)]] sample_index : u32,
)";
auto* expect = R"(
[[builtin(sample_index)]] var<in> tint_symbol_1 : u32;
[[builtin(sample_index)]] var<in> tint_symbol_3 : u32;
[[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
[[builtin(sample_mask_in)]] var<in> tint_symbol_1 : array<u32, 1>;
[[builtin(sample_mask_out)]] var<out> tint_symbol_4 : array<u32, 1>;
[[builtin(sample_mask_out)]] var<out> tint_symbol_2 : array<u32, 1>;
fn tint_symbol_5(tint_symbol_3 : u32) -> void {
tint_symbol_4[0] = tint_symbol_3;
fn tint_symbol_4(tint_symbol_5 : u32) -> void {
tint_symbol_2[0] = tint_symbol_5;
}
[[stage(fragment)]]
fn main() -> void {
tint_symbol_5(tint_symbol_2[0]);
tint_symbol_4(tint_symbol_1[0]);
return;
}
)";

44
src/utils/get_or_create.h Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_UTILS_GET_OR_CREATE_H_
#define SRC_UTILS_GET_OR_CREATE_H_
#include <unordered_map>
namespace tint {
namespace utils {
/// GetOrCreate is a utility function for lazily adding to an unordered map.
/// If the map already contains the key `key` then this is returned, otherwise
/// `create()` is called and the result is added to the map and is returned.
/// @param map the unordered_map
/// @param key the map key of the item to query or add
/// @param create a callable function-like object with the signature `V()`
/// @return the value of the item with the given key, or the newly created item
template <typename K, typename V, typename CREATE, typename H>
V GetOrCreate(std::unordered_map<K, V, H>& map, K key, CREATE&& create) {
auto it = map.find(key);
if (it != map.end()) {
return it->second;
}
V value = create();
map.emplace(key, value);
return value;
}
} // namespace utils
} // namespace tint
#endif // SRC_UTILS_GET_OR_CREATE_H_

View File

@ -0,0 +1,49 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/utils/get_or_create.h"
#include <unordered_map>
#include "gtest/gtest.h"
namespace tint {
namespace utils {
namespace {
TEST(GetOrCreateTest, NewKey) {
std::unordered_map<int, int> map;
EXPECT_EQ(GetOrCreate(map, 1, [&] { return 2; }), 2);
EXPECT_EQ(map.size(), 1u);
EXPECT_EQ(map[1], 2);
}
TEST(GetOrCreateTest, ExistingKey) {
std::unordered_map<int, int> map;
map[1] = 2;
bool called = false;
EXPECT_EQ(GetOrCreate(map, 1,
[&] {
called = true;
return -2;
}),
2);
EXPECT_EQ(called, false);
EXPECT_EQ(map.size(), 1u);
EXPECT_EQ(map[1], 2);
}
} // namespace
} // namespace utils
} // namespace tint

View File

@ -124,16 +124,16 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct tint_symbol_3 {
EXPECT_EQ(result(), R"(struct tint_symbol_1 {
float foo : TEXCOORD0;
};
struct tint_symbol_5 {
struct tint_symbol_3 {
float value : SV_Target1;
};
tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) {
const float foo = tint_symbol_1.foo;
return tint_symbol_5(foo);
tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
const float foo = tint_symbol_6.foo;
return tint_symbol_3(foo);
}
)");
@ -157,16 +157,16 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct tint_symbol_3 {
EXPECT_EQ(result(), R"(struct tint_symbol_1 {
float4 coord : SV_Position;
};
struct tint_symbol_5 {
struct tint_symbol_3 {
float value : SV_Depth;
};
tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) {
const float4 coord = tint_symbol_1.coord;
return tint_symbol_5(coord.x);
tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
const float4 coord = tint_symbol_6.coord;
return tint_symbol_3(coord.x);
}
)");
@ -217,18 +217,18 @@ struct tint_symbol_4 {
float col1 : TEXCOORD1;
float col2 : TEXCOORD2;
};
struct tint_symbol_9 {
struct tint_symbol_7 {
float col1 : TEXCOORD1;
float col2 : TEXCOORD2;
};
tint_symbol_4 vert_main() {
const Interface tint_symbol_5 = Interface(0.5f, 0.25f);
return tint_symbol_4(tint_symbol_5.col1, tint_symbol_5.col2);
const Interface tint_symbol_6 = Interface(0.5f, 0.25f);
return tint_symbol_4(tint_symbol_6.col1, tint_symbol_6.col2);
}
void frag_main(tint_symbol_9 tint_symbol_7) {
const Interface colors = Interface(tint_symbol_7.col1, tint_symbol_7.col2);
void frag_main(tint_symbol_7 tint_symbol_9) {
const Interface colors = Interface(tint_symbol_9.col1, tint_symbol_9.col2);
const float r = colors.col1;
const float g = colors.col2;
return;
@ -281,10 +281,10 @@ TEST_F(HlslGeneratorImplTest_Function,
EXPECT_EQ(result(), R"(struct VertexOutput {
float4 pos;
};
struct tint_symbol_3 {
struct tint_symbol_5 {
float4 pos : SV_Position;
};
struct tint_symbol_7 {
struct tint_symbol_8 {
float4 pos : SV_Position;
};
@ -292,14 +292,14 @@ VertexOutput foo(float x) {
return VertexOutput(float4(x, x, x, 1.0f));
}
tint_symbol_3 vert_main1() {
const VertexOutput tint_symbol_5 = VertexOutput(foo(0.5f));
return tint_symbol_3(tint_symbol_5.pos);
tint_symbol_5 vert_main1() {
const VertexOutput tint_symbol_7 = VertexOutput(foo(0.5f));
return tint_symbol_5(tint_symbol_7.pos);
}
tint_symbol_7 vert_main2() {
const VertexOutput tint_symbol_8 = VertexOutput(foo(0.25f));
return tint_symbol_7(tint_symbol_8.pos);
tint_symbol_8 vert_main2() {
const VertexOutput tint_symbol_10 = VertexOutput(foo(0.25f));
return tint_symbol_8(tint_symbol_10.pos);
}
)");

View File

@ -134,9 +134,9 @@ OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %14 "frag_main" %1 %4
OpExecutionMode %14 OriginUpperLeft
OpName %1 "tint_symbol_1"
OpName %4 "tint_symbol_3"
OpName %10 "tint_symbol_4"
OpName %11 "tint_symbol_2"
OpName %4 "tint_symbol_2"
OpName %10 "tint_symbol_3"
OpName %11 "tint_symbol_4"
OpName %14 "frag_main"
OpDecorate %1 Location 0
OpDecorate %4 Location 0
@ -220,16 +220,16 @@ OpEntryPoint Vertex %16 "vert_main" %1
OpEntryPoint Fragment %25 "frag_main" %5 %7
OpExecutionMode %25 OriginUpperLeft
OpExecutionMode %25 DepthReplacing
OpName %1 "tint_symbol_4"
OpName %1 "tint_symbol_3"
OpName %5 "tint_symbol_7"
OpName %7 "tint_symbol_10"
OpName %7 "tint_symbol_8"
OpName %10 "Interface"
OpMemberName %10 0 "value"
OpName %11 "tint_symbol_5"
OpName %12 "tint_symbol_3"
OpName %11 "tint_symbol_4"
OpName %12 "tint_symbol_5"
OpName %16 "vert_main"
OpName %22 "tint_symbol_11"
OpName %23 "tint_symbol_9"
OpName %22 "tint_symbol_9"
OpName %23 "tint_symbol_10"
OpName %25 "frag_main"
OpDecorate %1 Location 1
OpDecorate %5 Location 1

View File

@ -220,6 +220,7 @@ source_set("tint_unittests_core_src") {
"../src/type/vector_type_test.cc",
"../src/utils/command.h",
"../src/utils/command_test.cc",
"../src/utils/get_or_create_test.cc",
"../src/utils/hash_test.cc",
"../src/utils/math_test.cc",
"../src/utils/tmpfile.h",