From 14b340314808ab3c528239ae6b4b130eecc08125 Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Wed, 9 Jun 2021 20:17:59 +0000 Subject: [PATCH] Validate function call arguments - Add resolver/call_test.cc for new unit tests, and move a couple that were in resolver/validation_test.cc to it - Fix CalculateArrayLength transform so that it passes the address of the u32 it creates to the internal function - Fix tests broken as a result of this change Bug: tint:664 Change-Id: If713f9828790cd51224d2392d42c01c0057cb652 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53920 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Antonio Maiorano --- src/CMakeLists.txt | 1 + src/program_builder.h | 13 ++ src/resolver/call_test.cc | 194 ++++++++++++++++++ src/resolver/resolver.cc | 121 +++++++---- src/resolver/resolver.h | 8 +- src/resolver/resolver_test.cc | 3 +- src/resolver/validation_test.cc | 50 ----- src/sem/call.cc | 2 +- src/sem/call.h | 2 +- src/sem/expression.cc | 2 +- src/sem/expression.h | 8 +- src/traits.h | 27 ++- src/transform/calculate_array_length.cc | 5 +- src/transform/calculate_array_length_test.cc | 14 +- src/writer/hlsl/generator_impl_binary_test.cc | 9 +- src/writer/hlsl/generator_impl_call_test.cc | 16 +- src/writer/msl/generator_impl_call_test.cc | 16 +- src/writer/wgsl/generator_impl_call_test.cc | 16 +- test/BUILD.gn | 3 +- 19 files changed, 371 insertions(+), 139 deletions(-) create mode 100644 src/resolver/call_test.cc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 25c90edac3..5a9494edbc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -589,6 +589,7 @@ if(${TINT_BUILD_TESTS}) resolver/assignment_validation_test.cc resolver/block_test.cc resolver/builtins_validation_test.cc + resolver/call_test.cc resolver/control_block_validation_test.cc resolver/decoration_validation_test.cc resolver/entry_point_validation_test.cc diff --git a/src/program_builder.h b/src/program_builder.h index 02eae062c4..b04a5adde5 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -1403,11 +1403,24 @@ class ProgramBuilder { Expr(std::forward(expr))); } + /// @param source the source information /// @param func the function name /// @param args the function call arguments /// @returns a `ast::CallExpression` to the function `func`, with the /// arguments of `args` converted to `ast::Expression`s using `Expr()`. template + ast::CallExpression* Call(const Source& source, NAME&& func, ARGS&&... args) { + return create(source, Expr(func), + ExprList(std::forward(args)...)); + } + + /// @param func the function name + /// @param args the function call arguments + /// @returns a `ast::CallExpression` to the function `func`, with the + /// arguments of `args` converted to `ast::Expression`s using `Expr()`. + template , Source>* = nullptr> ast::CallExpression* Call(NAME&& func, ARGS&&... args) { return create(Expr(func), ExprList(std::forward(args)...)); diff --git a/src/resolver/call_test.cc b/src/resolver/call_test.cc new file mode 100644 index 0000000000..f9e2af5823 --- /dev/null +++ b/src/resolver/call_test.cc @@ -0,0 +1,194 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/resolver/resolver.h" + +#include "gmock/gmock.h" +#include "src/ast/call_statement.h" +#include "src/resolver/resolver_test_helper.h" + +namespace tint { +namespace resolver { +namespace { +// Helpers and typedefs +template +using DataType = builder::DataType; +template +using vec = builder::vec; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using mat = builder::mat; +template +using mat2x2 = builder::mat2x2; +template +using mat2x3 = builder::mat2x3; +template +using mat3x2 = builder::mat3x2; +template +using mat3x3 = builder::mat3x3; +template +using mat4x4 = builder::mat4x4; +template +using alias = builder::alias; +template +using alias1 = builder::alias1; +template +using alias2 = builder::alias2; +template +using alias3 = builder::alias3; +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; + +using ResolverCallTest = ResolverTest; + +TEST_F(ResolverCallTest, Recursive_Invalid) { + // fn main() {main(); } + + SetSource(Source::Location{12, 34}); + auto* call_expr = Call("main"); + ast::VariableList params0; + + Func("main", params0, ty.void_(), + ast::StatementList{ + create(call_expr), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kVertex), + }); + + EXPECT_FALSE(r()->Resolve()); + + EXPECT_EQ(r()->error(), + "12:34 error v-0004: recursion is not permitted. 'main' attempted " + "to call " + "itself."); +} + +TEST_F(ResolverCallTest, Undeclared_Invalid) { + // fn main() {func(); return; } + // fn func() { return; } + + SetSource(Source::Location{12, 34}); + auto* call_expr = Call("func"); + ast::VariableList params0; + + Func("main", params0, ty.f32(), + ast::StatementList{ + create(call_expr), + Return(), + }, + ast::DecorationList{}); + + Func("func", params0, ty.f32(), + ast::StatementList{ + Return(), + }, + ast::DecorationList{}); + + EXPECT_FALSE(r()->Resolve()); + + EXPECT_EQ(r()->error(), + "12:34 error: v-0006: unable to find called function: func"); +} + +struct Params { + builder::ast_expr_func_ptr create_value; + builder::ast_type_func_ptr create_type; +}; + +template +constexpr Params ParamsFor() { + return Params{DataType::Expr, DataType::AST}; +} + +static constexpr Params all_param_types[] = { + ParamsFor(), // + ParamsFor(), // + ParamsFor(), // + ParamsFor(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>() // +}; + +TEST_F(ResolverCallTest, Valid) { + ast::VariableList params; + ast::ExpressionList args; + for (auto& p : all_param_types) { + params.push_back(Param(Sym(), p.create_type(*this))); + args.push_back(p.create_value(*this, 0)); + } + + Func("foo", std::move(params), ty.void_(), {Return()}); + auto* call = Call("foo", std::move(args)); + WrapInFunction(call); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverCallTest, TooFewArgs) { + Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(), + {Return()}); + auto* call = Call(Source{{12, 34}}, "foo", 1); + WrapInFunction(call); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: too few arguments in call to 'foo', expected 2, got 1"); +} + +TEST_F(ResolverCallTest, TooManyArgs) { + Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(), + {Return()}); + auto* call = Call(Source{{12, 34}}, "foo", 1, 1.0f, 1.0f); + WrapInFunction(call); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: too many arguments in call to 'foo', expected 2, got 3"); +} + +TEST_F(ResolverCallTest, MismatchedArgs) { + Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(), + {Return()}); + auto* call = Call("foo", Expr(Source{{12, 34}}, true), 1.0f); + WrapInFunction(call); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: type mismatch for argument 1 in call to 'foo', " + "expected 'i32', got 'bool'"); +} + +} // namespace +} // namespace resolver +} // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 27fbd1b2a0..41c411423e 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1701,51 +1701,9 @@ bool Resolver::Call(ast::CallExpression* call) { return false; } } else { - if (current_function_) { - auto callee_func_it = symbol_to_function_.find(ident->symbol()); - if (callee_func_it == symbol_to_function_.end()) { - if (current_function_->declaration->symbol() == ident->symbol()) { - diagnostics_.add_error("v-0004", - "recursion is not permitted. '" + name + - "' attempted to call itself.", - call->source()); - } else { - diagnostics_.add_error( - "v-0006: unable to find called function: " + name, - call->source()); - } - return false; - } - auto* callee_func = callee_func_it->second; - - callee_func->callsites.push_back(call); - - // Note: Requires called functions to be resolved first. - // This is currently guaranteed as functions must be declared before - // use. - current_function_->transitive_calls.add(callee_func); - for (auto* transitive_call : callee_func->transitive_calls) { - current_function_->transitive_calls.add(transitive_call); - } - - // We inherit any referenced variables from the callee. - for (auto* var : callee_func->referenced_module_vars) { - set_referenced_from_function_if_needed(var, false); - } - } - - auto iter = symbol_to_function_.find(ident->symbol()); - if (iter == symbol_to_function_.end()) { - diagnostics_.add_error( - "v-0005: function must be declared before use: '" + name + "'", - call->source()); + if (!FunctionCall(call)) { return false; } - - auto* function = iter->second; - function_calls_.emplace(call, - FunctionCallInfo{function, current_statement_}); - SetType(call, function->return_type, function->return_type_name); } return true; @@ -1775,6 +1733,79 @@ bool Resolver::IntrinsicCall(ast::CallExpression* call, return true; } +bool Resolver::FunctionCall(const ast::CallExpression* call) { + auto* ident = call->func(); + auto name = builder_->Symbols().NameFor(ident->symbol()); + + auto callee_func_it = symbol_to_function_.find(ident->symbol()); + if (callee_func_it == symbol_to_function_.end()) { + if (current_function_ && + current_function_->declaration->symbol() == ident->symbol()) { + diagnostics_.add_error("v-0004", + "recursion is not permitted. '" + name + + "' attempted to call itself.", + call->source()); + } else { + diagnostics_.add_error("v-0006: unable to find called function: " + name, + call->source()); + } + return false; + } + auto* callee_func = callee_func_it->second; + + if (current_function_) { + callee_func->callsites.push_back(call); + + // Note: Requires called functions to be resolved first. + // This is currently guaranteed as functions must be declared before + // use. + current_function_->transitive_calls.add(callee_func); + for (auto* transitive_call : callee_func->transitive_calls) { + current_function_->transitive_calls.add(transitive_call); + } + + // We inherit any referenced variables from the callee. + for (auto* var : callee_func->referenced_module_vars) { + set_referenced_from_function_if_needed(var, false); + } + } + + // Validate number of arguments match number of parameters + if (call->params().size() != callee_func->parameters.size()) { + bool more = call->params().size() > callee_func->parameters.size(); + diagnostics_.add_error( + "too " + (more ? std::string("many") : std::string("few")) + + " arguments in call to '" + name + "', expected " + + std::to_string(callee_func->parameters.size()) + ", got " + + std::to_string(call->params().size()), + call->source()); + return false; + } + + // Validate arguments match parameter types + for (size_t i = 0; i < call->params().size(); ++i) { + const VariableInfo* param = callee_func->parameters[i]; + const ast::Expression* arg_expr = call->params()[i]; + auto* arg_type = TypeOf(arg_expr)->UnwrapRef(); + + if (param->type != arg_type) { + diagnostics_.add_error( + "type mismatch for argument " + std::to_string(i + 1) + + " in call to '" + name + "', expected '" + + param->type->FriendlyName(builder_->Symbols()) + "', got '" + + arg_type->FriendlyName(builder_->Symbols()) + "'", + arg_expr->source()); + return false; + } + } + + function_calls_.emplace(call, + FunctionCallInfo{callee_func, current_statement_}); + SetType(call, callee_func->return_type, callee_func->return_type_name); + + return true; +} + bool Resolver::Constructor(ast::ConstructorExpression* expr) { if (auto* type_ctor = expr->As()) { for (auto* value : type_ctor->values()) { @@ -2514,11 +2545,11 @@ sem::Type* Resolver::TypeOf(const ast::Literal* lit) { return nullptr; } -void Resolver::SetType(ast::Expression* expr, const sem::Type* type) { +void Resolver::SetType(const ast::Expression* expr, const sem::Type* type) { SetType(expr, type, type->FriendlyName(builder_->Symbols())); } -void Resolver::SetType(ast::Expression* expr, +void Resolver::SetType(const ast::Expression* expr, const sem::Type* type, const std::string& type_name) { if (expr_info_.count(expr)) { diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 03715fd41b..9f9102729e 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -234,6 +234,7 @@ class Resolver { bool Identifier(ast::IdentifierExpression*); bool IfStatement(ast::IfStatement*); bool IntrinsicCall(ast::CallExpression*, sem::IntrinsicType); + bool FunctionCall(const ast::CallExpression* call); bool LoopStatement(ast::LoopStatement*); bool MemberAccessor(ast::MemberAccessorExpression*); bool Parameter(ast::Variable* param); @@ -345,7 +346,7 @@ class Resolver { /// assigns this semantic node to the expression `expr`. /// @param expr the expression /// @param type the resolved type - void SetType(ast::Expression* expr, const sem::Type* type); + void SetType(const ast::Expression* expr, const sem::Type* type); /// Creates a sem::Expression node with the resolved type `type`, the declared /// type name `type_name` and assigns this semantic node to the expression @@ -353,7 +354,7 @@ class Resolver { /// @param expr the expression /// @param type the resolved type /// @param type_name the declared type name - void SetType(ast::Expression* expr, + void SetType(const ast::Expression* expr, const sem::Type* type, const std::string& type_name); @@ -396,7 +397,8 @@ class Resolver { std::vector entry_points_; std::unordered_map function_to_info_; std::unordered_map variable_to_info_; - std::unordered_map function_calls_; + std::unordered_map + function_calls_; std::unordered_map expr_info_; std::unordered_map named_type_info_; diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index d225fbb976..e060682c01 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -560,8 +560,7 @@ TEST_F(ResolverTest, Expr_Call_InBinaryOp) { } TEST_F(ResolverTest, Expr_Call_WithParams) { - ast::VariableList params; - Func("my_func", params, ty.f32(), + Func("my_func", {Param(Sym(), ty.f32())}, ty.f32(), { Return(1.2f), }); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index a0a5f231ae..1c62b894b8 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -83,56 +83,6 @@ TEST_F(ResolverValidationTest, Stmt_Error_Unknown) { "2:30 error: unknown statement type for type determination: Fake"); } -TEST_F(ResolverValidationTest, Stmt_Call_undeclared) { - // fn main() {func(); return; } - // fn func() { return; } - - SetSource(Source::Location{12, 34}); - auto* call_expr = Call("func"); - ast::VariableList params0; - - Func("main", params0, ty.f32(), - ast::StatementList{ - create(call_expr), - Return(), - }, - ast::DecorationList{}); - - Func("func", params0, ty.f32(), - ast::StatementList{ - Return(), - }, - ast::DecorationList{}); - - EXPECT_FALSE(r()->Resolve()); - - EXPECT_EQ(r()->error(), - "12:34 error: v-0006: unable to find called function: func"); -} - -TEST_F(ResolverValidationTest, Stmt_Call_recursive) { - // fn main() {main(); } - - SetSource(Source::Location{12, 34}); - auto* call_expr = Call("main"); - ast::VariableList params0; - - Func("main", params0, ty.void_(), - ast::StatementList{ - create(call_expr), - }, - ast::DecorationList{ - Stage(ast::PipelineStage::kVertex), - }); - - EXPECT_FALSE(r()->Resolve()); - - EXPECT_EQ(r()->error(), - "12:34 error v-0004: recursion is not permitted. 'main' attempted " - "to call " - "itself."); -} - TEST_F(ResolverValidationTest, Stmt_If_NonBool) { // if (1.23f) {} diff --git a/src/sem/call.cc b/src/sem/call.cc index baa4425ad1..3abb91ec5a 100644 --- a/src/sem/call.cc +++ b/src/sem/call.cc @@ -19,7 +19,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::sem::Call); namespace tint { namespace sem { -Call::Call(ast::Expression* declaration, +Call::Call(const ast::Expression* declaration, const CallTarget* target, Statement* statement) : Base(declaration, target->ReturnType(), statement), target_(target) {} diff --git a/src/sem/call.h b/src/sem/call.h index a3e3da7367..d2fdb314b4 100644 --- a/src/sem/call.h +++ b/src/sem/call.h @@ -29,7 +29,7 @@ class Call : public Castable { /// @param declaration the AST node /// @param target the call target /// @param statement the statement that owns this expression - Call(ast::Expression* declaration, + Call(const ast::Expression* declaration, const CallTarget* target, Statement* statement); diff --git a/src/sem/expression.cc b/src/sem/expression.cc index 74fb0c29d6..7286dc610f 100644 --- a/src/sem/expression.cc +++ b/src/sem/expression.cc @@ -19,7 +19,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::sem::Expression); namespace tint { namespace sem { -Expression::Expression(ast::Expression* declaration, +Expression::Expression(const ast::Expression* declaration, const sem::Type* type, Statement* statement) : declaration_(declaration), type_(type), statement_(statement) { diff --git a/src/sem/expression.h b/src/sem/expression.h index 73d971f23b..8c2e304e61 100644 --- a/src/sem/expression.h +++ b/src/sem/expression.h @@ -31,7 +31,7 @@ class Expression : public Castable { /// @param declaration the AST node /// @param type the resolved type of the expression /// @param statement the statement that owns this expression - Expression(ast::Expression* declaration, + Expression(const ast::Expression* declaration, const sem::Type* type, Statement* statement); @@ -42,10 +42,12 @@ class Expression : public Castable { Statement* Stmt() const { return statement_; } /// @returns the AST node - ast::Expression* Declaration() const { return declaration_; } + ast::Expression* Declaration() const { + return const_cast(declaration_); + } private: - ast::Expression* declaration_; + const ast::Expression* declaration_; const sem::Type* const type_; Statement* const statement_; }; diff --git a/src/traits.h b/src/traits.h index f35fe0f026..1fda4f0039 100644 --- a/src/traits.h +++ b/src/traits.h @@ -20,6 +20,10 @@ namespace tint { namespace traits { +/// Convience type definition for std::decay::type +template +using Decay = typename std::decay::type; + /// NthTypeOf returns the `N`th type in `Types` template using NthTypeOf = typename std::tuple_element>::type; @@ -38,7 +42,7 @@ struct ParamType { /// Arg is the raw type of the `N`th parameter of the function using Arg = NthTypeOf; /// The type of the `N`th parameter of the function - using type = typename std::decay::type; + using type = Decay; }; /// ParamType specialization for a non-static method. @@ -47,7 +51,7 @@ struct ParamType { /// Arg is the raw type of the `N`th parameter of the function using Arg = NthTypeOf; /// The type of the `N`th parameter of the function - using type = typename std::decay::type; + using type = Decay; }; /// ParamType specialization for a non-static, const method. @@ -56,7 +60,7 @@ struct ParamType { /// Arg is the raw type of the `N`th parameter of the function using Arg = NthTypeOf; /// The type of the `N`th parameter of the function - using type = typename std::decay::type; + using type = Decay; }; /// ParamTypeT is an alias to `typename ParamType::type`. @@ -66,21 +70,26 @@ using ParamTypeT = typename ParamType::type; /// `IsTypeOrDerived::value` is true iff `T` is of type `BASE`, or /// derives from `BASE`. template -using IsTypeOrDerived = std::integral_constant< - bool, - std::is_base_of::type>::value || - std::is_same::type>::value>; +using IsTypeOrDerived = + std::integral_constant>::value || + std::is_same>::value>; /// If `CONDITION` is true then EnableIf resolves to type T, otherwise an /// invalid type. template using EnableIf = typename std::enable_if::type; -/// If T is a base of BASE then EnableIfIsType resolves to type T, otherwise an -/// invalid type. +/// If `T` is of type `BASE`, or derives from `BASE`, then EnableIfIsType +/// resolves to type `T`, otherwise an invalid type. template using EnableIfIsType = EnableIf::value, T>; +/// If `T` is not of type `BASE`, or does not derive from `BASE`, then +/// EnableIfIsNotType resolves to type `T`, otherwise an invalid type. +template +using EnableIfIsNotType = EnableIf::value, T>; + } // namespace traits } // namespace tint diff --git a/src/transform/calculate_array_length.cc b/src/transform/calculate_array_length.cc index 0feadc160c..74962fbc6b 100644 --- a/src/transform/calculate_array_length.cc +++ b/src/transform/calculate_array_length.cc @@ -182,14 +182,15 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) { ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(), ast::StorageClass::kNone, ctx.dst->Expr(0u))); - // Call storage_buffer.GetDimensions(buffer_size_result) + // Call storage_buffer.GetDimensions(&buffer_size_result) auto* call_get_dims = ctx.dst->create(ctx.dst->Call( // BufferSizeIntrinsic(X, ARGS...) is // translated to: // X.GetDimensions(ARGS..) by the writer buffer_size, ctx.Clone(storage_buffer_expr), - buffer_size_result->variable()->symbol())); + ctx.dst->AddressOf(ctx.dst->Expr( + buffer_size_result->variable()->symbol())))); // Calculate actual array length // total_storage_buffer_size - array_offset diff --git a/src/transform/calculate_array_length_test.cc b/src/transform/calculate_array_length_test.cc index 557454b719..4821c58465 100644 --- a/src/transform/calculate_array_length_test.cc +++ b/src/transform/calculate_array_length_test.cc @@ -53,7 +53,7 @@ fn tint_symbol(buffer : SB, result : ptr) [[stage(compute)]] fn main() { var tint_symbol_1 : u32 = 0u; - tint_symbol(sb, tint_symbol_1); + tint_symbol(sb, &(tint_symbol_1)); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); var len : u32 = tint_symbol_2; } @@ -97,7 +97,7 @@ fn tint_symbol(buffer : SB, result : ptr) [[stage(compute)]] fn main() { var tint_symbol_1 : u32 = 0u; - tint_symbol(sb, tint_symbol_1); + tint_symbol(sb, &(tint_symbol_1)); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); var a : u32 = tint_symbol_2; var b : u32 = tint_symbol_2; @@ -143,7 +143,7 @@ fn tint_symbol(buffer : SB, result : ptr) [[stage(compute)]] fn main() { var tint_symbol_1 : u32 = 0u; - tint_symbol(sb, tint_symbol_1); + tint_symbol(sb, &(tint_symbol_1)); let tint_symbol_2 : u32 = ((tint_symbol_1 - 8u) / 64u); var len : u32 = tint_symbol_2; } @@ -192,13 +192,13 @@ fn tint_symbol(buffer : SB, result : ptr) fn main() { if (true) { var tint_symbol_1 : u32 = 0u; - tint_symbol(sb, tint_symbol_1); + tint_symbol(sb, &(tint_symbol_1)); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); var len : u32 = tint_symbol_2; } else { if (true) { var tint_symbol_3 : u32 = 0u; - tint_symbol(sb, tint_symbol_3); + tint_symbol(sb, &(tint_symbol_3)); let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u); var len : u32 = tint_symbol_4; } @@ -263,10 +263,10 @@ fn tint_symbol_3(buffer : SB2, result : ptr) [[stage(compute)]] fn main() { var tint_symbol_1 : u32 = 0u; - tint_symbol(sb1, tint_symbol_1); + tint_symbol(sb1, &(tint_symbol_1)); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); var tint_symbol_4 : u32 = 0u; - tint_symbol_3(sb2, tint_symbol_4); + tint_symbol_3(sb2, &(tint_symbol_4)); let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u); var len1 : u32 = tint_symbol_2; var len2 : u32 = tint_symbol_5; diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc index 610e517ce5..743bb2bf8b 100644 --- a/src/writer/hlsl/generator_impl_binary_test.cc +++ b/src/writer/hlsl/generator_impl_binary_test.cc @@ -483,8 +483,13 @@ if (tint_tmp) { TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) { // foo(a && b, c || d, (a || c) && (b || d)) - Func("foo", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{}); + Func("foo", + { + Param(Sym(), ty.bool_()), + Param(Sym(), ty.bool_()), + Param(Sym(), ty.bool_()), + }, + ty.void_(), ast::StatementList{}, ast::DecorationList{}); Global("a", ty.bool_(), ast::StorageClass::kPrivate); Global("b", ty.bool_(), ast::StorageClass::kPrivate); Global("c", ty.bool_(), ast::StorageClass::kPrivate); diff --git a/src/writer/hlsl/generator_impl_call_test.cc b/src/writer/hlsl/generator_impl_call_test.cc index b5bb0895e0..eadfcd406e 100644 --- a/src/writer/hlsl/generator_impl_call_test.cc +++ b/src/writer/hlsl/generator_impl_call_test.cc @@ -36,8 +36,12 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithoutParams) { } TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithParams) { - Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{}); + Func("my_func", + { + Param(Sym(), ty.f32()), + Param(Sym(), ty.f32()), + }, + ty.void_(), ast::StatementList{}, ast::DecorationList{}); Global("param1", ty.f32(), ast::StorageClass::kPrivate); Global("param2", ty.f32(), ast::StorageClass::kPrivate); @@ -51,8 +55,12 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithParams) { } TEST_F(HlslGeneratorImplTest_Call, EmitStatement_Call) { - Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{}); + Func("my_func", + { + Param(Sym(), ty.f32()), + Param(Sym(), ty.f32()), + }, + ty.void_(), ast::StatementList{}, ast::DecorationList{}); Global("param1", ty.f32(), ast::StorageClass::kPrivate); Global("param2", ty.f32(), ast::StorageClass::kPrivate); diff --git a/src/writer/msl/generator_impl_call_test.cc b/src/writer/msl/generator_impl_call_test.cc index bd015b67e0..c165579cee 100644 --- a/src/writer/msl/generator_impl_call_test.cc +++ b/src/writer/msl/generator_impl_call_test.cc @@ -36,8 +36,12 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithoutParams) { } TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) { - Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{}); + Func("my_func", + { + Param(Sym(), ty.f32()), + Param(Sym(), ty.f32()), + }, + ty.void_(), ast::StatementList{}, ast::DecorationList{}); Global("param1", ty.f32(), ast::StorageClass::kInput); Global("param2", ty.f32(), ast::StorageClass::kInput); @@ -51,8 +55,12 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) { } TEST_F(MslGeneratorImplTest, EmitStatement_Call) { - Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{}); + Func("my_func", + { + Param(Sym(), ty.f32()), + Param(Sym(), ty.f32()), + }, + ty.void_(), ast::StatementList{}, ast::DecorationList{}); Global("param1", ty.f32(), ast::StorageClass::kInput); Global("param2", ty.f32(), ast::StorageClass::kInput); diff --git a/src/writer/wgsl/generator_impl_call_test.cc b/src/writer/wgsl/generator_impl_call_test.cc index bf701cd511..ecce2726a8 100644 --- a/src/writer/wgsl/generator_impl_call_test.cc +++ b/src/writer/wgsl/generator_impl_call_test.cc @@ -36,8 +36,12 @@ TEST_F(WgslGeneratorImplTest, EmitExpression_Call_WithoutParams) { } TEST_F(WgslGeneratorImplTest, EmitExpression_Call_WithParams) { - Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{}); + Func("my_func", + { + Param(Sym(), ty.f32()), + Param(Sym(), ty.f32()), + }, + ty.void_(), ast::StatementList{}, ast::DecorationList{}); Global("param1", ty.f32(), ast::StorageClass::kPrivate); Global("param2", ty.f32(), ast::StorageClass::kPrivate); @@ -51,8 +55,12 @@ TEST_F(WgslGeneratorImplTest, EmitExpression_Call_WithParams) { } TEST_F(WgslGeneratorImplTest, EmitStatement_Call) { - Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{}); + Func("my_func", + { + Param(Sym(), ty.f32()), + Param(Sym(), ty.f32()), + }, + ty.void_(), ast::StatementList{}, ast::DecorationList{}); Global("param1", ty.f32(), ast::StorageClass::kPrivate); Global("param2", ty.f32(), ast::StorageClass::kPrivate); diff --git a/test/BUILD.gn b/test/BUILD.gn index dd982644f6..9c503cd4a8 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -225,6 +225,7 @@ tint_unittests_source_set("tint_unittests_core_src") { "../src/resolver/assignment_validation_test.cc", "../src/resolver/block_test.cc", "../src/resolver/builtins_validation_test.cc", + "../src/resolver/call_test.cc", "../src/resolver/control_block_validation_test.cc", "../src/resolver/decoration_validation_test.cc", "../src/resolver/entry_point_validation_test.cc", @@ -465,8 +466,8 @@ tint_unittests_source_set("tint_unittests_wgsl_reader_src") { "../src/reader/wgsl/parser_impl_variable_decoration_list_test.cc", "../src/reader/wgsl/parser_impl_variable_decoration_test.cc", "../src/reader/wgsl/parser_impl_variable_ident_decl_test.cc", - "../src/reader/wgsl/parser_impl_variable_stmt_test.cc", "../src/reader/wgsl/parser_impl_variable_qualifier_test.cc", + "../src/reader/wgsl/parser_impl_variable_stmt_test.cc", "../src/reader/wgsl/parser_test.cc", "../src/reader/wgsl/token_test.cc", ]