tint: Improve resolver test helper to specify more than one expression arg

Bug: tint:1581
Change-Id: Ie77c56c15a5965b20008036cd99a81cbabd86988
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/100340
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-08-31 22:59:08 +00:00 committed by Dawn LUCI CQ
parent 329ddd813d
commit b6d524380e
7 changed files with 511 additions and 157 deletions

View File

@ -25,12 +25,12 @@ struct Type {
template <typename T>
static constexpr Type Create() {
return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
builder::DataType<T>::Expr};
builder::DataType<T>::ExprFromDouble};
}
builder::ast_type_func_ptr ast;
builder::sem_type_func_ptr sem;
builder::ast_expr_func_ptr expr;
builder::ast_expr_from_double_func_ptr expr;
};
static constexpr Type kNumericScalars[] = {

View File

@ -57,13 +57,13 @@ using alias3 = builder::alias3<T>;
using ResolverCallTest = ResolverTest;
struct Params {
builder::ast_expr_func_ptr create_value;
builder::ast_expr_from_double_func_ptr create_value;
builder::ast_type_func_ptr create_type;
};
template <typename T>
constexpr Params ParamsFor() {
return Params{DataType<T>::Expr, DataType<T>::AST};
return Params{DataType<T>::ExprFromDouble, DataType<T>::AST};
}
static constexpr Params all_param_types[] = {

View File

@ -3155,29 +3155,127 @@ TEST_F(ResolverConstEvalTest, UnaryNegateLowestAbstract) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Binary op
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace binary_op {
using Types = std::variant<AInt, AFloat, u32, i32, f32, f16>;
using builder::IsValue;
using builder::Mat;
using builder::Val;
using builder::Value;
using builder::Vec;
using Types = std::variant<Value<AInt>,
Value<AFloat>,
Value<u32>,
Value<i32>,
Value<f32>,
Value<f16>,
Value<builder::vec2<AInt>>,
Value<builder::vec2<AFloat>>,
Value<builder::vec2<u32>>,
Value<builder::vec2<i32>>,
Value<builder::vec2<f32>>,
Value<builder::vec2<f16>>,
Value<builder::vec3<AInt>>,
Value<builder::vec3<AFloat>>,
Value<builder::vec3<u32>>,
Value<builder::vec3<i32>>,
Value<builder::vec3<f32>>,
Value<builder::vec3<f16>>,
Value<builder::vec4<AInt>>,
Value<builder::vec4<AFloat>>,
Value<builder::vec4<u32>>,
Value<builder::vec4<i32>>,
Value<builder::vec4<f32>>,
Value<builder::vec4<f16>>,
Value<builder::mat2x2<AInt>>,
Value<builder::mat2x2<AFloat>>,
Value<builder::mat2x2<f32>>,
Value<builder::mat2x2<f16>>,
Value<builder::mat2x3<AInt>>,
Value<builder::mat2x3<AFloat>>,
Value<builder::mat2x3<f32>>,
Value<builder::mat2x3<f16>>,
Value<builder::mat3x2<AInt>>,
Value<builder::mat3x2<AFloat>>,
Value<builder::mat3x2<f32>>,
Value<builder::mat3x2<f16>>
//
>;
struct Case {
Types lhs;
Types rhs;
Types expected;
bool is_overflow;
bool overflow;
};
/// Creates a Case with Values of any type
template <typename T, typename U, typename V>
Case C(Value<T> lhs, Value<U> rhs, Value<V> expected, bool overflow = false) {
return Case{std::move(lhs), std::move(rhs), std::move(expected), overflow};
}
/// Convenience overload to creates a Case with just scalars
template <typename T, typename U, typename V, typename = std::enable_if_t<!IsValue<T>>>
Case C(T lhs, U rhs, V expected, bool overflow = false) {
return Case{Val(lhs), Val(rhs), Val(expected), overflow};
}
static std::ostream& operator<<(std::ostream& o, const Case& c) {
std::visit(
[&](auto&& lhs, auto&& rhs, auto&& expected) {
o << "lhs: " << lhs << ", rhs: " << rhs << ", expected: " << expected;
},
c.lhs, c.rhs, c.expected);
auto print_value = [&](auto&& value) {
std::visit(
[&](auto&& v) {
using ValueType = std::decay_t<decltype(v)>;
o << ValueType::DataType::Name() << "(";
for (auto& a : v.args.values) {
o << std::get<typename ValueType::ElementType>(a);
if (&a != &v.args.values.Back()) {
o << ", ";
}
}
o << ")";
},
value);
};
o << "lhs: ";
print_value(c.lhs);
o << ", rhs: ";
print_value(c.rhs);
o << ", expected: ";
print_value(c.expected);
o << ", overflow: " << c.overflow;
return o;
}
template <typename T, typename U, typename V>
Case C(T lhs, U rhs, V expected, bool is_overflow = false) {
return Case{lhs, rhs, expected, is_overflow};
// Calls `f` on deepest elements of both `a` and `b`. If function returns false, it stops
// traversing, and return false, otherwise it continues and returns true.
// TODO(amaiorano): Move to Constant.h?
template <typename Func>
bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
EXPECT_EQ(a->Type(), b->Type());
size_t i = 0;
while (true) {
auto* a_elem = a->Index(i);
if (!a_elem) {
break;
}
auto* b_elem = b->Index(i);
if (!ForEachElemPair(a_elem, b_elem, f)) {
return false;
}
i++;
}
if (i == 0) {
return f(a, b);
}
return true;
}
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
@ -3185,35 +3283,51 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
Enable(ast::Extension::kF16);
auto op = std::get<0>(GetParam());
auto c = std::get<1>(GetParam());
std::visit(
[&](auto&& lhs, auto&& rhs, auto&& expected) {
using T = std::decay_t<decltype(expected)>;
auto& c = std::get<1>(GetParam());
std::visit(
[&](auto&& expected) {
using T = typename std::decay_t<decltype(expected)>::ElementType;
if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
if (c.is_overflow) {
if (c.overflow) {
// Overflow is not allowed for abstract types. This is tested separately.
return;
}
}
auto* expr = create<ast::BinaryExpression>(op, Expr(lhs), Expr(rhs));
auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
GlobalConst("C", expr);
auto* expected_expr = expected.Expr(*this);
GlobalConst("E", expected_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
const sem::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
EXPECT_EQ(value->As<T>(), expected);
if constexpr (IsInteger<UnwrapNumber<T>>) {
// Check that the constant's integer doesn't contain unexpected data in the MSBs
// that are outside of the bit-width of T.
EXPECT_EQ(value->As<AInt>(), AInt(expected));
}
auto* expected_sem = Sem().Get(expected_expr);
const sem::Constant* expected_value = expected_sem->ConstantValue();
ASSERT_NE(expected_value, nullptr);
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
ForEachElemPair(value, expected_value,
[&](const sem::Constant* a, const sem::Constant* b) {
EXPECT_EQ(a->As<T>(), b->As<T>());
if constexpr (IsInteger<UnwrapNumber<T>>) {
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
}
return !HasFailure();
});
},
c.lhs, c.rhs, c.expected);
c.expected);
}
INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
@ -3325,33 +3439,37 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) {
EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'");
}
TEST_F(ResolverConstEvalTest, BinaryAbstractMixed_ScalarScalar) {
auto* a = Const("a", Expr(1_a)); // AInt
auto* b = Const("b", Expr(2.3_a)); // AFloat
auto* c = Add(Expr("a"), Expr("b"));
WrapInFunction(a, b, c);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(c);
ASSERT_TRUE(sem);
ASSERT_TRUE(sem->ConstantValue());
auto result = sem->ConstantValue()->As<AFloat>();
EXPECT_EQ(result, 3.3f);
}
TEST_F(ResolverConstEvalTest, BinaryAbstractMixed_ScalarVector) {
auto* a = Const("a", Expr(1_a)); // AInt
auto* b = Const("b", Construct(ty.vec(nullptr, 3), Expr(2.3_a))); // AFloat
auto* c = Add(Expr("a"), Expr("b"));
WrapInFunction(a, b, c);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(c);
ASSERT_TRUE(sem);
ASSERT_TRUE(sem->ConstantValue());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 3.3f);
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), 3.3f);
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 3.3f);
}
// Mixed AInt and AFloat args to test implicit conversion to AFloat
INSTANTIATE_TEST_SUITE_P(
AbstractMixed,
ResolverConstEvalBinaryOpTest,
testing::Combine(
testing::Values(ast::BinaryOp::kAdd),
testing::Values(C(Val(1_a), Val(2.3_a), Val(3.3_a)),
C(Val(2.3_a), Val(1_a), Val(3.3_a)),
C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)),
C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)),
C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)),
C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)),
C(Mat({1_a, 2_a}, //
{1_a, 2_a}, //
{1_a, 2_a}), //
Mat({1.2_a, 2.3_a}, //
{1.2_a, 2.3_a}, //
{1.2_a, 2.3_a}), //
Mat({2.2_a, 4.3_a}, //
{2.2_a, 4.3_a}, //
{2.2_a, 4.3_a})), //
C(Mat({1.2_a, 2.3_a}, //
{1.2_a, 2.3_a}, //
{1.2_a, 2.3_a}), //
Mat({1_a, 2_a}, //
{1_a, 2_a}, //
{1_a, 2_a}), //
Mat({2.2_a, 4.3_a}, //
{2.2_a, 4.3_a}, //
{2.2_a, 4.3_a})) //
)));
} // namespace binary_op
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -43,13 +43,15 @@ using alias = builder::alias<T>;
struct ResolverInferredTypeTest : public resolver::TestHelper, public testing::Test {};
struct Params {
builder::ast_expr_func_ptr create_value;
// builder::ast_expr_func_ptr_default_arg create_value;
builder::ast_expr_from_double_func_ptr create_value;
builder::sem_type_func_ptr create_expected_type;
};
template <typename T>
constexpr Params ParamsFor() {
return Params{DataType<T>::Expr, DataType<T>::Sem};
// return Params{builder::CreateExprWithDefaultArg<T>(), DataType<T>::Sem};
return Params{DataType<T>::ExprFromDouble, DataType<T>::Sem};
}
Params all_cases[] = {

View File

@ -255,9 +255,9 @@ struct Data {
std::string target_element_type_name;
builder::ast_type_func_ptr target_ast_ty;
builder::sem_type_func_ptr target_sem_ty;
builder::ast_expr_func_ptr target_expr;
builder::ast_expr_from_double_func_ptr target_expr;
std::string abstract_type_name;
builder::ast_expr_func_ptr abstract_expr;
builder::ast_expr_from_double_func_ptr abstract_expr;
std::variant<AInt, AFloat> materialized_value;
double literal_value;
};
@ -268,13 +268,13 @@ Data Types(MATERIALIZED_TYPE materialized_value, double literal_value) {
using AbstractDataType = builder::DataType<ABSTRACT_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
TargetDataType::Name(), // target_type_name
TargetElementDataType::Name(), // target_element_type_name
TargetDataType::AST, // target_ast_ty
TargetDataType::Sem, // target_sem_ty
TargetDataType::Expr, // target_expr
AbstractDataType::Name(), // abstract_type_name
AbstractDataType::Expr, // abstract_expr
TargetDataType::Name(), // target_type_name
TargetElementDataType::Name(), // target_element_type_name
TargetDataType::AST, // target_ast_ty
TargetDataType::Sem, // target_sem_ty
TargetDataType::ExprFromDouble, // target_expr
AbstractDataType::Name(), // abstract_type_name
AbstractDataType::ExprFromDouble, // abstract_expr
materialized_value,
literal_value,
};
@ -286,13 +286,13 @@ Data Types() {
using AbstractDataType = builder::DataType<ABSTRACT_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
TargetDataType::Name(), // target_type_name
TargetElementDataType::Name(), // target_element_type_name
TargetDataType::AST, // target_ast_ty
TargetDataType::Sem, // target_sem_ty
TargetDataType::Expr, // target_expr
AbstractDataType::Name(), // abstract_type_name
AbstractDataType::Expr, // abstract_expr
TargetDataType::Name(), // target_type_name
TargetElementDataType::Name(), // target_element_type_name
TargetDataType::AST, // target_ast_ty
TargetDataType::Sem, // target_sem_ty
TargetDataType::ExprFromDouble, // target_expr
AbstractDataType::Name(), // abstract_type_name
AbstractDataType::ExprFromDouble, // abstract_expr
0_a,
0.0,
};
@ -826,7 +826,7 @@ struct Data {
std::string expected_element_type_name;
builder::sem_type_func_ptr expected_sem_ty;
std::string abstract_type_name;
builder::ast_expr_func_ptr abstract_expr;
builder::ast_expr_from_double_func_ptr abstract_expr;
std::variant<AInt, AFloat> materialized_value;
double literal_value;
};
@ -837,11 +837,11 @@ Data Types(MATERIALIZED_TYPE materialized_value, double literal_value) {
using AbstractDataType = builder::DataType<ABSTRACT_TYPE>;
using TargetElementDataType = builder::DataType<typename ExpectedDataType::ElementType>;
return {
ExpectedDataType::Name(), // expected_type_name
TargetElementDataType::Name(), // expected_element_type_name
ExpectedDataType::Sem, // expected_sem_ty
AbstractDataType::Name(), // abstract_type_name
AbstractDataType::Expr, // abstract_expr
ExpectedDataType::Name(), // expected_type_name
TargetElementDataType::Name(), // expected_element_type_name
ExpectedDataType::Sem, // expected_sem_ty
AbstractDataType::Name(), // abstract_type_name
AbstractDataType::ExprFromDouble, // abstract_expr
materialized_value,
literal_value,
};

View File

@ -18,6 +18,9 @@
#include <functional>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <variant>
#include "gtest/gtest.h"
#include "src/tint/program_builder.h"
@ -170,8 +173,35 @@ using alias3 = alias<TO, 3>;
template <typename TO>
struct ptr {};
/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for
/// composite types, or all values requried by the composite type.
struct ScalarArgs {
/// Constructor
/// @param single_value single value to initialize with
template <typename T>
ScalarArgs(T single_value) // NOLINT: implicit on purpose
: values(utils::Vector<Storage, 1>{single_value}) {}
/// Constructor
/// @param all_values all values to initialize the composite type with
template <typename T>
ScalarArgs(utils::VectorRef<T> all_values) // NOLINT: implicit on purpose
{
for (auto& v : all_values) {
values.Push(v);
}
}
/// Valid scalar types for args
using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
/// The vector of values
utils::Vector<Storage, 16> values;
};
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double elem_value);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args);
using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
template <typename T>
@ -202,10 +232,16 @@ struct DataType<bool> {
/// @return the semantic bool type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::Bool>(); }
/// @param b the ProgramBuilder
/// @param elem_value the b
/// @param args args of size 1 with the boolean value to init with
/// @return a new AST expression of the bool type
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(std::equal_to<double>()(elem_value, 0));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<bool>(args.values[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to bool.
/// @return a new AST expression of the bool type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "bool"; }
@ -227,10 +263,16 @@ struct DataType<i32> {
/// @return the semantic i32 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::I32>(); }
/// @param b the ProgramBuilder
/// @param elem_value the value i32 will be initialized with
/// @param args args of size 1 with the i32 value to init with
/// @return a new AST i32 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<i32>(elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<i32>(args.values[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to i32.
/// @return a new AST i32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "i32"; }
@ -252,10 +294,16 @@ struct DataType<u32> {
/// @return the semantic u32 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::U32>(); }
/// @param b the ProgramBuilder
/// @param elem_value the value u32 will be initialized with
/// @param args args of size 1 with the u32 value to init with
/// @return a new AST u32 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<u32>(elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<u32>(args.values[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to u32.
/// @return a new AST u32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "u32"; }
@ -277,10 +325,16 @@ struct DataType<f32> {
/// @return the semantic f32 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::F32>(); }
/// @param b the ProgramBuilder
/// @param elem_value the value f32 will be initialized with
/// @param args args of size 1 with the f32 value to init with
/// @return a new AST f32 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<f32>(elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<f32>(args.values[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to f32.
/// @return a new AST f32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<f32>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f32"; }
@ -302,10 +356,16 @@ struct DataType<f16> {
/// @return the semantic f16 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::F16>(); }
/// @param b the ProgramBuilder
/// @param elem_value the value f16 will be initialized with
/// @param args args of size 1 with the f16 value to init with
/// @return a new AST f16 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(static_cast<f16>(elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<f16>(args.values[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to f16.
/// @return a new AST f16 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f16"; }
@ -326,10 +386,16 @@ struct DataType<AFloat> {
/// @return the semantic abstract-float type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::AbstractFloat>(); }
/// @param b the ProgramBuilder
/// @param elem_value the value the abstract-float literal will be constructed with
/// @param args args of size 1 with the abstract-float value to init with
/// @return a new AST abstract-float literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(AFloat(elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<AFloat>(args.values[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to AFloat.
/// @return a new AST abstract-float literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-float"; }
@ -350,10 +416,16 @@ struct DataType<AInt> {
/// @return the semantic abstract-int type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::AbstractInt>(); }
/// @param b the ProgramBuilder
/// @param elem_value the value the abstract-int literal will be constructed with
/// @param args args of size 1 with the abstract-int value to init with
/// @return a new AST abstract-int literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Expr(AInt(elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<AInt>(args.values[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to AInt.
/// @return a new AST abstract-int literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-int"; }
@ -379,22 +451,27 @@ struct DataType<vec<N, T>> {
return b.create<sem::Vector>(DataType<T>::Sem(b), N);
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the vector will be initialized
/// with
/// @param args args of size 1 or N with values of type T to initialize with
/// @return a new AST vector value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the vector
static inline auto ExprArgs(ProgramBuilder& b, double elem_value) {
utils::Vector<const ast::Expression*, N> args;
for (uint32_t i = 0; i < N; i++) {
args.Push(DataType<T>::Expr(b, elem_value));
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
const bool one_value = args.values.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (size_t i = 0; i < N; ++i) {
r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i]));
}
return args;
return r;
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST vector value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@ -423,22 +500,36 @@ struct DataType<mat<N, M, T>> {
return b.create<sem::Matrix>(column_type, N);
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the matrix will be initialized
/// with
/// @param args args of size 1 or N*M with values of type T to initialize with
/// @return a new AST matrix value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @return the list of expressions that are used to construct the matrix
static inline auto ExprArgs(ProgramBuilder& b, double elem_value) {
utils::Vector<const ast::Expression*, N> args;
for (uint32_t i = 0; i < N; i++) {
args.Push(DataType<vec<M, T>>::Expr(b, elem_value));
/// @param args args of size 1 or N*M with values of type T to initialize with
/// @return a new AST matrix value expression
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
const bool one_value = args.values.Length() == 1;
size_t next = 0;
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; ++i) {
if (one_value) {
r.Push(DataType<vec<M, T>>::Expr(b, args.values[0]));
} else {
utils::Vector<T, M> v;
for (size_t j = 0; j < M; ++j) {
v.Push(std::get<T>(args.values[next++]));
}
r.Push(DataType<vec<M, T>>::Expr(b, utils::VectorRef<T>{v}));
}
}
return args;
return r;
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST matrix value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@ -451,7 +542,7 @@ struct DataType<mat<N, M, T>> {
template <typename T, int ID>
struct DataType<alias<T, ID>> {
/// The element type
using ElementType = T;
using ElementType = typename DataType<T>::ElementType;
/// true if the aliased type is a composite type
static constexpr bool is_composite = DataType<T>::is_composite;
@ -471,24 +562,32 @@ struct DataType<alias<T, ID>> {
static inline const sem::Type* Sem(ProgramBuilder& b) { return DataType<T>::Sem(b); }
/// @param b the ProgramBuilder
/// @param elem_value the value nested elements will be initialized with
/// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
double elem_value) {
ScalarArgs args) {
// Cast
return b.Construct(AST(b), DataType<T>::Expr(b, elem_value));
return b.Construct(AST(b), DataType<T>::Expr(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param elem_value the value nested elements will be initialized with
/// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
double elem_value) {
ScalarArgs args) {
// Construct
return b.Construct(AST(b), DataType<T>::ExprArgs(b, elem_value));
return b.Construct(AST(b), DataType<T>::ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the alias type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "alias_" + std::to_string(ID); }
};
@ -497,7 +596,7 @@ struct DataType<alias<T, ID>> {
template <typename T>
struct DataType<ptr<T>> {
/// The element type
using ElementType = T;
using ElementType = typename DataType<T>::ElementType;
/// true if the pointer type is a composite type
static constexpr bool is_composite = false;
@ -516,12 +615,20 @@ struct DataType<ptr<T>> {
}
/// @param b the ProgramBuilder
/// @return a new AST expression of the alias type
static inline const ast::Expression* Expr(ProgramBuilder& b, double /*unused*/) {
/// @return a new AST expression of the pointer type
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs /*unused*/) {
auto sym = b.Symbols().New("global_for_ptr");
b.GlobalVar(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
return b.AddressOf(sym);
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the pointer type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "ptr<" + DataType<T>::Name() + ">"; }
};
@ -530,7 +637,7 @@ struct DataType<ptr<T>> {
template <uint32_t N, typename T>
struct DataType<array<N, T>> {
/// The element type
using ElementType = T;
using ElementType = typename DataType<T>::ElementType;
/// true as arrays are a composite type
static constexpr bool is_composite = true;
@ -556,22 +663,28 @@ struct DataType<array<N, T>> {
/* implicit_stride */ el->Align());
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the array will be initialized
/// @param args args of size 1 or N with values of type T to initialize with
/// with
/// @return a new AST array value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the array
static inline auto ExprArgs(ProgramBuilder& b, double elem_value) {
utils::Vector<const ast::Expression*, N> args;
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
const bool one_value = args.values.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; i++) {
args.Push(DataType<T>::Expr(b, elem_value));
r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i]));
}
return args;
return r;
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST array value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@ -585,6 +698,8 @@ struct CreatePtrs {
ast_type_func_ptr ast;
/// ast expression type create function
ast_expr_func_ptr expr;
/// ast expression type create function from double arg
ast_expr_from_double_func_ptr expr_from_double;
/// sem type create function
sem_type_func_ptr sem;
};
@ -593,11 +708,129 @@ struct CreatePtrs {
/// type `T`
template <typename T>
constexpr CreatePtrs CreatePtrsFor() {
return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::ExprFromDouble, DataType<T>::Sem};
}
/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
/// expressions with.
template <typename T>
struct Value {
/// Alias to T
using Type = T;
/// Alias to DataType<T>
using DataType = builder::DataType<T>;
/// Alias to DataType::ElementType
using ElementType = typename DataType::ElementType;
/// Creates a Value<T> with `args`
/// @param args the args that will be passed to the expression
/// @returns a Value<T>
static Value Create(ScalarArgs args) { return Value{DataType::Expr, std::move(args)}; }
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
const ast::Expression* Expr(ProgramBuilder& b) const { return (*expr)(b, args); }
/// ast expression type create function
ast_expr_func_ptr expr;
/// args to create expression with
ScalarArgs args;
};
namespace detail {
/// Base template for IsValue
template <typename T>
struct IsValue : std::false_type {};
/// Specialization for IsValue
template <typename T>
struct IsValue<Value<T>> : std::true_type {};
} // namespace detail
/// True if T is of type Value
template <typename T>
constexpr bool IsValue = detail::IsValue<T>::value;
/// Creates a `Value<T>` from a scalar `v`
template <typename T>
auto Val(T v) {
return Value<T>::Create(v);
}
/// Creates a `Value<vec<N, T>>` from N scalar `args`
template <typename... T>
auto Vec(T&&... args) {
constexpr size_t N = sizeof...(args);
using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
utils::Vector v{args...};
using VT = vec<N, FirstT>;
return Value<VT>::Create(utils::VectorRef<FirstT>{v});
}
/// Creates a `Value<mat<C,R,T>` from C*R scalar `args`
template <size_t C, size_t R, typename T>
auto Mat(const T (&m_in)[C][R]) {
utils::Vector<T, C * R> m;
for (uint32_t i = 0; i < C; ++i) {
for (size_t j = 0; j < R; ++j) {
m.Push(m_in[i][j]);
}
}
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
}
/// Creates a `Value<mat<2,R,T>` from column vectors `c0` and `c1`
template <typename T, size_t R>
auto Mat(const T (&c0)[R], const T (&c1)[R]) {
constexpr size_t C = 2;
utils::Vector<T, C * R> m;
for (auto v : c0) {
m.Push(v);
}
for (auto v : c1) {
m.Push(v);
}
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
}
/// Creates a `Value<mat<3,R,T>` from column vectors `c0`, `c1`, and `c2`
template <typename T, size_t R>
auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
constexpr size_t C = 3;
utils::Vector<T, C * R> m;
for (auto v : c0) {
m.Push(v);
}
for (auto v : c1) {
m.Push(v);
}
for (auto v : c2) {
m.Push(v);
}
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
}
/// Creates a `Value<mat<4,R,T>` from column vectors `c0`, `c1`, `c2`, and `c3`
template <typename T, size_t R>
auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
constexpr size_t C = 4;
utils::Vector<T, C * R> m;
for (auto v : c0) {
m.Push(v);
}
for (auto v : c1) {
m.Push(v);
}
for (auto v : c2) {
m.Push(v);
}
for (auto v : c3) {
m.Push(v);
}
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
}
} // namespace builder
} // namespace tint::resolver
#endif // SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_

View File

@ -47,13 +47,13 @@ class ResolverTypeConstructorValidationTest : public resolver::TestHelper, publi
namespace InferTypeTest {
struct Params {
builder::ast_type_func_ptr create_rhs_ast_type;
builder::ast_expr_func_ptr create_rhs_ast_value;
builder::ast_expr_from_double_func_ptr create_rhs_ast_value;
builder::sem_type_func_ptr create_rhs_sem_type;
};
template <typename T>
constexpr Params ParamsFor() {
return Params{DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
return Params{DataType<T>::AST, DataType<T>::ExprFromDouble, DataType<T>::Sem};
}
TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) {
@ -242,12 +242,13 @@ struct Params {
Kind kind;
builder::ast_type_func_ptr lhs_type;
builder::ast_type_func_ptr rhs_type;
builder::ast_expr_func_ptr rhs_value_expr;
builder::ast_expr_from_double_func_ptr rhs_value_expr;
};
template <typename LhsType, typename RhsType>
constexpr Params ParamsFor(Kind kind) {
return Params{kind, DataType<LhsType>::AST, DataType<RhsType>::AST, DataType<RhsType>::Expr};
return Params{kind, DataType<LhsType>::AST, DataType<RhsType>::AST,
DataType<RhsType>::ExprFromDouble};
}
static constexpr Params valid_cases[] = {
@ -426,7 +427,7 @@ TEST_P(ConversionConstructorInvalidTest, All) {
// Skip test for valid cases
for (auto& v : valid_cases) {
if (v.lhs_type == lhs_params.ast && v.rhs_type == rhs_params.ast &&
v.rhs_value_expr == rhs_params.expr) {
v.rhs_value_expr == rhs_params.expr_from_double) {
return;
}
}
@ -439,7 +440,7 @@ TEST_P(ConversionConstructorInvalidTest, All) {
auto* lhs_type1 = lhs_params.ast(*this);
auto* lhs_type2 = lhs_params.ast(*this);
auto* rhs_type = rhs_params.ast(*this);
auto* rhs_value_expr = rhs_params.expr(*this, 0);
auto* rhs_value_expr = rhs_params.expr_from_double(*this, 0);
std::stringstream ss;
ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
@ -2437,7 +2438,7 @@ struct MatrixParams {
uint32_t columns;
name_func_ptr get_element_type_name;
builder::ast_type_func_ptr create_element_ast_type;
builder::ast_expr_func_ptr create_element_ast_value;
builder::ast_expr_from_double_func_ptr create_element_ast_value;
builder::ast_type_func_ptr create_column_ast_type;
builder::ast_type_func_ptr create_mat_ast_type;
};
@ -2449,7 +2450,7 @@ constexpr MatrixParams MatrixParamsFor() {
C,
DataType<T>::Name,
DataType<T>::AST,
DataType<T>::Expr,
DataType<T>::ExprFromDouble,
DataType<tint::resolver::builder::vec<R, T>>::AST,
DataType<tint::resolver::builder::mat<C, R, T>>::AST,
};
@ -3058,7 +3059,7 @@ TEST_P(StructConstructorInputsTest, TooFew) {
auto* struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
if (i < N - 1) {
auto* ctor_value_expr = str_params.expr(*this, 0);
auto* ctor_value_expr = str_params.expr_from_double(*this, 0);
values.Push(ctor_value_expr);
}
}
@ -3084,7 +3085,7 @@ TEST_P(StructConstructorInputsTest, TooMany) {
auto* struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
}
auto* ctor_value_expr = str_params.expr(*this, 0);
auto* ctor_value_expr = str_params.expr_from_double(*this, 0);
values.Push(ctor_value_expr);
}
auto* s = Structure("s", members);
@ -3122,8 +3123,8 @@ TEST_P(StructConstructorTypeTest, AllTypes) {
auto* struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
auto* ctor_value_expr = (i == constructor_value_with_different_type)
? ctor_params.expr(*this, 0)
: str_params.expr(*this, 0);
? ctor_params.expr_from_double(*this, 0)
: str_params.expr_from_double(*this, 0);
values.Push(ctor_value_expr);
}
auto* s = Structure("s", members);