tint: Have ast::CallExpression use ast::Identifier

Instead of ast::IdentifierExpression.
The name is not an expression, as it resolves to a function, builtin or
type.

Bug: tint:1257
Change-Id: I13143f2bbc208e9e2934dad20fe5c9aa59520b68
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118341
Kokoro: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2023-02-02 15:16:28 +00:00 committed by Dawn LUCI CQ
parent 6e31bc24b1
commit 999db74a24
18 changed files with 105 additions and 103 deletions

View File

@ -23,7 +23,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::CallExpression);
namespace tint::ast { namespace tint::ast {
namespace { namespace {
CallExpression::Target ToTarget(const IdentifierExpression* name) { CallExpression::Target ToTarget(const Identifier* name) {
CallExpression::Target target; CallExpression::Target target;
target.name = name; target.name = name;
return target; return target;
@ -38,7 +38,7 @@ CallExpression::Target ToTarget(const Type* type) {
CallExpression::CallExpression(ProgramID pid, CallExpression::CallExpression(ProgramID pid,
NodeID nid, NodeID nid,
const Source& src, const Source& src,
const IdentifierExpression* name, const Identifier* name,
utils::VectorRef<const Expression*> a) utils::VectorRef<const Expression*> a)
: Base(pid, nid, src), target(ToTarget(name)), args(std::move(a)) { : Base(pid, nid, src), target(ToTarget(name)), args(std::move(a)) {
TINT_ASSERT(AST, name); TINT_ASSERT(AST, name);

View File

@ -20,7 +20,7 @@
// Forward declarations // Forward declarations
namespace tint::ast { namespace tint::ast {
class Type; class Type;
class IdentifierExpression; class Identifier;
} // namespace tint::ast } // namespace tint::ast
namespace tint::ast { namespace tint::ast {
@ -41,7 +41,7 @@ class CallExpression final : public Castable<CallExpression, Expression> {
CallExpression(ProgramID pid, CallExpression(ProgramID pid,
NodeID nid, NodeID nid,
const Source& source, const Source& source,
const IdentifierExpression* name, const Identifier* name,
utils::VectorRef<const Expression*> args); utils::VectorRef<const Expression*> args);
/// Constructor /// Constructor
@ -71,7 +71,7 @@ class CallExpression final : public Castable<CallExpression, Expression> {
struct Target { struct Target {
/// name is a function or builtin to call, or type name to construct or /// name is a function or builtin to call, or type name to construct or
/// cast-to /// cast-to
const IdentifierExpression* name = nullptr; const Identifier* name = nullptr;
/// type to construct or cast-to /// type to construct or cast-to
const Type* type = nullptr; const Type* type = nullptr;
}; };

View File

@ -21,13 +21,13 @@ namespace {
using CallExpressionTest = TestHelper; using CallExpressionTest = TestHelper;
TEST_F(CallExpressionTest, CreationIdentifier) { TEST_F(CallExpressionTest, CreationIdentifier) {
auto* func = Expr("func"); auto* func = Ident("func");
utils::Vector params{ utils::Vector params{
Expr("param1"), Expr("param1"),
Expr("param2"), Expr("param2"),
}; };
auto* stmt = create<CallExpression>(func, params); auto* stmt = Call(func, params);
EXPECT_EQ(stmt->target.name, func); EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr); EXPECT_EQ(stmt->target.type, nullptr);
@ -38,8 +38,8 @@ TEST_F(CallExpressionTest, CreationIdentifier) {
} }
TEST_F(CallExpressionTest, CreationIdentifier_WithSource) { TEST_F(CallExpressionTest, CreationIdentifier_WithSource) {
auto* func = Expr("func"); auto* func = Ident("func");
auto* stmt = create<CallExpression>(Source{{20, 2}}, func, utils::Empty); auto* stmt = Call(Source{{20, 2}}, func);
EXPECT_EQ(stmt->target.name, func); EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr); EXPECT_EQ(stmt->target.type, nullptr);
@ -55,7 +55,7 @@ TEST_F(CallExpressionTest, CreationType) {
Expr("param2"), Expr("param2"),
}; };
auto* stmt = create<CallExpression>(type, params); auto* stmt = Construct(type, params);
EXPECT_EQ(stmt->target.name, nullptr); EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type); EXPECT_EQ(stmt->target.type, type);
@ -67,7 +67,7 @@ TEST_F(CallExpressionTest, CreationType) {
TEST_F(CallExpressionTest, CreationType_WithSource) { TEST_F(CallExpressionTest, CreationType_WithSource) {
auto* type = ty.f32(); auto* type = ty.f32();
auto* stmt = create<CallExpression>(Source{{20, 2}}, type, utils::Empty); auto* stmt = Construct(Source{{20, 2}}, type);
EXPECT_EQ(stmt->target.name, nullptr); EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type); EXPECT_EQ(stmt->target.type, type);
@ -77,8 +77,8 @@ TEST_F(CallExpressionTest, CreationType_WithSource) {
} }
TEST_F(CallExpressionTest, IsCall) { TEST_F(CallExpressionTest, IsCall) {
auto* func = Expr("func"); auto* func = Ident("func");
auto* stmt = create<CallExpression>(func, utils::Empty); auto* stmt = Call(func);
EXPECT_TRUE(stmt->Is<CallExpression>()); EXPECT_TRUE(stmt->Is<CallExpression>());
} }
@ -86,7 +86,7 @@ TEST_F(CallExpressionTest, Assert_Null_Identifier) {
EXPECT_FATAL_FAILURE( EXPECT_FATAL_FAILURE(
{ {
ProgramBuilder b; ProgramBuilder b;
b.create<CallExpression>(static_cast<IdentifierExpression*>(nullptr), utils::Empty); b.Call(static_cast<Identifier*>(nullptr));
}, },
"internal compiler error"); "internal compiler error");
} }
@ -95,7 +95,7 @@ TEST_F(CallExpressionTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE( EXPECT_FATAL_FAILURE(
{ {
ProgramBuilder b; ProgramBuilder b;
b.create<CallExpression>(static_cast<Type*>(nullptr), utils::Empty); b.Construct(static_cast<Type*>(nullptr));
}, },
"internal compiler error"); "internal compiler error");
} }
@ -104,11 +104,11 @@ TEST_F(CallExpressionTest, Assert_Null_Param) {
EXPECT_FATAL_FAILURE( EXPECT_FATAL_FAILURE(
{ {
ProgramBuilder b; ProgramBuilder b;
b.create<CallExpression>(b.Expr("func"), utils::Vector{ b.Call(b.Ident("func"), utils::Vector{
b.Expr("param1"), b.Expr("param1"),
nullptr, nullptr,
b.Expr("param2"), b.Expr("param2"),
}); });
}, },
"internal compiler error"); "internal compiler error");
} }
@ -118,7 +118,7 @@ TEST_F(CallExpressionTest, Assert_DifferentProgramID_Identifier) {
{ {
ProgramBuilder b1; ProgramBuilder b1;
ProgramBuilder b2; ProgramBuilder b2;
b1.create<CallExpression>(b2.Expr("func"), utils::Empty); b1.Call(b2.Ident("func"));
}, },
"internal compiler error"); "internal compiler error");
} }
@ -128,7 +128,7 @@ TEST_F(CallExpressionTest, Assert_DifferentProgramID_Type) {
{ {
ProgramBuilder b1; ProgramBuilder b1;
ProgramBuilder b2; ProgramBuilder b2;
b1.create<CallExpression>(b2.ty.f32(), utils::Empty); b1.Construct(b2.ty.f32());
}, },
"internal compiler error"); "internal compiler error");
} }
@ -138,7 +138,7 @@ TEST_F(CallExpressionTest, Assert_DifferentProgramID_Param) {
{ {
ProgramBuilder b1; ProgramBuilder b1;
ProgramBuilder b2; ProgramBuilder b2;
b1.create<CallExpression>(b1.Expr("func"), utils::Vector{b2.Expr("param1")}); b1.Call(b1.Ident("func"), b2.Expr("param1"));
}, },
"internal compiler error"); "internal compiler error");
} }

View File

@ -23,14 +23,14 @@ namespace {
using CallStatementTest = TestHelper; using CallStatementTest = TestHelper;
TEST_F(CallStatementTest, Creation) { TEST_F(CallStatementTest, Creation) {
auto* expr = create<CallExpression>(Expr("func"), utils::Empty); auto* expr = Call("func");
auto* c = create<CallStatement>(expr); auto* c = CallStmt(expr);
EXPECT_EQ(c->expr, expr); EXPECT_EQ(c->expr, expr);
} }
TEST_F(CallStatementTest, IsCall) { TEST_F(CallStatementTest, IsCall) {
auto* c = create<CallStatement>(Call("f")); auto* c = CallStmt(Call("f"));
EXPECT_TRUE(c->Is<CallStatement>()); EXPECT_TRUE(c->Is<CallStatement>());
} }
@ -38,7 +38,7 @@ TEST_F(CallStatementTest, Assert_Null_Call) {
EXPECT_FATAL_FAILURE( EXPECT_FATAL_FAILURE(
{ {
ProgramBuilder b; ProgramBuilder b;
b.create<CallStatement>(nullptr); b.CallStmt(nullptr);
}, },
"internal compiler error"); "internal compiler error");
} }
@ -48,7 +48,7 @@ TEST_F(CallStatementTest, Assert_DifferentProgramID_Call) {
{ {
ProgramBuilder b1; ProgramBuilder b1;
ProgramBuilder b2; ProgramBuilder b2;
b1.create<CallStatement>(b2.create<CallExpression>(b2.Expr("func"), utils::Empty)); b1.CallStmt(b2.Call("func"));
}, },
"internal compiler error"); "internal compiler error");
} }

View File

@ -1160,7 +1160,13 @@ class ProgramBuilder {
/// @return an ast::Identifier with the given symbol /// @return an ast::Identifier with the given symbol
template <typename IDENTIFIER> template <typename IDENTIFIER>
const ast::Identifier* Ident(IDENTIFIER&& identifier) { const ast::Identifier* Ident(IDENTIFIER&& identifier) {
return create<ast::Identifier>(Sym(std::forward<IDENTIFIER>(identifier))); if constexpr (traits::IsTypeOrDerived<
std::decay_t<std::remove_pointer_t<std::decay_t<IDENTIFIER>>>,
ast::Identifier>) {
return identifier; // Pass-through
} else {
return create<ast::Identifier>(Sym(std::forward<IDENTIFIER>(identifier)));
}
} }
/// @param expr the expression /// @param expr the expression
@ -2054,7 +2060,7 @@ class ProgramBuilder {
/// arguments of `args` converted to `ast::Expression`s using `Expr()`. /// arguments of `args` converted to `ast::Expression`s using `Expr()`.
template <typename NAME, typename... ARGS> template <typename NAME, typename... ARGS>
const ast::CallExpression* Call(const Source& source, NAME&& func, ARGS&&... args) { const ast::CallExpression* Call(const Source& source, NAME&& func, ARGS&&... args) {
return create<ast::CallExpression>(source, Expr(func), return create<ast::CallExpression>(source, Ident(func),
ExprList(std::forward<ARGS>(args)...)); ExprList(std::forward<ARGS>(args)...));
} }
@ -2064,7 +2070,7 @@ class ProgramBuilder {
/// arguments of `args` converted to `ast::Expression`s using `Expr()`. /// arguments of `args` converted to `ast::Expression`s using `Expr()`.
template <typename NAME, typename... ARGS, typename = DisableIfSource<NAME>> template <typename NAME, typename... ARGS, typename = DisableIfSource<NAME>>
const ast::CallExpression* Call(NAME&& func, ARGS&&... args) { const ast::CallExpression* Call(NAME&& func, ARGS&&... args) {
return create<ast::CallExpression>(Expr(func), ExprList(std::forward<ARGS>(args)...)); return create<ast::CallExpression>(Ident(func), ExprList(std::forward<ARGS>(args)...));
} }
/// @param source the source information /// @param source the source information

View File

@ -1316,10 +1316,10 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() {
// Call the inner function. It has no parameters. // Call the inner function. It has no parameters.
stmts.Push(create<ast::CallStatement>( stmts.Push(create<ast::CallStatement>(
source, source,
create<ast::CallExpression>(source, create<ast::CallExpression>(
create<ast::IdentifierExpression>( source,
source, builder_.Symbols().Register(ep_info_->inner_name)), create<ast::Identifier>(source, builder_.Symbols().Register(ep_info_->inner_name)),
utils::Empty))); utils::Empty)));
// Pipeline outputs are mapped to the return value. // Pipeline outputs are mapped to the return value.
if (ep_info_->outputs.IsEmpty()) { if (ep_info_->outputs.IsEmpty()) {
@ -3854,7 +3854,7 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
params.Push(MakeOperand(inst, 0).expr); params.Push(MakeOperand(inst, 0).expr);
return {ast_type, create<ast::CallExpression>( return {ast_type, create<ast::CallExpression>(
Source{}, Source{},
create<ast::IdentifierExpression>( create<ast::Identifier>(
Source{}, builder_.Symbols().Register(unary_builtin_name)), Source{}, builder_.Symbols().Register(unary_builtin_name)),
std::move(params))}; std::move(params))};
} }
@ -4106,7 +4106,7 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(const spvtools::opt::Inst
return {}; return {};
} }
auto* func = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name)); auto* func = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList operands; ExpressionList operands;
const Type* first_operand_type = nullptr; const Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs. // All parameters to GLSL.std.450 extended instructions are IDs.
@ -5212,7 +5212,7 @@ TypedExpression FunctionEmitter::MakeNumericConversion(const spvtools::opt::Inst
bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) { bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
// We ignore function attributes such as Inline, DontInline, Pure, Const. // We ignore function attributes such as Inline, DontInline, Pure, Const.
auto name = namer_.Name(inst.GetSingleWordInOperand(0)); auto name = namer_.Name(inst.GetSingleWordInOperand(0));
auto* function = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name)); auto* function = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList args; ExpressionList args;
for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) { for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
@ -5302,7 +5302,7 @@ bool FunctionEmitter::EmitControlBarrier(const spvtools::opt::Instruction& inst)
TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instruction& inst) { TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instruction& inst) {
const auto builtin = GetBuiltin(opcode(inst)); const auto builtin = GetBuiltin(opcode(inst));
auto* name = sem::str(builtin); auto* name = sem::str(builtin);
auto* ident = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name)); auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList params; ExpressionList params;
const Type* first_operand_type = nullptr; const Type* first_operand_type = nullptr;
@ -5341,11 +5341,10 @@ TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instructi
params.Push(true_value.expr); params.Push(true_value.expr);
// The condition goes last. // The condition goes last.
params.Push(condition.expr); params.Push(condition.expr);
return {op_ty, return {op_ty, create<ast::CallExpression>(
create<ast::CallExpression>(Source{}, Source{},
create<ast::IdentifierExpression>( create<ast::Identifier>(Source{}, builder_.Symbols().Register("select")),
Source{}, builder_.Symbols().Register("select")), std::move(params))};
std::move(params))};
} }
return {}; return {};
} }
@ -5650,8 +5649,7 @@ bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
return false; return false;
} }
auto* ident = auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(builtin_name));
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(builtin_name));
auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(args)); auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(args));
if (inst.type_id() != 0) { if (inst.type_id() != 0) {
@ -5741,8 +5739,8 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
// Invoke textureDimensions. // Invoke textureDimensions.
// If the texture is arrayed, combine with the result from // If the texture is arrayed, combine with the result from
// textureNumLayers. // textureNumLayers.
auto* dims_ident = create<ast::IdentifierExpression>( auto* dims_ident =
Source{}, builder_.Symbols().Register("textureDimensions")); create<ast::Identifier>(Source{}, builder_.Symbols().Register("textureDimensions"));
ExpressionList dims_args{GetImageExpression(inst)}; ExpressionList dims_args{GetImageExpression(inst)};
if (op == spv::Op::OpImageQuerySizeLod) { if (op == spv::Op::OpImageQuerySizeLod) {
dims_args.Push(MakeOperand(inst, 1).expr); dims_args.Push(MakeOperand(inst, 1).expr);
@ -5758,7 +5756,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
} }
exprs.Push(dims_call); exprs.Push(dims_call);
if (ast::IsTextureArray(dims)) { if (ast::IsTextureArray(dims)) {
auto* layers_ident = create<ast::IdentifierExpression>( auto* layers_ident = create<ast::Identifier>(
Source{}, builder_.Symbols().Register("textureNumLayers")); Source{}, builder_.Symbols().Register("textureNumLayers"));
auto num_layers = create<ast::CallExpression>( auto num_layers = create<ast::CallExpression>(
Source{}, layers_ident, utils::Vector{GetImageExpression(inst)}); Source{}, layers_ident, utils::Vector{GetImageExpression(inst)});
@ -5789,7 +5787,7 @@ bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
const auto* name = const auto* name =
(op == spv::Op::OpImageQueryLevels) ? "textureNumLevels" : "textureNumSamples"; (op == spv::Op::OpImageQueryLevels) ? "textureNumLevels" : "textureNumSamples";
auto* levels_ident = auto* levels_ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name)); create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
const ast::Expression* ast_expr = create<ast::CallExpression>( const ast::Expression* ast_expr = create<ast::CallExpression>(
Source{}, levels_ident, utils::Vector{GetImageExpression(inst)}); Source{}, levels_ident, utils::Vector{GetImageExpression(inst)});
auto* result_type = parser_impl_.ConvertType(inst.type_id()); auto* result_type = parser_impl_.ConvertType(inst.type_id());

View File

@ -2436,7 +2436,7 @@ Maybe<const ast::CallStatement*> ParserImpl::func_call_statement() {
t.source(), t.source(),
create<ast::CallExpression>( create<ast::CallExpression>(
t.source(), t.source(),
create<ast::IdentifierExpression>(t.source(), builder_.Symbols().Register(t.to_str())), create<ast::Identifier>(t.source(), builder_.Symbols().Register(t.to_str())),
std::move(params.value))); std::move(params.value)));
} }
@ -2642,19 +2642,19 @@ Maybe<const ast::Expression*> ParserImpl::primary_expression() {
"in parentheses"); "in parentheses");
} }
auto* ident =
create<ast::IdentifierExpression>(t.source(), builder_.Symbols().Register(t.to_str()));
if (peek_is(Token::Type::kParenLeft)) { if (peek_is(Token::Type::kParenLeft)) {
auto params = expect_argument_expression_list("function call"); auto params = expect_argument_expression_list("function call");
if (params.errored) { if (params.errored) {
return Failure::kErrored; return Failure::kErrored;
} }
auto* ident =
create<ast::Identifier>(t.source(), builder_.Symbols().Register(t.to_str()));
return create<ast::CallExpression>(t.source(), ident, std::move(params.value)); return create<ast::CallExpression>(t.source(), ident, std::move(params.value));
} }
return ident; return create<ast::IdentifierExpression>(t.source(),
builder_.Symbols().Register(t.to_str()));
} }
if (t.Is(Token::Type::kParenLeft)) { if (t.Is(Token::Type::kParenLeft)) {

View File

@ -39,10 +39,7 @@ TEST_F(ResolverBuiltinValidationTest, FunctionTypeMustMatchReturnStatementType_v
TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageDirect) { TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageDirect) {
// @compute @workgroup_size(1) fn func { return dpdx(1.0); } // @compute @workgroup_size(1) fn func { return dpdx(1.0); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"), auto* dpdx = Call(Source{{3, 4}}, "dpdx", 1_f);
utils::Vector{
Expr(1_f),
});
Func(Source{{1, 2}}, "func", utils::Empty, ty.void_(), Func(Source{{1, 2}}, "func", utils::Empty, ty.void_(),
utils::Vector{ utils::Vector{
CallStmt(dpdx), CallStmt(dpdx),
@ -62,10 +59,7 @@ TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageIndirect) {
// fn f2 { f1(); } // fn f2 { f1(); }
// @compute @workgroup_size(1) fn main { return f2(); } // @compute @workgroup_size(1) fn main { return f2(); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"), auto* dpdx = Call(Source{{3, 4}}, "dpdx", 1_f);
utils::Vector{
Expr(1_f),
});
Func(Source{{1, 2}}, "f0", utils::Empty, ty.void_(), Func(Source{{1, 2}}, "f0", utils::Empty, ty.void_(),
utils::Vector{ utils::Vector{
CallStmt(dpdx), CallStmt(dpdx),
@ -138,7 +132,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsFunctionUsedAsType) {
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalConstUsedAsFunction) { TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalConstUsedAsFunction) {
GlobalConst(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i)); GlobalConst(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i));
WrapInFunction(Call(Expr(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f)); WrapInFunction(Call(Ident(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix' EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix'
@ -167,7 +161,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalConstUsedAsType)
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalVarUsedAsFunction) { TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalVarUsedAsFunction) {
GlobalVar(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i), type::AddressSpace::kPrivate); GlobalVar(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i), type::AddressSpace::kPrivate);
WrapInFunction(Call(Expr(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f)); WrapInFunction(Call(Ident(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix' EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix'

View File

@ -39,6 +39,7 @@
#include "src/tint/ast/for_loop_statement.h" #include "src/tint/ast/for_loop_statement.h"
#include "src/tint/ast/i32.h" #include "src/tint/ast/i32.h"
#include "src/tint/ast/id_attribute.h" #include "src/tint/ast/id_attribute.h"
#include "src/tint/ast/identifier.h"
#include "src/tint/ast/if_statement.h" #include "src/tint/ast/if_statement.h"
#include "src/tint/ast/increment_decrement_statement.h" #include "src/tint/ast/increment_decrement_statement.h"
#include "src/tint/ast/internal_attribute.h" #include "src/tint/ast/internal_attribute.h"

View File

@ -548,7 +548,7 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind, Symbol symbol, Source
return node; return node;
} }
case SymbolUseKind::CallFunction: { case SymbolUseKind::CallFunction: {
auto* node = b.Expr(source, symbol); auto* node = b.Ident(source, symbol);
statements.Push(b.CallStmt(b.Call(node))); statements.Push(b.CallStmt(b.Call(node)));
return node; return node;
} }
@ -651,7 +651,8 @@ TEST_F(ResolverDependencyGraphUsedBeforeDeclTest, FuncCall) {
// fn A() { B(); } // fn A() { B(); }
// fn B() {} // fn B() {}
Func("A", utils::Empty, ty.void_(), utils::Vector{CallStmt(Call(Expr(Source{{12, 34}}, "B")))}); Func("A", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Ident(Source{{12, 34}}, "B")))});
Func(Source{{56, 78}}, "B", utils::Empty, ty.void_(), utils::Vector{Return()}); Func(Source{{56, 78}}, "B", utils::Empty, ty.void_(), utils::Vector{Return()});
Build(); Build();
@ -812,7 +813,7 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, DirectCall) {
// fn main() { main(); } // fn main() { main(); }
Func(Source{{12, 34}}, "main", utils::Empty, ty.void_(), Func(Source{{12, 34}}, "main", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{56, 78}}, "main")))}); utils::Vector{CallStmt(Call(Ident(Source{{56, 78}}, "main")))});
Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main' Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main'
56:78 note: function 'main' calls function 'main' here)"); 56:78 note: function 'main' calls function 'main' here)");
@ -826,17 +827,17 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) {
// 5: fn b() { c(); } // 5: fn b() { c(); }
Func(Source{{1, 1}}, "a", utils::Empty, ty.void_(), Func(Source{{1, 1}}, "a", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{1, 10}}, "b")))}); utils::Vector{CallStmt(Call(Ident(Source{{1, 10}}, "b")))});
Func(Source{{2, 1}}, "e", utils::Empty, ty.void_(), utils::Empty); Func(Source{{2, 1}}, "e", utils::Empty, ty.void_(), utils::Empty);
Func(Source{{3, 1}}, "d", utils::Empty, ty.void_(), Func(Source{{3, 1}}, "d", utils::Empty, ty.void_(),
utils::Vector{ utils::Vector{
CallStmt(Call(Expr(Source{{3, 10}}, "e"))), CallStmt(Call(Ident(Source{{3, 10}}, "e"))),
CallStmt(Call(Expr(Source{{3, 10}}, "b"))), CallStmt(Call(Ident(Source{{3, 10}}, "b"))),
}); });
Func(Source{{4, 1}}, "c", utils::Empty, ty.void_(), Func(Source{{4, 1}}, "c", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{4, 10}}, "d")))}); utils::Vector{CallStmt(Call(Ident(Source{{4, 10}}, "d")))});
Func(Source{{5, 1}}, "b", utils::Empty, ty.void_(), Func(Source{{5, 1}}, "b", utils::Empty, ty.void_(),
utils::Vector{CallStmt(Call(Expr(Source{{5, 10}}, "c")))}); utils::Vector{CallStmt(Call(Ident(Source{{5, 10}}, "c")))});
Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b' Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b'
5:10 note: function 'b' calls function 'c' here 5:10 note: function 'b' calls function 'c' here
@ -1232,7 +1233,7 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
}; };
#define V add_use(value_decl, Expr(value_sym), __LINE__, "V()") #define V add_use(value_decl, Expr(value_sym), __LINE__, "V()")
#define T add_use(type_decl, ty.type_name(type_sym), __LINE__, "T()") #define T add_use(type_decl, ty.type_name(type_sym), __LINE__, "T()")
#define F add_use(func_decl, Expr(func_sym), __LINE__, "F()") #define F add_use(func_decl, Ident(func_sym), __LINE__, "F()")
Alias(Sym(), T); Alias(Sym(), T);
Structure(Sym(), // Structure(Sym(), //

View File

@ -5330,7 +5330,7 @@ TEST_F(UniformityAnalysisTest, MaximumNumberOfPointerParameters) {
args.Push(b.AddressOf(name)); args.Push(b.AddressOf(name));
} }
main_body.Push(b.Assign("v0", "non_uniform_global")); main_body.Push(b.Assign("v0", "non_uniform_global"));
main_body.Push(b.CallStmt(b.create<ast::CallExpression>(b.Expr("foo"), args))); main_body.Push(b.CallStmt(b.create<ast::CallExpression>(b.Ident("foo"), args)));
main_body.Push(b.If(b.Equal("v254", 0_i), b.Block(b.CallStmt(b.Call("workgroupBarrier"))))); main_body.Push(b.If(b.Equal("v254", 0_i), b.Block(b.CallStmt(b.Call("workgroupBarrier")))));
b.Func("main", utils::Empty, ty.void_(), main_body); b.Func("main", utils::Empty, ty.void_(), main_body);

View File

@ -434,14 +434,13 @@ struct MultiplanarExternalTexture::State {
buildTextureBuiltinBody(sem::BuiltinType::kTextureSampleBaseClampToEdge)); buildTextureBuiltinBody(sem::BuiltinType::kTextureSampleBaseClampToEdge));
} }
const ast::IdentifierExpression* exp = b.Expr(texture_sample_external_sym); return b.Call(texture_sample_external_sym, utils::Vector{
return b.Call(exp, utils::Vector{ plane_0_binding_param,
plane_0_binding_param, b.Expr(syms.plane_1),
b.Expr(syms.plane_1), ctx.Clone(expr->args[1]),
ctx.Clone(expr->args[1]), ctx.Clone(expr->args[2]),
ctx.Clone(expr->args[2]), b.Expr(syms.params),
b.Expr(syms.params), });
});
} }
/// Creates the textureLoadExternal function if needed and returns a call expression to it. /// Creates the textureLoadExternal function if needed and returns a call expression to it.

View File

@ -512,9 +512,6 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return clone_maybe_hoisted(bitcast); return clone_maybe_hoisted(bitcast);
}, },
[&](const ast::CallExpression* call) { [&](const ast::CallExpression* call) {
if (call->target.name) {
ctx.Replace(call->target.name, decompose(call->target.name));
}
for (auto* a : call->args) { for (auto* a : call->args) {
ctx.Replace(a, decompose(a)); ctx.Replace(a, decompose(a));
} }

View File

@ -1262,7 +1262,9 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
CloneContext ctx{&b, src, /* auto_clone_symbols */ false}; CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
// Identifiers that need to keep their symbols preserved. // Identifiers that need to keep their symbols preserved.
utils::Hashset<const ast::IdentifierExpression*, 8> preserved_identifiers; utils::Hashset<const ast::Identifier*, 8> preserved_identifiers;
// Identifiers expressions that need to keep their symbols preserved.
utils::Hashset<const ast::IdentifierExpression*, 8> preserved_identifiers_expressions;
// Type names that need to keep their symbols preserved. // Type names that need to keep their symbols preserved.
utils::Hashset<const ast::TypeName*, 8> preserved_type_names; utils::Hashset<const ast::TypeName*, 8> preserved_type_names;
@ -1287,11 +1289,11 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
[&](const ast::MemberAccessorExpression* accessor) { [&](const ast::MemberAccessorExpression* accessor) {
auto* sem = src->Sem().Get(accessor)->UnwrapLoad(); auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
if (sem->Is<sem::Swizzle>()) { if (sem->Is<sem::Swizzle>()) {
preserved_identifiers.Add(accessor->member); preserved_identifiers_expressions.Add(accessor->member);
} else if (auto* str_expr = src->Sem().Get(accessor->structure)) { } else if (auto* str_expr = src->Sem().Get(accessor->structure)) {
if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) { if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
if (ty->Declaration() == nullptr) { // Builtin structure if (ty->Declaration() == nullptr) { // Builtin structure
preserved_identifiers.Add(accessor->member); preserved_identifiers_expressions.Add(accessor->member);
} }
} }
} }
@ -1314,7 +1316,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
} }
}, },
[&](const ast::DiagnosticControl* diagnostic) { [&](const ast::DiagnosticControl* diagnostic) {
preserved_identifiers.Add(diagnostic->rule_name); preserved_identifiers_expressions.Add(diagnostic->rule_name);
}, },
[&](const ast::TypeName* type_name) { [&](const ast::TypeName* type_name) {
if (is_type_short_name(type_name->name)) { if (is_type_short_name(type_name->name)) {
@ -1376,8 +1378,18 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
return sym_out; return sym_out;
}); });
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* { ctx.ReplaceAll([&](const ast::Identifier* ident) -> const ast::Identifier* {
if (preserved_identifiers.Contains(ident)) { if (preserved_identifiers.Contains(ident)) {
auto sym_in = ident->symbol;
auto str = src->Symbols().NameFor(sym_in);
auto sym_out = b.Symbols().Register(str);
return ctx.dst->create<ast::Identifier>(ctx.Clone(ident->source), sym_out);
}
return nullptr; // Clone ident. Uses the symbol remapping above.
});
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
if (preserved_identifiers_expressions.Contains(ident)) {
auto sym_in = ident->symbol; auto sym_in = ident->symbol;
auto str = src->Symbols().NameFor(sym_in); auto str = src->Symbols().NameFor(sym_in);
auto sym_out = b.Symbols().Register(str); auto sym_out = b.Symbols().Register(str);

View File

@ -34,8 +34,7 @@ using GlslImportData_SingleParamTest = TestParamHelper<GlslImportData>;
TEST_P(GlslImportData_SingleParamTest, FloatScalar) { TEST_P(GlslImportData_SingleParamTest, FloatScalar) {
auto param = GetParam(); auto param = GetParam();
auto* ident = Expr(param.name); auto* expr = Call(param.name, 1_f);
auto* expr = Call(ident, 1_f);
WrapInFunction(expr); WrapInFunction(expr);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -91,8 +90,7 @@ using GlslImportData_SingleVectorParamTest = TestParamHelper<GlslImportData>;
TEST_P(GlslImportData_SingleVectorParamTest, FloatVector) { TEST_P(GlslImportData_SingleVectorParamTest, FloatVector) {
auto param = GetParam(); auto param = GetParam();
auto* ident = Expr(param.name); auto* expr = Call(param.name, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
auto* expr = Call(ident, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
WrapInFunction(expr); WrapInFunction(expr);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -34,8 +34,7 @@ using HlslImportData_SingleParamTest = TestParamHelper<HlslImportData>;
TEST_P(HlslImportData_SingleParamTest, FloatScalar) { TEST_P(HlslImportData_SingleParamTest, FloatScalar) {
auto param = GetParam(); auto param = GetParam();
auto* ident = Expr(param.name); auto* expr = Call(param.name, 1_f);
auto* expr = Call(ident, 1_f);
WrapInFunction(expr); WrapInFunction(expr);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -90,8 +89,7 @@ using HlslImportData_SingleVectorParamTest = TestParamHelper<HlslImportData>;
TEST_P(HlslImportData_SingleVectorParamTest, FloatVector) { TEST_P(HlslImportData_SingleVectorParamTest, FloatVector) {
auto param = GetParam(); auto param = GetParam();
auto* ident = Expr(param.name); auto* expr = Call(param.name, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
auto* expr = Call(ident, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
WrapInFunction(expr); WrapInFunction(expr);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -276,7 +276,7 @@ TEST_P(MslGeneratorBuiltinTextureTest, Call) {
param.BuildTextureVariable(this); param.BuildTextureVariable(this);
param.BuildSamplerVariable(this); param.BuildSamplerVariable(this);
auto* call = Call(Expr(param.function), param.args(this)); auto* call = Call(Ident(param.function), param.args(this));
auto* stmt = CallStmt(call); auto* stmt = CallStmt(call);
Func("main", utils::Empty, ty.void_(), utils::Vector{stmt}, Func("main", utils::Empty, ty.void_(), utils::Vector{stmt},

View File

@ -239,9 +239,7 @@ bool GeneratorImpl::EmitBitcast(std::ostream& out, const ast::BitcastExpression*
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) { bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
if (expr->target.name) { if (expr->target.name) {
if (!EmitExpression(out, expr->target.name)) { out << program_->Symbols().NameFor(expr->target.name->symbol);
return false;
}
} else if (TINT_LIKELY(expr->target.type)) { } else if (TINT_LIKELY(expr->target.type)) {
if (!EmitType(out, expr->target.type)) { if (!EmitType(out, expr->target.type)) {
return false; return false;