tint: Have ast::CallExpression use ast::Identifier

Instead of ast::IdentifierExpression.
The name is not an expression, as it resolves to a function, builtin or
type.

Bug: tint:1257
Change-Id: I13143f2bbc208e9e2934dad20fe5c9aa59520b68
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118341
Kokoro: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2023-02-02 15:16:28 +00:00 committed by Dawn LUCI CQ
parent 6e31bc24b1
commit 999db74a24
18 changed files with 105 additions and 103 deletions

View File

@ -23,7 +23,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::CallExpression);
namespace tint::ast {
namespace {
CallExpression::Target ToTarget(const IdentifierExpression* name) {
CallExpression::Target ToTarget(const Identifier* name) {
CallExpression::Target target;
target.name = name;
return target;
@ -38,7 +38,7 @@ CallExpression::Target ToTarget(const Type* type) {
CallExpression::CallExpression(ProgramID pid,
NodeID nid,
const Source& src,
const IdentifierExpression* name,
const Identifier* name,
utils::VectorRef<const Expression*> a)
: Base(pid, nid, src), target(ToTarget(name)), args(std::move(a)) {
TINT_ASSERT(AST, name);

View File

@ -20,7 +20,7 @@
// Forward declarations
namespace tint::ast {
class Type;
class IdentifierExpression;
class Identifier;
} // namespace tint::ast
namespace tint::ast {
@ -41,7 +41,7 @@ class CallExpression final : public Castable<CallExpression, Expression> {
CallExpression(ProgramID pid,
NodeID nid,
const Source& source,
const IdentifierExpression* name,
const Identifier* name,
utils::VectorRef<const Expression*> args);
/// Constructor
@ -71,7 +71,7 @@ class CallExpression final : public Castable<CallExpression, Expression> {
struct Target {
/// name is a function or builtin to call, or type name to construct or
/// cast-to
const IdentifierExpression* name = nullptr;
const Identifier* name = nullptr;
/// type to construct or cast-to
const Type* type = nullptr;
};

View File

@ -21,13 +21,13 @@ namespace {
using CallExpressionTest = TestHelper;
TEST_F(CallExpressionTest, CreationIdentifier) {
auto* func = Expr("func");
auto* func = Ident("func");
utils::Vector params{
Expr("param1"),
Expr("param2"),
};
auto* stmt = create<CallExpression>(func, params);
auto* stmt = Call(func, params);
EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr);
@ -38,8 +38,8 @@ TEST_F(CallExpressionTest, CreationIdentifier) {
}
TEST_F(CallExpressionTest, CreationIdentifier_WithSource) {
auto* func = Expr("func");
auto* stmt = create<CallExpression>(Source{{20, 2}}, func, utils::Empty);
auto* func = Ident("func");
auto* stmt = Call(Source{{20, 2}}, func);
EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr);
@ -55,7 +55,7 @@ TEST_F(CallExpressionTest, CreationType) {
Expr("param2"),
};
auto* stmt = create<CallExpression>(type, params);
auto* stmt = Construct(type, params);
EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type);
@ -67,7 +67,7 @@ TEST_F(CallExpressionTest, CreationType) {
TEST_F(CallExpressionTest, CreationType_WithSource) {
auto* type = ty.f32();
auto* stmt = create<CallExpression>(Source{{20, 2}}, type, utils::Empty);
auto* stmt = Construct(Source{{20, 2}}, type);
EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type);
@ -77,8 +77,8 @@ TEST_F(CallExpressionTest, CreationType_WithSource) {
}
TEST_F(CallExpressionTest, IsCall) {
auto* func = Expr("func");
auto* stmt = create<CallExpression>(func, utils::Empty);
auto* func = Ident("func");
auto* stmt = Call(func);
EXPECT_TRUE(stmt->Is<CallExpression>());
}
@ -86,7 +86,7 @@ TEST_F(CallExpressionTest, Assert_Null_Identifier) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.create<CallExpression>(static_cast<IdentifierExpression*>(nullptr), utils::Empty);
b.Call(static_cast<Identifier*>(nullptr));
},
"internal compiler error");
}
@ -95,7 +95,7 @@ TEST_F(CallExpressionTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.create<CallExpression>(static_cast<Type*>(nullptr), utils::Empty);
b.Construct(static_cast<Type*>(nullptr));
},
"internal compiler error");
}
@ -104,11 +104,11 @@ TEST_F(CallExpressionTest, Assert_Null_Param) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.create<CallExpression>(b.Expr("func"), utils::Vector{
b.Expr("param1"),
nullptr,
b.Expr("param2"),
});
b.Call(b.Ident("func"), utils::Vector{
b.Expr("param1"),
nullptr,
b.Expr("param2"),
});
},
"internal compiler error");
}
@ -118,7 +118,7 @@ TEST_F(CallExpressionTest, Assert_DifferentProgramID_Identifier) {
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.create<CallExpression>(b2.Expr("func"), utils::Empty);
b1.Call(b2.Ident("func"));
},
"internal compiler error");
}
@ -128,7 +128,7 @@ TEST_F(CallExpressionTest, Assert_DifferentProgramID_Type) {
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.create<CallExpression>(b2.ty.f32(), utils::Empty);
b1.Construct(b2.ty.f32());
},
"internal compiler error");
}
@ -138,7 +138,7 @@ TEST_F(CallExpressionTest, Assert_DifferentProgramID_Param) {
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.create<CallExpression>(b1.Expr("func"), utils::Vector{b2.Expr("param1")});
b1.Call(b1.Ident("func"), b2.Expr("param1"));
},
"internal compiler error");
}

View File

@ -23,14 +23,14 @@ namespace {
using CallStatementTest = TestHelper;
TEST_F(CallStatementTest, Creation) {
auto* expr = create<CallExpression>(Expr("func"), utils::Empty);
auto* expr = Call("func");
auto* c = create<CallStatement>(expr);
auto* c = CallStmt(expr);
EXPECT_EQ(c->expr, expr);
}
TEST_F(CallStatementTest, IsCall) {
auto* c = create<CallStatement>(Call("f"));
auto* c = CallStmt(Call("f"));
EXPECT_TRUE(c->Is<CallStatement>());
}
@ -38,7 +38,7 @@ TEST_F(CallStatementTest, Assert_Null_Call) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.create<CallStatement>(nullptr);
b.CallStmt(nullptr);
},
"internal compiler error");
}
@ -48,7 +48,7 @@ TEST_F(CallStatementTest, Assert_DifferentProgramID_Call) {
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.create<CallStatement>(b2.create<CallExpression>(b2.Expr("func"), utils::Empty));
b1.CallStmt(b2.Call("func"));
},
"internal compiler error");
}

View File

@ -1160,7 +1160,13 @@ class ProgramBuilder {
/// @return an ast::Identifier with the given symbol
template <typename IDENTIFIER>
const ast::Identifier* Ident(IDENTIFIER&& identifier) {
return create<ast::Identifier>(Sym(std::forward<IDENTIFIER>(identifier)));
if constexpr (traits::IsTypeOrDerived<
std::decay_t<std::remove_pointer_t<std::decay_t<IDENTIFIER>>>,
ast::Identifier>) {
return identifier; // Pass-through
} else {
return create<ast::Identifier>(Sym(std::forward<IDENTIFIER>(identifier)));
}
}
/// @param expr the expression
@ -2054,7 +2060,7 @@ class ProgramBuilder {
/// arguments of `args` converted to `ast::Expression`s using `Expr()`.
template <typename NAME, typename... ARGS>
const ast::CallExpression* Call(const Source& source, NAME&& func, ARGS&&... args) {
return create<ast::CallExpression>(source, Expr(func),
return create<ast::CallExpression>(source, Ident(func),
ExprList(std::forward<ARGS>(args)...));
}
@ -2064,7 +2070,7 @@ class ProgramBuilder {
/// arguments of `args` converted to `ast::Expression`s using `Expr()`.
template <typename NAME, typename... ARGS, typename = DisableIfSource<NAME>>
const ast::CallExpression* Call(NAME&& func, ARGS&&... args) {
return create<ast::CallExpression>(Expr(func), ExprList(std::forward<ARGS>(args)...));
return create<ast::CallExpression>(Ident(func), ExprList(std::forward<ARGS>(args)...));
}
/// @param source the source information

View File

@ -1316,10 +1316,10 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() {
// Call the inner function. It has no parameters.
stmts.Push(create<ast::CallStatement>(
source,
create<ast::CallExpression>(source,
create<ast::IdentifierExpression>(
source, builder_.Symbols().Register(ep_info_->inner_name)),
utils::Empty)));
create<ast::CallExpression>(
source,
create<ast::Identifier>(source, builder_.Symbols().Register(ep_info_->inner_name)),
utils::Empty)));
// Pipeline outputs are mapped to the return value.
if (ep_info_->outputs.IsEmpty()) {
@ -3854,7 +3854,7 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
params.Push(MakeOperand(inst, 0).expr);
return {ast_type, create<ast::CallExpression>(
Source{},
create<ast::IdentifierExpression>(
create<ast::Identifier>(
Source{}, builder_.Symbols().Register(unary_builtin_name)),
std::move(params))};
}
@ -4106,7 +4106,7 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(const spvtools::opt::Inst
return {};
}
auto* func = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
auto* func = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList operands;
const Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs.
@ -5212,7 +5212,7 @@ TypedExpression FunctionEmitter::MakeNumericConversion(const spvtools::opt::Inst
bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
// We ignore function attributes such as Inline, DontInline, Pure, Const.
auto name = namer_.Name(inst.GetSingleWordInOperand(0));
auto* function = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
auto* function = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList args;
for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
@ -5302,7 +5302,7 @@ bool FunctionEmitter::EmitControlBarrier(const spvtools::opt::Instruction& inst)
TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instruction& inst) {
const auto builtin = GetBuiltin(opcode(inst));
auto* name = sem::str(builtin);
auto* ident = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList params;
const Type* first_operand_type = nullptr;
@ -5341,11 +5341,10 @@ TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instructi
params.Push(true_value.expr);
// The condition goes last.
params.Push(condition.expr);
return {op_ty,
create<ast::CallExpression>(Source{},
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("select")),
std::move(params))};
return {op_ty, create<ast::CallExpression>(
Source{},
create<ast::Identifier>(Source{}, builder_.Symbols().Register("select")),
std::move(params))};
}
return {};
}
@ -5650,8 +5649,7 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
return false;
}
auto* ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(builtin_name));
auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(builtin_name));
auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(args));
if (inst.type_id() != 0) {
@ -5741,8 +5739,8 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
// Invoke textureDimensions.
// If the texture is arrayed, combine with the result from
// textureNumLayers.
auto* dims_ident = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("textureDimensions"));
auto* dims_ident =
create<ast::Identifier>(Source{}, builder_.Symbols().Register("textureDimensions"));
ExpressionList dims_args{GetImageExpression(inst)};
if (op == spv::Op::OpImageQuerySizeLod) {
dims_args.Push(MakeOperand(inst, 1).expr);
@ -5758,7 +5756,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
}
exprs.Push(dims_call);
if (ast::IsTextureArray(dims)) {
auto* layers_ident = create<ast::IdentifierExpression>(
auto* layers_ident = create<ast::Identifier>(
Source{}, builder_.Symbols().Register("textureNumLayers"));
auto num_layers = create<ast::CallExpression>(
Source{}, layers_ident, utils::Vector{GetImageExpression(inst)});
@ -5789,7 +5787,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
const auto* name =
(op == spv::Op::OpImageQueryLevels) ? "textureNumLevels" : "textureNumSamples";
auto* levels_ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
const ast::Expression* ast_expr = create<ast::CallExpression>(
Source{}, levels_ident, utils::Vector{GetImageExpression(inst)});
auto* result_type = parser_impl_.ConvertType(inst.type_id());

View File

@ -2436,7 +2436,7 @@ Maybe<const ast::CallStatement*> ParserImpl::func_call_statement() {
t.source(),
create<ast::CallExpression>(
t.source(),
create<ast::IdentifierExpression>(t.source(), builder_.Symbols().Register(t.to_str())),
create<ast::Identifier>(t.source(), builder_.Symbols().Register(t.to_str())),
std::move(params.value)));
}
@ -2642,19 +2642,19 @@ Maybe<const ast::Expression*> ParserImpl::primary_expression() {
"in parentheses");
}
auto* ident =
create<ast::IdentifierExpression>(t.source(), builder_.Symbols().Register(t.to_str()));
if (peek_is(Token::Type::kParenLeft)) {
auto params = expect_argument_expression_list("function call");
if (params.errored) {
return Failure::kErrored;
}
auto* ident =
create<ast::Identifier>(t.source(), builder_.Symbols().Register(t.to_str()));
return create<ast::CallExpression>(t.source(), ident, std::move(params.value));
}
return ident;
return create<ast::IdentifierExpression>(t.source(),
builder_.Symbols().Register(t.to_str()));
}
if (t.Is(Token::Type::kParenLeft)) {

View File

@ -39,10 +39,7 @@ TEST_F(ResolverBuiltinValidationTest, FunctionTypeMustMatchReturnStatementType_v
TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageDirect) {
// @compute @workgroup_size(1) fn func { return dpdx(1.0); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
utils::Vector{
Expr(1_f),
});
auto* dpdx = Call(Source{{3, 4}}, "dpdx", 1_f);
Func(Source{{1, 2}}, "func", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(dpdx),
@ -62,10 +59,7 @@ TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageIndirect) {
// fn f2 { f1(); }
// @compute @workgroup_size(1) fn main { return f2(); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
utils::Vector{
Expr(1_f),
});
auto* dpdx = Call(Source{{3, 4}}, "dpdx", 1_f);
Func(Source{{1, 2}}, "f0", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(dpdx),
@ -138,7 +132,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsFunctionUsedAsType) {
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalConstUsedAsFunction) {
GlobalConst(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i));
WrapInFunction(Call(Expr(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
WrapInFunction(Call(Ident(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix'
@ -167,7 +161,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalConstUsedAsType)
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalVarUsedAsFunction) {
GlobalVar(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i), type::AddressSpace::kPrivate);
WrapInFunction(Call(Expr(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
WrapInFunction(Call(Ident(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix'

View File

@ -39,6 +39,7 @@
#include "src/tint/ast/for_loop_statement.h"
#include "src/tint/ast/i32.h"
#include "src/tint/ast/id_attribute.h"
#include "src/tint/ast/identifier.h"
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/increment_decrement_statement.h"
#include "src/tint/ast/internal_attribute.h"

View File

@ -548,7 +548,7 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind, Symbol symbol, Source
return node;
}
case SymbolUseKind::CallFunction: {
auto* node = b.Expr(source, symbol);
auto* node = b.Ident(source, symbol);
statements.Push(b.CallStmt(b.Call(node)));
return node;
}
@ -651,7 +651,8 @@ TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, FuncCall) {
// fn A() { B(); }
// fn B() {}
Func("A", utils::Empty, ty.void_(), utils::Vector{CallStmt(Call(Expr(Source{{12, 34}}, "B")))});
Func("A", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Ident(Source{{12, 34}}, "B")))});
Func(Source{{56, 78}}, "B", utils::Empty, ty.void_(), utils::Vector{Return()});
Build();
@ -812,7 +813,7 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, DirectCall) {
// fn main() { main(); }
Func(Source{{12, 34}}, "main", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{56, 78}}, "main")))});
utils::Vector{CallStmt(Call(Ident(Source{{56, 78}}, "main")))});
Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main'
56:78 note: function 'main' calls function 'main' here)");
@ -826,17 +827,17 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) {
// 5: fn b() { c(); }
Func(Source{{1, 1}}, "a", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{1, 10}}, "b")))});
utils::Vector{CallStmt(Call(Ident(Source{{1, 10}}, "b")))});
Func(Source{{2, 1}}, "e", utils::Empty, ty.void_(), utils::Empty);
Func(Source{{3, 1}}, "d", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(Call(Expr(Source{{3, 10}}, "e"))),
CallStmt(Call(Expr(Source{{3, 10}}, "b"))),
CallStmt(Call(Ident(Source{{3, 10}}, "e"))),
CallStmt(Call(Ident(Source{{3, 10}}, "b"))),
});
Func(Source{{4, 1}}, "c", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{4, 10}}, "d")))});
utils::Vector{CallStmt(Call(Ident(Source{{4, 10}}, "d")))});
Func(Source{{5, 1}}, "b", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{5, 10}}, "c")))});
utils::Vector{CallStmt(Call(Ident(Source{{5, 10}}, "c")))});
Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b'
5:10 note: function 'b' calls function 'c' here
@ -1232,7 +1233,7 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
};
#define V add_use(value_decl, Expr(value_sym), __LINE__, "V()")
#define T add_use(type_decl, ty.type_name(type_sym), __LINE__, "T()")
#define F add_use(func_decl, Expr(func_sym), __LINE__, "F()")
#define F add_use(func_decl, Ident(func_sym), __LINE__, "F()")
Alias(Sym(), T);
Structure(Sym(), //

View File

@ -5330,7 +5330,7 @@ TEST_F(UniformityAnalysisTest, MaximumNumberOfPointerParameters) {
args.Push(b.AddressOf(name));
}
main_body.Push(b.Assign("v0", "non_uniform_global"));
main_body.Push(b.CallStmt(b.create<ast::CallExpression>(b.Expr("foo"), args)));
main_body.Push(b.CallStmt(b.create<ast::CallExpression>(b.Ident("foo"), args)));
main_body.Push(b.If(b.Equal("v254", 0_i), b.Block(b.CallStmt(b.Call("workgroupBarrier")))));
b.Func("main", utils::Empty, ty.void_(), main_body);

View File

@ -434,14 +434,13 @@ struct MultiplanarExternalTexture::State {
buildTextureBuiltinBody(sem::BuiltinType::kTextureSampleBaseClampToEdge));
}
const ast::IdentifierExpression* exp = b.Expr(texture_sample_external_sym);
return b.Call(exp, utils::Vector{
plane_0_binding_param,
b.Expr(syms.plane_1),
ctx.Clone(expr->args[1]),
ctx.Clone(expr->args[2]),
b.Expr(syms.params),
});
return b.Call(texture_sample_external_sym, utils::Vector{
plane_0_binding_param,
b.Expr(syms.plane_1),
ctx.Clone(expr->args[1]),
ctx.Clone(expr->args[2]),
b.Expr(syms.params),
});
}
/// Creates the textureLoadExternal function if needed and returns a call expression to it.

View File

@ -512,9 +512,6 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return clone_maybe_hoisted(bitcast);
},
[&](const ast::CallExpression* call) {
if (call->target.name) {
ctx.Replace(call->target.name, decompose(call->target.name));
}
for (auto* a : call->args) {
ctx.Replace(a, decompose(a));
}

View File

@ -1262,7 +1262,9 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
// Identifiers that need to keep their symbols preserved.
utils::Hashset<const ast::IdentifierExpression*, 8> preserved_identifiers;
utils::Hashset<const ast::Identifier*, 8> preserved_identifiers;
// Identifiers expressions that need to keep their symbols preserved.
utils::Hashset<const ast::IdentifierExpression*, 8> preserved_identifiers_expressions;
// Type names that need to keep their symbols preserved.
utils::Hashset<const ast::TypeName*, 8> preserved_type_names;
@ -1287,11 +1289,11 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
[&](const ast::MemberAccessorExpression* accessor) {
auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
if (sem->Is<sem::Swizzle>()) {
preserved_identifiers.Add(accessor->member);
preserved_identifiers_expressions.Add(accessor->member);
} else if (auto* str_expr = src->Sem().Get(accessor->structure)) {
if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
if (ty->Declaration() == nullptr) { // Builtin structure
preserved_identifiers.Add(accessor->member);
preserved_identifiers_expressions.Add(accessor->member);
}
}
}
@ -1314,7 +1316,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
}
},
[&](const ast::DiagnosticControl* diagnostic) {
preserved_identifiers.Add(diagnostic->rule_name);
preserved_identifiers_expressions.Add(diagnostic->rule_name);
},
[&](const ast::TypeName* type_name) {
if (is_type_short_name(type_name->name)) {
@ -1376,8 +1378,18 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
return sym_out;
});
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
ctx.ReplaceAll([&](const ast::Identifier* ident) -> const ast::Identifier* {
if (preserved_identifiers.Contains(ident)) {
auto sym_in = ident->symbol;
auto str = src->Symbols().NameFor(sym_in);
auto sym_out = b.Symbols().Register(str);
return ctx.dst->create<ast::Identifier>(ctx.Clone(ident->source), sym_out);
}
return nullptr; // Clone ident. Uses the symbol remapping above.
});
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
if (preserved_identifiers_expressions.Contains(ident)) {
auto sym_in = ident->symbol;
auto str = src->Symbols().NameFor(sym_in);
auto sym_out = b.Symbols().Register(str);

View File

@ -34,8 +34,7 @@ using GlslImportData_SingleParamTest = TestParamHelper<GlslImportData>;
TEST_P(GlslImportData_SingleParamTest, FloatScalar) {
auto param = GetParam();
auto* ident = Expr(param.name);
auto* expr = Call(ident, 1_f);
auto* expr = Call(param.name, 1_f);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
@ -91,8 +90,7 @@ using GlslImportData_SingleVectorParamTest = TestParamHelper<GlslImportData>;
TEST_P(GlslImportData_SingleVectorParamTest, FloatVector) {
auto param = GetParam();
auto* ident = Expr(param.name);
auto* expr = Call(ident, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
auto* expr = Call(param.name, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
WrapInFunction(expr);
GeneratorImpl& gen = Build();

View File

@ -34,8 +34,7 @@ using HlslImportData_SingleParamTest = TestParamHelper<HlslImportData>;
TEST_P(HlslImportData_SingleParamTest, FloatScalar) {
auto param = GetParam();
auto* ident = Expr(param.name);
auto* expr = Call(ident, 1_f);
auto* expr = Call(param.name, 1_f);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
@ -90,8 +89,7 @@ using HlslImportData_SingleVectorParamTest = TestParamHelper<HlslImportData>;
TEST_P(HlslImportData_SingleVectorParamTest, FloatVector) {
auto param = GetParam();
auto* ident = Expr(param.name);
auto* expr = Call(ident, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
auto* expr = Call(param.name, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
WrapInFunction(expr);
GeneratorImpl& gen = Build();

View File

@ -276,7 +276,7 @@ TEST_P(MslGeneratorBuiltinTextureTest, Call) {
param.BuildTextureVariable(this);
param.BuildSamplerVariable(this);
auto* call = Call(Expr(param.function), param.args(this));
auto* call = Call(Ident(param.function), param.args(this));
auto* stmt = CallStmt(call);
Func("main", utils::Empty, ty.void_(), utils::Vector{stmt},

View File

@ -239,9 +239,7 @@ bool GeneratorImpl::EmitBitcast(std::ostream& out, const ast::BitcastExpression*
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
if (expr->target.name) {
if (!EmitExpression(out, expr->target.name)) {
return false;
}
out << program_->Symbols().NameFor(expr->target.name->symbol);
} else if (TINT_LIKELY(expr->target.type)) {
if (!EmitType(out, expr->target.type)) {
return false;