From b6d524380ef9c3311ef999e59884bb040273360f Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Wed, 31 Aug 2022 22:59:08 +0000 Subject: [PATCH] tint: Improve resolver test helper to specify more than one expression arg Bug: tint:1581 Change-Id: Ie77c56c15a5965b20008036cd99a81cbabd86988 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/100340 Kokoro: Kokoro Reviewed-by: Ben Clayton Commit-Queue: Antonio Maiorano --- src/tint/resolver/bitcast_validation_test.cc | 4 +- src/tint/resolver/call_test.cc | 4 +- src/tint/resolver/const_eval_test.cc | 218 ++++++++--- src/tint/resolver/inferred_type_test.cc | 6 +- src/tint/resolver/materialize_test.cc | 44 +-- src/tint/resolver/resolver_test_helper.h | 367 ++++++++++++++---- .../type_constructor_validation_test.cc | 25 +- 7 files changed, 511 insertions(+), 157 deletions(-) diff --git a/src/tint/resolver/bitcast_validation_test.cc b/src/tint/resolver/bitcast_validation_test.cc index d341808d5b..c6d4cb4913 100644 --- a/src/tint/resolver/bitcast_validation_test.cc +++ b/src/tint/resolver/bitcast_validation_test.cc @@ -25,12 +25,12 @@ struct Type { template static constexpr Type Create() { return Type{builder::DataType::AST, builder::DataType::Sem, - builder::DataType::Expr}; + builder::DataType::ExprFromDouble}; } builder::ast_type_func_ptr ast; builder::sem_type_func_ptr sem; - builder::ast_expr_func_ptr expr; + builder::ast_expr_from_double_func_ptr expr; }; static constexpr Type kNumericScalars[] = { diff --git a/src/tint/resolver/call_test.cc b/src/tint/resolver/call_test.cc index 37aaffa09c..e5b245c1f6 100644 --- a/src/tint/resolver/call_test.cc +++ b/src/tint/resolver/call_test.cc @@ -57,13 +57,13 @@ using alias3 = builder::alias3; using ResolverCallTest = ResolverTest; struct Params { - builder::ast_expr_func_ptr create_value; + builder::ast_expr_from_double_func_ptr create_value; builder::ast_type_func_ptr create_type; }; template constexpr Params ParamsFor() { - return Params{DataType::Expr, DataType::AST}; + return Params{DataType::ExprFromDouble, DataType::AST}; } static constexpr Params all_param_types[] = { diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc index 243b3b9c65..23e621f01c 100644 --- a/src/tint/resolver/const_eval_test.cc +++ b/src/tint/resolver/const_eval_test.cc @@ -3155,29 +3155,127 @@ TEST_F(ResolverConstEvalTest, UnaryNegateLowestAbstract) { //////////////////////////////////////////////////////////////////////////////////////////////////// // Binary op //////////////////////////////////////////////////////////////////////////////////////////////////// + namespace binary_op { -using Types = std::variant; +using builder::IsValue; +using builder::Mat; +using builder::Val; +using builder::Value; +using builder::Vec; + +using Types = std::variant, + Value, + Value, + Value, + Value, + Value, + + Value>, + Value>, + Value>, + Value>, + Value>, + Value>, + + Value>, + Value>, + Value>, + Value>, + Value>, + Value>, + + Value>, + Value>, + Value>, + Value>, + Value>, + Value>, + + Value>, + Value>, + Value>, + Value>, + + Value>, + Value>, + Value>, + Value>, + + Value>, + Value>, + Value>, + Value> + // + >; struct Case { Types lhs; Types rhs; Types expected; - bool is_overflow; + bool overflow; }; +/// Creates a Case with Values of any type +template +Case C(Value lhs, Value rhs, Value expected, bool overflow = false) { + return Case{std::move(lhs), std::move(rhs), std::move(expected), overflow}; +} + +/// Convenience overload to creates a Case with just scalars +template >> +Case C(T lhs, U rhs, V expected, bool overflow = false) { + return Case{Val(lhs), Val(rhs), Val(expected), overflow}; +} + static std::ostream& operator<<(std::ostream& o, const Case& c) { - std::visit( - [&](auto&& lhs, auto&& rhs, auto&& expected) { - o << "lhs: " << lhs << ", rhs: " << rhs << ", expected: " << expected; - }, - c.lhs, c.rhs, c.expected); + auto print_value = [&](auto&& value) { + std::visit( + [&](auto&& v) { + using ValueType = std::decay_t; + o << ValueType::DataType::Name() << "("; + for (auto& a : v.args.values) { + o << std::get(a); + if (&a != &v.args.values.Back()) { + o << ", "; + } + } + o << ")"; + }, + value); + }; + o << "lhs: "; + print_value(c.lhs); + o << ", rhs: "; + print_value(c.rhs); + o << ", expected: "; + print_value(c.expected); + o << ", overflow: " << c.overflow; return o; } -template -Case C(T lhs, U rhs, V expected, bool is_overflow = false) { - return Case{lhs, rhs, expected, is_overflow}; +// Calls `f` on deepest elements of both `a` and `b`. If function returns false, it stops +// traversing, and return false, otherwise it continues and returns true. +// TODO(amaiorano): Move to Constant.h? +template +bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) { + EXPECT_EQ(a->Type(), b->Type()); + size_t i = 0; + while (true) { + auto* a_elem = a->Index(i); + if (!a_elem) { + break; + } + auto* b_elem = b->Index(i); + if (!ForEachElemPair(a_elem, b_elem, f)) { + return false; + } + i++; + } + if (i == 0) { + return f(a, b); + } + return true; } using ResolverConstEvalBinaryOpTest = ResolverTestWithParam>; @@ -3185,35 +3283,51 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) { Enable(ast::Extension::kF16); auto op = std::get<0>(GetParam()); - auto c = std::get<1>(GetParam()); - std::visit( - [&](auto&& lhs, auto&& rhs, auto&& expected) { - using T = std::decay_t; + auto& c = std::get<1>(GetParam()); + std::visit( + [&](auto&& expected) { + using T = typename std::decay_t::ElementType; if constexpr (std::is_same_v || std::is_same_v) { - if (c.is_overflow) { + if (c.overflow) { + // Overflow is not allowed for abstract types. This is tested separately. return; } } - auto* expr = create(op, Expr(lhs), Expr(rhs)); + auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs); + auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs); + auto* expr = create(op, lhs_expr, rhs_expr); + GlobalConst("C", expr); + auto* expected_expr = expected.Expr(*this); + GlobalConst("E", expected_expr); + EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); const sem::Constant* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); - EXPECT_EQ(value->As(), expected); - if constexpr (IsInteger>) { - // Check that the constant's integer doesn't contain unexpected data in the MSBs - // that are outside of the bit-width of T. - EXPECT_EQ(value->As(), AInt(expected)); - } + auto* expected_sem = Sem().Get(expected_expr); + const sem::Constant* expected_value = expected_sem->ConstantValue(); + ASSERT_NE(expected_value, nullptr); + EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); + + ForEachElemPair(value, expected_value, + [&](const sem::Constant* a, const sem::Constant* b) { + EXPECT_EQ(a->As(), b->As()); + if constexpr (IsInteger>) { + // Check that the constant's integer doesn't contain unexpected + // data in the MSBs that are outside of the bit-width of T. + EXPECT_EQ(a->As(), b->As()); + } + return !HasFailure(); + }); }, - c.lhs, c.rhs, c.expected); + c.expected); } INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs, @@ -3325,33 +3439,37 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) { EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'"); } -TEST_F(ResolverConstEvalTest, BinaryAbstractMixed_ScalarScalar) { - auto* a = Const("a", Expr(1_a)); // AInt - auto* b = Const("b", Expr(2.3_a)); // AFloat - auto* c = Add(Expr("a"), Expr("b")); - WrapInFunction(a, b, c); - EXPECT_TRUE(r()->Resolve()) << r()->error(); - auto* sem = Sem().Get(c); - ASSERT_TRUE(sem); - ASSERT_TRUE(sem->ConstantValue()); - auto result = sem->ConstantValue()->As(); - EXPECT_EQ(result, 3.3f); -} - -TEST_F(ResolverConstEvalTest, BinaryAbstractMixed_ScalarVector) { - auto* a = Const("a", Expr(1_a)); // AInt - auto* b = Const("b", Construct(ty.vec(nullptr, 3), Expr(2.3_a))); // AFloat - auto* c = Add(Expr("a"), Expr("b")); - WrapInFunction(a, b, c); - EXPECT_TRUE(r()->Resolve()) << r()->error(); - auto* sem = Sem().Get(c); - ASSERT_TRUE(sem); - ASSERT_TRUE(sem->ConstantValue()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 3.3f); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 3.3f); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 3.3f); -} - +// Mixed AInt and AFloat args to test implicit conversion to AFloat +INSTANTIATE_TEST_SUITE_P( + AbstractMixed, + ResolverConstEvalBinaryOpTest, + testing::Combine( + testing::Values(ast::BinaryOp::kAdd), + testing::Values(C(Val(1_a), Val(2.3_a), Val(3.3_a)), + C(Val(2.3_a), Val(1_a), Val(3.3_a)), + C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)), + C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)), + C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)), + C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)), + C(Mat({1_a, 2_a}, // + {1_a, 2_a}, // + {1_a, 2_a}), // + Mat({1.2_a, 2.3_a}, // + {1.2_a, 2.3_a}, // + {1.2_a, 2.3_a}), // + Mat({2.2_a, 4.3_a}, // + {2.2_a, 4.3_a}, // + {2.2_a, 4.3_a})), // + C(Mat({1.2_a, 2.3_a}, // + {1.2_a, 2.3_a}, // + {1.2_a, 2.3_a}), // + Mat({1_a, 2_a}, // + {1_a, 2_a}, // + {1_a, 2_a}), // + Mat({2.2_a, 4.3_a}, // + {2.2_a, 4.3_a}, // + {2.2_a, 4.3_a})) // + ))); } // namespace binary_op //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc index 8ce611ee8a..83ca0bd752 100644 --- a/src/tint/resolver/inferred_type_test.cc +++ b/src/tint/resolver/inferred_type_test.cc @@ -43,13 +43,15 @@ using alias = builder::alias; struct ResolverInferredTypeTest : public resolver::TestHelper, public testing::Test {}; struct Params { - builder::ast_expr_func_ptr create_value; + // builder::ast_expr_func_ptr_default_arg create_value; + builder::ast_expr_from_double_func_ptr create_value; builder::sem_type_func_ptr create_expected_type; }; template constexpr Params ParamsFor() { - return Params{DataType::Expr, DataType::Sem}; + // return Params{builder::CreateExprWithDefaultArg(), DataType::Sem}; + return Params{DataType::ExprFromDouble, DataType::Sem}; } Params all_cases[] = { diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc index 01ff0d5afa..239efd59f0 100644 --- a/src/tint/resolver/materialize_test.cc +++ b/src/tint/resolver/materialize_test.cc @@ -255,9 +255,9 @@ struct Data { std::string target_element_type_name; builder::ast_type_func_ptr target_ast_ty; builder::sem_type_func_ptr target_sem_ty; - builder::ast_expr_func_ptr target_expr; + builder::ast_expr_from_double_func_ptr target_expr; std::string abstract_type_name; - builder::ast_expr_func_ptr abstract_expr; + builder::ast_expr_from_double_func_ptr abstract_expr; std::variant materialized_value; double literal_value; }; @@ -268,13 +268,13 @@ Data Types(MATERIALIZED_TYPE materialized_value, double literal_value) { using AbstractDataType = builder::DataType; using TargetElementDataType = builder::DataType; return { - TargetDataType::Name(), // target_type_name - TargetElementDataType::Name(), // target_element_type_name - TargetDataType::AST, // target_ast_ty - TargetDataType::Sem, // target_sem_ty - TargetDataType::Expr, // target_expr - AbstractDataType::Name(), // abstract_type_name - AbstractDataType::Expr, // abstract_expr + TargetDataType::Name(), // target_type_name + TargetElementDataType::Name(), // target_element_type_name + TargetDataType::AST, // target_ast_ty + TargetDataType::Sem, // target_sem_ty + TargetDataType::ExprFromDouble, // target_expr + AbstractDataType::Name(), // abstract_type_name + AbstractDataType::ExprFromDouble, // abstract_expr materialized_value, literal_value, }; @@ -286,13 +286,13 @@ Data Types() { using AbstractDataType = builder::DataType; using TargetElementDataType = builder::DataType; return { - TargetDataType::Name(), // target_type_name - TargetElementDataType::Name(), // target_element_type_name - TargetDataType::AST, // target_ast_ty - TargetDataType::Sem, // target_sem_ty - TargetDataType::Expr, // target_expr - AbstractDataType::Name(), // abstract_type_name - AbstractDataType::Expr, // abstract_expr + TargetDataType::Name(), // target_type_name + TargetElementDataType::Name(), // target_element_type_name + TargetDataType::AST, // target_ast_ty + TargetDataType::Sem, // target_sem_ty + TargetDataType::ExprFromDouble, // target_expr + AbstractDataType::Name(), // abstract_type_name + AbstractDataType::ExprFromDouble, // abstract_expr 0_a, 0.0, }; @@ -826,7 +826,7 @@ struct Data { std::string expected_element_type_name; builder::sem_type_func_ptr expected_sem_ty; std::string abstract_type_name; - builder::ast_expr_func_ptr abstract_expr; + builder::ast_expr_from_double_func_ptr abstract_expr; std::variant materialized_value; double literal_value; }; @@ -837,11 +837,11 @@ Data Types(MATERIALIZED_TYPE materialized_value, double literal_value) { using AbstractDataType = builder::DataType; using TargetElementDataType = builder::DataType; return { - ExpectedDataType::Name(), // expected_type_name - TargetElementDataType::Name(), // expected_element_type_name - ExpectedDataType::Sem, // expected_sem_ty - AbstractDataType::Name(), // abstract_type_name - AbstractDataType::Expr, // abstract_expr + ExpectedDataType::Name(), // expected_type_name + TargetElementDataType::Name(), // expected_element_type_name + ExpectedDataType::Sem, // expected_sem_ty + AbstractDataType::Name(), // abstract_type_name + AbstractDataType::ExprFromDouble, // abstract_expr materialized_value, literal_value, }; diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h index bab1d843aa..66411766b0 100644 --- a/src/tint/resolver/resolver_test_helper.h +++ b/src/tint/resolver/resolver_test_helper.h @@ -18,6 +18,9 @@ #include #include #include +#include +#include +#include #include "gtest/gtest.h" #include "src/tint/program_builder.h" @@ -170,8 +173,35 @@ using alias3 = alias; template struct ptr {}; +/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for +/// composite types, or all values requried by the composite type. +struct ScalarArgs { + /// Constructor + /// @param single_value single value to initialize with + template + ScalarArgs(T single_value) // NOLINT: implicit on purpose + : values(utils::Vector{single_value}) {} + + /// Constructor + /// @param all_values all values to initialize the composite type with + template + ScalarArgs(utils::VectorRef all_values) // NOLINT: implicit on purpose + { + for (auto& v : all_values) { + values.Push(v); + } + } + + /// Valid scalar types for args + using Storage = std::variant; + + /// The vector of values + utils::Vector values; +}; + using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b); -using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double elem_value); +using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args); +using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v); using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b); template @@ -202,10 +232,16 @@ struct DataType { /// @return the semantic bool type static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create(); } /// @param b the ProgramBuilder - /// @param elem_value the b + /// @param args args of size 1 with the boolean value to init with /// @return a new AST expression of the bool type - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Expr(std::equal_to()(elem_value, 0)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Expr(std::get(args.values[0])); + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to bool. + /// @return a new AST expression of the bool type + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { return "bool"; } @@ -227,10 +263,16 @@ struct DataType { /// @return the semantic i32 type static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create(); } /// @param b the ProgramBuilder - /// @param elem_value the value i32 will be initialized with + /// @param args args of size 1 with the i32 value to init with /// @return a new AST i32 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Expr(static_cast(elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Expr(std::get(args.values[0])); + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to i32. + /// @return a new AST i32 literal value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { return "i32"; } @@ -252,10 +294,16 @@ struct DataType { /// @return the semantic u32 type static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create(); } /// @param b the ProgramBuilder - /// @param elem_value the value u32 will be initialized with + /// @param args args of size 1 with the u32 value to init with /// @return a new AST u32 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Expr(static_cast(elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Expr(std::get(args.values[0])); + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to u32. + /// @return a new AST u32 literal value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { return "u32"; } @@ -277,10 +325,16 @@ struct DataType { /// @return the semantic f32 type static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create(); } /// @param b the ProgramBuilder - /// @param elem_value the value f32 will be initialized with + /// @param args args of size 1 with the f32 value to init with /// @return a new AST f32 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Expr(static_cast(elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Expr(std::get(args.values[0])); + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to f32. + /// @return a new AST f32 literal value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { return "f32"; } @@ -302,10 +356,16 @@ struct DataType { /// @return the semantic f16 type static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create(); } /// @param b the ProgramBuilder - /// @param elem_value the value f16 will be initialized with + /// @param args args of size 1 with the f16 value to init with /// @return a new AST f16 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Expr(static_cast(elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Expr(std::get(args.values[0])); + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to f16. + /// @return a new AST f16 literal value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { return "f16"; } @@ -326,10 +386,16 @@ struct DataType { /// @return the semantic abstract-float type static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create(); } /// @param b the ProgramBuilder - /// @param elem_value the value the abstract-float literal will be constructed with + /// @param args args of size 1 with the abstract-float value to init with /// @return a new AST abstract-float literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Expr(AFloat(elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Expr(std::get(args.values[0])); + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to AFloat. + /// @return a new AST abstract-float literal value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { return "abstract-float"; } @@ -350,10 +416,16 @@ struct DataType { /// @return the semantic abstract-int type static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create(); } /// @param b the ProgramBuilder - /// @param elem_value the value the abstract-int literal will be constructed with + /// @param args args of size 1 with the abstract-int value to init with /// @return a new AST abstract-int literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Expr(AInt(elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Expr(std::get(args.values[0])); + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to AInt. + /// @return a new AST abstract-int literal value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { return "abstract-int"; } @@ -379,22 +451,27 @@ struct DataType> { return b.create(DataType::Sem(b), N); } /// @param b the ProgramBuilder - /// @param elem_value the value each element in the vector will be initialized - /// with + /// @param args args of size 1 or N with values of type T to initialize with /// @return a new AST vector value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Construct(AST(b), ExprArgs(b, elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Construct(AST(b), ExprArgs(b, std::move(args))); } - /// @param b the ProgramBuilder - /// @param elem_value the value each element will be initialized with + /// @param args args of size 1 or N with values of type T to initialize with /// @return the list of expressions that are used to construct the vector - static inline auto ExprArgs(ProgramBuilder& b, double elem_value) { - utils::Vector args; - for (uint32_t i = 0; i < N; i++) { - args.Push(DataType::Expr(b, elem_value)); + static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) { + const bool one_value = args.values.Length() == 1; + utils::Vector r; + for (size_t i = 0; i < N; ++i) { + r.Push(DataType::Expr(b, one_value ? args.values[0] : args.values[i])); } - return args; + return r; + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to ElementType + /// @return a new AST vector value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -423,22 +500,36 @@ struct DataType> { return b.create(column_type, N); } /// @param b the ProgramBuilder - /// @param elem_value the value each element in the matrix will be initialized - /// with + /// @param args args of size 1 or N*M with values of type T to initialize with /// @return a new AST matrix value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Construct(AST(b), ExprArgs(b, elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Construct(AST(b), ExprArgs(b, std::move(args))); } - /// @param b the ProgramBuilder - /// @param elem_value the value each element will be initialized with - /// @return the list of expressions that are used to construct the matrix - static inline auto ExprArgs(ProgramBuilder& b, double elem_value) { - utils::Vector args; - for (uint32_t i = 0; i < N; i++) { - args.Push(DataType>::Expr(b, elem_value)); + /// @param args args of size 1 or N*M with values of type T to initialize with + /// @return a new AST matrix value expression + static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) { + const bool one_value = args.values.Length() == 1; + size_t next = 0; + utils::Vector r; + for (uint32_t i = 0; i < N; ++i) { + if (one_value) { + r.Push(DataType>::Expr(b, args.values[0])); + } else { + utils::Vector v; + for (size_t j = 0; j < M; ++j) { + v.Push(std::get(args.values[next++])); + } + r.Push(DataType>::Expr(b, utils::VectorRef{v})); + } } - return args; + return r; + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to ElementType + /// @return a new AST matrix value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -451,7 +542,7 @@ struct DataType> { template struct DataType> { /// The element type - using ElementType = T; + using ElementType = typename DataType::ElementType; /// true if the aliased type is a composite type static constexpr bool is_composite = DataType::is_composite; @@ -471,24 +562,32 @@ struct DataType> { static inline const sem::Type* Sem(ProgramBuilder& b) { return DataType::Sem(b); } /// @param b the ProgramBuilder - /// @param elem_value the value nested elements will be initialized with + /// @param args the value nested elements will be initialized with /// @return a new AST expression of the alias type template static inline traits::EnableIf Expr(ProgramBuilder& b, - double elem_value) { + ScalarArgs args) { // Cast - return b.Construct(AST(b), DataType::Expr(b, elem_value)); + return b.Construct(AST(b), DataType::Expr(b, std::move(args))); } /// @param b the ProgramBuilder - /// @param elem_value the value nested elements will be initialized with + /// @param args the value nested elements will be initialized with /// @return a new AST expression of the alias type template static inline traits::EnableIf Expr(ProgramBuilder& b, - double elem_value) { + ScalarArgs args) { // Construct - return b.Construct(AST(b), DataType::ExprArgs(b, elem_value)); + return b.Construct(AST(b), DataType::ExprArgs(b, std::move(args))); } + + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to ElementType + /// @return a new AST expression of the alias type + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); + } + /// @returns the WGSL name for the type static inline std::string Name() { return "alias_" + std::to_string(ID); } }; @@ -497,7 +596,7 @@ struct DataType> { template struct DataType> { /// The element type - using ElementType = T; + using ElementType = typename DataType::ElementType; /// true if the pointer type is a composite type static constexpr bool is_composite = false; @@ -516,12 +615,20 @@ struct DataType> { } /// @param b the ProgramBuilder - /// @return a new AST expression of the alias type - static inline const ast::Expression* Expr(ProgramBuilder& b, double /*unused*/) { + /// @return a new AST expression of the pointer type + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs /*unused*/) { auto sym = b.Symbols().New("global_for_ptr"); b.GlobalVar(sym, DataType::AST(b), ast::StorageClass::kPrivate); return b.AddressOf(sym); } + + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to ElementType + /// @return a new AST expression of the pointer type + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); + } + /// @returns the WGSL name for the type static inline std::string Name() { return "ptr<" + DataType::Name() + ">"; } }; @@ -530,7 +637,7 @@ struct DataType> { template struct DataType> { /// The element type - using ElementType = T; + using ElementType = typename DataType::ElementType; /// true as arrays are a composite type static constexpr bool is_composite = true; @@ -556,22 +663,28 @@ struct DataType> { /* implicit_stride */ el->Align()); } /// @param b the ProgramBuilder - /// @param elem_value the value each element in the array will be initialized + /// @param args args of size 1 or N with values of type T to initialize with /// with /// @return a new AST array value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) { - return b.Construct(AST(b), ExprArgs(b, elem_value)); + static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + return b.Construct(AST(b), ExprArgs(b, std::move(args))); } - /// @param b the ProgramBuilder - /// @param elem_value the value each element will be initialized with + /// @param args args of size 1 or N with values of type T to initialize with /// @return the list of expressions that are used to construct the array - static inline auto ExprArgs(ProgramBuilder& b, double elem_value) { - utils::Vector args; + static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) { + const bool one_value = args.values.Length() == 1; + utils::Vector r; for (uint32_t i = 0; i < N; i++) { - args.Push(DataType::Expr(b, elem_value)); + r.Push(DataType::Expr(b, one_value ? args.values[0] : args.values[i])); } - return args; + return r; + } + /// @param b the ProgramBuilder + /// @param v arg of type double that will be cast to ElementType + /// @return a new AST array value expression + static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { + return Expr(b, static_cast(v)); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -585,6 +698,8 @@ struct CreatePtrs { ast_type_func_ptr ast; /// ast expression type create function ast_expr_func_ptr expr; + /// ast expression type create function from double arg + ast_expr_from_double_func_ptr expr_from_double; /// sem type create function sem_type_func_ptr sem; }; @@ -593,11 +708,129 @@ struct CreatePtrs { /// type `T` template constexpr CreatePtrs CreatePtrsFor() { - return {DataType::AST, DataType::Expr, DataType::Sem}; + return {DataType::AST, DataType::Expr, DataType::ExprFromDouble, DataType::Sem}; +} + +/// Value is an instance of a value of type DataType. Useful for storing values to create +/// expressions with. +template +struct Value { + /// Alias to T + using Type = T; + /// Alias to DataType + using DataType = builder::DataType; + /// Alias to DataType::ElementType + using ElementType = typename DataType::ElementType; + + /// Creates a Value with `args` + /// @param args the args that will be passed to the expression + /// @returns a Value + static Value Create(ScalarArgs args) { return Value{DataType::Expr, std::move(args)}; } + + /// Creates an `ast::Expression` for the type T passing in previously stored args + /// @param b the ProgramBuilder + /// @returns an expression node + const ast::Expression* Expr(ProgramBuilder& b) const { return (*expr)(b, args); } + + /// ast expression type create function + ast_expr_func_ptr expr; + /// args to create expression with + ScalarArgs args; +}; + +namespace detail { +/// Base template for IsValue +template +struct IsValue : std::false_type {}; +/// Specialization for IsValue +template +struct IsValue> : std::true_type {}; +} // namespace detail + +/// True if T is of type Value +template +constexpr bool IsValue = detail::IsValue::value; + +/// Creates a `Value` from a scalar `v` +template +auto Val(T v) { + return Value::Create(v); +} + +/// Creates a `Value>` from N scalar `args` +template +auto Vec(T&&... args) { + constexpr size_t N = sizeof...(args); + using FirstT = std::tuple_element_t<0, std::tuple>; + utils::Vector v{args...}; + using VT = vec; + return Value::Create(utils::VectorRef{v}); +} + +/// Creates a `Value` from C*R scalar `args` +template +auto Mat(const T (&m_in)[C][R]) { + utils::Vector m; + for (uint32_t i = 0; i < C; ++i) { + for (size_t j = 0; j < R; ++j) { + m.Push(m_in[i][j]); + } + } + return Value>::Create(utils::VectorRef{m}); +} + +/// Creates a `Value` from column vectors `c0` and `c1` +template +auto Mat(const T (&c0)[R], const T (&c1)[R]) { + constexpr size_t C = 2; + utils::Vector m; + for (auto v : c0) { + m.Push(v); + } + for (auto v : c1) { + m.Push(v); + } + return Value>::Create(utils::VectorRef{m}); +} + +/// Creates a `Value` from column vectors `c0`, `c1`, and `c2` +template +auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) { + constexpr size_t C = 3; + utils::Vector m; + for (auto v : c0) { + m.Push(v); + } + for (auto v : c1) { + m.Push(v); + } + for (auto v : c2) { + m.Push(v); + } + return Value>::Create(utils::VectorRef{m}); +} + +/// Creates a `Value` from column vectors `c0`, `c1`, `c2`, and `c3` +template +auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) { + constexpr size_t C = 4; + utils::Vector m; + for (auto v : c0) { + m.Push(v); + } + for (auto v : c1) { + m.Push(v); + } + for (auto v : c2) { + m.Push(v); + } + for (auto v : c3) { + m.Push(v); + } + return Value>::Create(utils::VectorRef{m}); } } // namespace builder - } // namespace tint::resolver #endif // SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_ diff --git a/src/tint/resolver/type_constructor_validation_test.cc b/src/tint/resolver/type_constructor_validation_test.cc index 6d771f2ebd..de64e9a7b2 100644 --- a/src/tint/resolver/type_constructor_validation_test.cc +++ b/src/tint/resolver/type_constructor_validation_test.cc @@ -47,13 +47,13 @@ class ResolverTypeConstructorValidationTest : public resolver::TestHelper, publi namespace InferTypeTest { struct Params { builder::ast_type_func_ptr create_rhs_ast_type; - builder::ast_expr_func_ptr create_rhs_ast_value; + builder::ast_expr_from_double_func_ptr create_rhs_ast_value; builder::sem_type_func_ptr create_rhs_sem_type; }; template constexpr Params ParamsFor() { - return Params{DataType::AST, DataType::Expr, DataType::Sem}; + return Params{DataType::AST, DataType::ExprFromDouble, DataType::Sem}; } TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) { @@ -242,12 +242,13 @@ struct Params { Kind kind; builder::ast_type_func_ptr lhs_type; builder::ast_type_func_ptr rhs_type; - builder::ast_expr_func_ptr rhs_value_expr; + builder::ast_expr_from_double_func_ptr rhs_value_expr; }; template constexpr Params ParamsFor(Kind kind) { - return Params{kind, DataType::AST, DataType::AST, DataType::Expr}; + return Params{kind, DataType::AST, DataType::AST, + DataType::ExprFromDouble}; } static constexpr Params valid_cases[] = { @@ -426,7 +427,7 @@ TEST_P(ConversionConstructorInvalidTest, All) { // Skip test for valid cases for (auto& v : valid_cases) { if (v.lhs_type == lhs_params.ast && v.rhs_type == rhs_params.ast && - v.rhs_value_expr == rhs_params.expr) { + v.rhs_value_expr == rhs_params.expr_from_double) { return; } } @@ -439,7 +440,7 @@ TEST_P(ConversionConstructorInvalidTest, All) { auto* lhs_type1 = lhs_params.ast(*this); auto* lhs_type2 = lhs_params.ast(*this); auto* rhs_type = rhs_params.ast(*this); - auto* rhs_value_expr = rhs_params.expr(*this, 0); + auto* rhs_value_expr = rhs_params.expr_from_double(*this, 0); std::stringstream ss; ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "(" @@ -2437,7 +2438,7 @@ struct MatrixParams { uint32_t columns; name_func_ptr get_element_type_name; builder::ast_type_func_ptr create_element_ast_type; - builder::ast_expr_func_ptr create_element_ast_value; + builder::ast_expr_from_double_func_ptr create_element_ast_value; builder::ast_type_func_ptr create_column_ast_type; builder::ast_type_func_ptr create_mat_ast_type; }; @@ -2449,7 +2450,7 @@ constexpr MatrixParams MatrixParamsFor() { C, DataType::Name, DataType::AST, - DataType::Expr, + DataType::ExprFromDouble, DataType>::AST, DataType>::AST, }; @@ -3058,7 +3059,7 @@ TEST_P(StructConstructorInputsTest, TooFew) { auto* struct_type = str_params.ast(*this); members.Push(Member("member_" + std::to_string(i), struct_type)); if (i < N - 1) { - auto* ctor_value_expr = str_params.expr(*this, 0); + auto* ctor_value_expr = str_params.expr_from_double(*this, 0); values.Push(ctor_value_expr); } } @@ -3084,7 +3085,7 @@ TEST_P(StructConstructorInputsTest, TooMany) { auto* struct_type = str_params.ast(*this); members.Push(Member("member_" + std::to_string(i), struct_type)); } - auto* ctor_value_expr = str_params.expr(*this, 0); + auto* ctor_value_expr = str_params.expr_from_double(*this, 0); values.Push(ctor_value_expr); } auto* s = Structure("s", members); @@ -3122,8 +3123,8 @@ TEST_P(StructConstructorTypeTest, AllTypes) { auto* struct_type = str_params.ast(*this); members.Push(Member("member_" + std::to_string(i), struct_type)); auto* ctor_value_expr = (i == constructor_value_with_different_type) - ? ctor_params.expr(*this, 0) - : str_params.expr(*this, 0); + ? ctor_params.expr_from_double(*this, 0) + : str_params.expr_from_double(*this, 0); values.Push(ctor_value_expr); } auto* s = Structure("s", members);