resolver: Enable AST type reachability checks

Required a lot of test fixes.

ProgramBuilder: :ConstructValueFilledWith() was a major source of unreached AST types, and this has been removed with more powerful type-building helpers in resolver_test_helper.h.
Change-Id: I1f2007cdaef7f319ab4ef8b4fb8c37687a0fb5d8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53800
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-06-09 07:48:17 +00:00 committed by Tint LUCI CQ
parent fc645f7489
commit c57f725797
13 changed files with 870 additions and 674 deletions

View File

@ -102,59 +102,6 @@ const sem::Type* ProgramBuilder::TypeOf(const ast::Type* type) const {
return Sem().Get(type); return Sem().Get(type);
} }
ast::ConstructorExpression* ProgramBuilder::ConstructValueFilledWith(
const ast::Type* type,
int elem_value) {
CloneContext ctx(this);
if (type->Is<ast::Bool>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(elem_value == 0 ? false : true));
}
if (type->Is<ast::I32>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(static_cast<i32>(elem_value)));
}
if (type->Is<ast::U32>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(static_cast<u32>(elem_value)));
}
if (type->Is<ast::F32>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(static_cast<f32>(elem_value)));
}
if (auto* v = type->As<ast::Vector>()) {
ast::ExpressionList el(v->size());
for (size_t i = 0; i < el.size(); i++) {
el[i] = ConstructValueFilledWith(ctx.Clone(v->type()), elem_value);
}
return create<ast::TypeConstructorExpression>(const_cast<ast::Type*>(type),
std::move(el));
}
if (auto* m = type->As<ast::Matrix>()) {
ast::ExpressionList el(m->columns());
for (size_t i = 0; i < el.size(); i++) {
auto* col_vec_type = create<ast::Vector>(ctx.Clone(m->type()), m->rows());
el[i] = ConstructValueFilledWith(col_vec_type, elem_value);
}
return create<ast::TypeConstructorExpression>(const_cast<ast::Type*>(type),
std::move(el));
}
if (auto* tn = type->As<ast::TypeName>()) {
if (auto* lookup = AST().LookupType(tn->name())) {
if (auto* alias = lookup->As<ast::Alias>()) {
return ConstructValueFilledWith(ctx.Clone(alias->type()), elem_value);
}
}
TINT_ICE(diagnostics_) << "unable to find NamedType '"
<< Symbols().NameFor(tn->name()) << "'";
return nullptr;
}
TINT_ICE(diagnostics_) << "unhandled type: " << type->TypeInfo().name;
return nullptr;
}
ast::Type* ProgramBuilder::TypesBuilder::MaybeCreateTypename( ast::Type* ProgramBuilder::TypesBuilder::MaybeCreateTypename(
ast::Type* type) const { ast::Type* type) const {
if (auto* nt = As<ast::NamedType>(type)) { if (auto* nt = As<ast::NamedType>(type)) {

View File

@ -1078,17 +1078,6 @@ class ProgramBuilder {
type, ExprList(std::forward<ARGS>(args)...)); type, ExprList(std::forward<ARGS>(args)...));
} }
/// Creates a constructor expression that constructs an object of
/// `type` filled with `elem_value`. For example,
/// ConstructValueFilledWith(ty.mat3x4<float>(), 5) returns a
/// TypeConstructorExpression for a Mat3x4 filled with 5.0f values.
/// @param type the type to construct
/// @param elem_value the initial or element value (for vec and mat) to
/// construct with
/// @return the constructor expression
ast::ConstructorExpression* ConstructValueFilledWith(const ast::Type* type,
int elem_value = 0);
/// @param args the arguments for the vector constructor /// @param args the arguments for the vector constructor
/// @return an `ast::TypeConstructorExpression` of a 2-element vector of type /// @return an `ast::TypeConstructorExpression` of a 2-element vector of type
/// `T`, constructed with the values `args`. /// `T`, constructed with the values `args`.

View File

@ -26,6 +26,33 @@
namespace tint { namespace tint {
namespace resolver { namespace resolver {
// Helpers and typedefs
template <typename T>
using DataType = builder::DataType<T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T, int ID = 0>
using alias = builder::alias<T, ID>;
template <typename T>
using alias1 = builder::alias1<T>;
template <typename T>
using alias2 = builder::alias2<T>;
template <typename T>
using alias3 = builder::alias3<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
namespace DecorationTests { namespace DecorationTests {
namespace { namespace {
@ -357,17 +384,22 @@ namespace ArrayStrideTests {
namespace { namespace {
struct Params { struct Params {
create_ast_type_func_ptr create_el_type; builder::ast_type_func_ptr create_el_type;
uint32_t stride; uint32_t stride;
bool should_pass; bool should_pass;
}; };
template <typename T>
constexpr Params ParamsFor(uint32_t stride, bool should_pass) {
return Params{DataType<T>::AST, stride, should_pass};
}
struct TestWithParams : ResolverTestWithParam<Params> {}; struct TestWithParams : ResolverTestWithParam<Params> {};
using ArrayStrideTest = TestWithParams; using ArrayStrideTest = TestWithParams;
TEST_P(ArrayStrideTest, All) { TEST_P(ArrayStrideTest, All) {
auto& params = GetParam(); auto& params = GetParam();
auto* el_ty = params.create_el_type(ty); auto* el_ty = params.create_el_type(*this);
std::stringstream ss; std::stringstream ss;
ss << "el_ty: " << FriendlyName(el_ty) << ", stride: " << params.stride ss << "el_ty: " << FriendlyName(el_ty) << ", stride: " << params.stride
@ -389,11 +421,6 @@ TEST_P(ArrayStrideTest, All) {
} }
} }
// Helpers and typedefs
using i32 = ProgramBuilder::i32;
using u32 = ProgramBuilder::u32;
using f32 = ProgramBuilder::f32;
struct SizeAndAlignment { struct SizeAndAlignment {
uint32_t size; uint32_t size;
uint32_t align; uint32_t align;
@ -414,49 +441,49 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values( testing::Values(
// Succeed because stride >= element size (while being multiple of // Succeed because stride >= element size (while being multiple of
// element alignment) // element alignment)
Params{ast_u32, default_u32.size, true}, ParamsFor<u32>(default_u32.size, true),
Params{ast_i32, default_i32.size, true}, ParamsFor<i32>(default_i32.size, true),
Params{ast_f32, default_f32.size, true}, ParamsFor<f32>(default_f32.size, true),
Params{ast_vec2<f32>, default_vec2.size, true}, ParamsFor<vec2<f32>>(default_vec2.size, true),
// vec3's default size is not a multiple of its alignment // vec3's default size is not a multiple of its alignment
// Params{ast_vec3<f32>, default_vec3.size, true}, // ParamsFor<vec3<f32>, default_vec3.size, true},
Params{ast_vec4<f32>, default_vec4.size, true}, ParamsFor<vec4<f32>>(default_vec4.size, true),
Params{ast_mat2x2<f32>, default_mat2x2.size, true}, ParamsFor<mat2x2<f32>>(default_mat2x2.size, true),
Params{ast_mat3x3<f32>, default_mat3x3.size, true}, ParamsFor<mat3x3<f32>>(default_mat3x3.size, true),
Params{ast_mat4x4<f32>, default_mat4x4.size, true}, ParamsFor<mat4x4<f32>>(default_mat4x4.size, true),
// Fail because stride is < element size // Fail because stride is < element size
Params{ast_u32, default_u32.size - 1, false}, ParamsFor<u32>(default_u32.size - 1, false),
Params{ast_i32, default_i32.size - 1, false}, ParamsFor<i32>(default_i32.size - 1, false),
Params{ast_f32, default_f32.size - 1, false}, ParamsFor<f32>(default_f32.size - 1, false),
Params{ast_vec2<f32>, default_vec2.size - 1, false}, ParamsFor<vec2<f32>>(default_vec2.size - 1, false),
Params{ast_vec3<f32>, default_vec3.size - 1, false}, ParamsFor<vec3<f32>>(default_vec3.size - 1, false),
Params{ast_vec4<f32>, default_vec4.size - 1, false}, ParamsFor<vec4<f32>>(default_vec4.size - 1, false),
Params{ast_mat2x2<f32>, default_mat2x2.size - 1, false}, ParamsFor<mat2x2<f32>>(default_mat2x2.size - 1, false),
Params{ast_mat3x3<f32>, default_mat3x3.size - 1, false}, ParamsFor<mat3x3<f32>>(default_mat3x3.size - 1, false),
Params{ast_mat4x4<f32>, default_mat4x4.size - 1, false}, ParamsFor<mat4x4<f32>>(default_mat4x4.size - 1, false),
// Succeed because stride equals multiple of element alignment // Succeed because stride equals multiple of element alignment
Params{ast_u32, default_u32.align * 7, true}, ParamsFor<u32>(default_u32.align * 7, true),
Params{ast_i32, default_i32.align * 7, true}, ParamsFor<i32>(default_i32.align * 7, true),
Params{ast_f32, default_f32.align * 7, true}, ParamsFor<f32>(default_f32.align * 7, true),
Params{ast_vec2<f32>, default_vec2.align * 7, true}, ParamsFor<vec2<f32>>(default_vec2.align * 7, true),
Params{ast_vec3<f32>, default_vec3.align * 7, true}, ParamsFor<vec3<f32>>(default_vec3.align * 7, true),
Params{ast_vec4<f32>, default_vec4.align * 7, true}, ParamsFor<vec4<f32>>(default_vec4.align * 7, true),
Params{ast_mat2x2<f32>, default_mat2x2.align * 7, true}, ParamsFor<mat2x2<f32>>(default_mat2x2.align * 7, true),
Params{ast_mat3x3<f32>, default_mat3x3.align * 7, true}, ParamsFor<mat3x3<f32>>(default_mat3x3.align * 7, true),
Params{ast_mat4x4<f32>, default_mat4x4.align * 7, true}, ParamsFor<mat4x4<f32>>(default_mat4x4.align * 7, true),
// Fail because stride is not multiple of element alignment // Fail because stride is not multiple of element alignment
Params{ast_u32, (default_u32.align - 1) * 7, false}, ParamsFor<u32>((default_u32.align - 1) * 7, false),
Params{ast_i32, (default_i32.align - 1) * 7, false}, ParamsFor<i32>((default_i32.align - 1) * 7, false),
Params{ast_f32, (default_f32.align - 1) * 7, false}, ParamsFor<f32>((default_f32.align - 1) * 7, false),
Params{ast_vec2<f32>, (default_vec2.align - 1) * 7, false}, ParamsFor<vec2<f32>>((default_vec2.align - 1) * 7, false),
Params{ast_vec3<f32>, (default_vec3.align - 1) * 7, false}, ParamsFor<vec3<f32>>((default_vec3.align - 1) * 7, false),
Params{ast_vec4<f32>, (default_vec4.align - 1) * 7, false}, ParamsFor<vec4<f32>>((default_vec4.align - 1) * 7, false),
Params{ast_mat2x2<f32>, (default_mat2x2.align - 1) * 7, false}, ParamsFor<mat2x2<f32>>((default_mat2x2.align - 1) * 7, false),
Params{ast_mat3x3<f32>, (default_mat3x3.align - 1) * 7, false}, ParamsFor<mat3x3<f32>>((default_mat3x3.align - 1) * 7, false),
Params{ast_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false})); ParamsFor<mat4x4<f32>>((default_mat4x4.align - 1) * 7, false)));
TEST_F(ArrayStrideTest, MultipleDecorations) { TEST_F(ArrayStrideTest, MultipleDecorations) {
auto* arr = ty.array(Source{{12, 34}}, ty.i32(), 4, auto* arr = ty.array(Source{{12, 34}}, ty.i32(), 4,

View File

@ -26,6 +26,27 @@ namespace tint {
namespace resolver { namespace resolver {
namespace { namespace {
// Helpers and typedefs
template <typename T>
using DataType = builder::DataType<T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T>
using alias = builder::alias<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
class ResolverEntryPointValidationTest : public TestHelper, class ResolverEntryPointValidationTest : public TestHelper,
public testing::Test {}; public testing::Test {};
@ -517,43 +538,48 @@ TEST_F(ResolverEntryPointValidationTest,
namespace TypeValidationTests { namespace TypeValidationTests {
struct Params { struct Params {
create_ast_type_func_ptr create_ast_type; builder::ast_type_func_ptr create_ast_type;
bool is_valid; bool is_valid;
}; };
template <typename T>
constexpr Params ParamsFor(bool is_valid) {
return Params{DataType<T>::AST, is_valid};
}
using TypeValidationTest = resolver::ResolverTestWithParam<Params>; using TypeValidationTest = resolver::ResolverTestWithParam<Params>;
static constexpr Params cases[] = { static constexpr Params cases[] = {
{ast_f32, true}, ParamsFor<f32>(true), //
{ast_i32, true}, ParamsFor<i32>(true), //
{ast_u32, true}, ParamsFor<u32>(true), //
{ast_bool, false}, ParamsFor<bool>(false), //
{ast_vec2<ast_f32>, true}, ParamsFor<vec2<f32>>(true), //
{ast_vec3<ast_f32>, true}, ParamsFor<vec3<f32>>(true), //
{ast_vec4<ast_f32>, true}, ParamsFor<vec4<f32>>(true), //
{ast_mat2x2<ast_f32>, false}, ParamsFor<mat2x2<f32>>(false), //
{ast_mat2x2<ast_i32>, false}, ParamsFor<mat2x2<i32>>(false), //
{ast_mat2x2<ast_u32>, false}, ParamsFor<mat2x2<u32>>(false), //
{ast_mat2x2<ast_bool>, false}, ParamsFor<mat2x2<bool>>(false), //
{ast_mat3x3<ast_f32>, false}, ParamsFor<mat3x3<f32>>(false), //
{ast_mat3x3<ast_i32>, false}, ParamsFor<mat3x3<i32>>(false), //
{ast_mat3x3<ast_u32>, false}, ParamsFor<mat3x3<u32>>(false), //
{ast_mat3x3<ast_bool>, false}, ParamsFor<mat3x3<bool>>(false), //
{ast_mat4x4<ast_f32>, false}, ParamsFor<mat4x4<f32>>(false), //
{ast_mat4x4<ast_i32>, false}, ParamsFor<mat4x4<i32>>(false), //
{ast_mat4x4<ast_u32>, false}, ParamsFor<mat4x4<u32>>(false), //
{ast_mat4x4<ast_bool>, false}, ParamsFor<mat4x4<bool>>(false), //
{ast_alias<ast_f32>, true}, ParamsFor<alias<f32>>(true), //
{ast_alias<ast_i32>, true}, ParamsFor<alias<i32>>(true), //
{ast_alias<ast_u32>, true}, ParamsFor<alias<u32>>(true), //
{ast_alias<ast_bool>, false}, ParamsFor<alias<bool>>(false), //
}; };
TEST_P(TypeValidationTest, BareInputs) { TEST_P(TypeValidationTest, BareInputs) {
// [[stage(fragment)]] // [[stage(fragment)]]
// fn main([[location(0)]] a : *) {} // fn main([[location(0)]] a : *) {}
auto params = GetParam(); auto params = GetParam();
auto* a = Param("a", params.create_ast_type(ty), {Location(0)}); auto* a = Param("a", params.create_ast_type(*this), {Location(0)});
Func(Source{{12, 34}}, "main", {a}, ty.void_(), {}, Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment)}); {Stage(ast::PipelineStage::kFragment)});
@ -572,7 +598,7 @@ TEST_P(TypeValidationTest, StructInputs) {
// fn main(a : Input) {} // fn main(a : Input) {}
auto params = GetParam(); auto params = GetParam();
auto* input = Structure( auto* input = Structure(
"Input", {Member("a", params.create_ast_type(ty), {Location(0)})}); "Input", {Member("a", params.create_ast_type(*this), {Location(0)})});
auto* a = Param("a", input, {}); auto* a = Param("a", input, {});
Func(Source{{12, 34}}, "main", {a}, ty.void_(), {}, Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment)}); {Stage(ast::PipelineStage::kFragment)});
@ -590,8 +616,8 @@ TEST_P(TypeValidationTest, BareOutputs) {
// return *(); // return *();
// } // }
auto params = GetParam(); auto params = GetParam();
Func(Source{{12, 34}}, "main", {}, params.create_ast_type(ty), Func(Source{{12, 34}}, "main", {}, params.create_ast_type(*this),
{Return(Construct(params.create_ast_type(ty)))}, {Return(Construct(params.create_ast_type(*this)))},
{Stage(ast::PipelineStage::kFragment)}, {Location(0)}); {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
if (params.is_valid) { if (params.is_valid) {
@ -611,7 +637,7 @@ TEST_P(TypeValidationTest, StructOutputs) {
// } // }
auto params = GetParam(); auto params = GetParam();
auto* output = Structure( auto* output = Structure(
"Output", {Member("a", params.create_ast_type(ty), {Location(0)})}); "Output", {Member("a", params.create_ast_type(*this), {Location(0)})});
Func(Source{{12, 34}}, "main", {}, output, {Return(Construct(output))}, Func(Source{{12, 34}}, "main", {}, output, {Return(Construct(output))},
{Stage(ast::PipelineStage::kFragment)}); {Stage(ast::PipelineStage::kFragment)});

View File

@ -23,42 +23,62 @@ namespace resolver {
namespace { namespace {
// Helpers and typedefs // Helpers and typedefs
using i32 = ProgramBuilder::i32; template <typename T>
using u32 = ProgramBuilder::u32; using DataType = builder::DataType<T>;
using f32 = ProgramBuilder::f32; template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T>
using alias = builder::alias<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
struct ResolverInferredTypeTest : public resolver::TestHelper, struct ResolverInferredTypeTest : public resolver::TestHelper,
public testing::Test {}; public testing::Test {};
struct Params { struct Params {
create_ast_type_func_ptr create_type; builder::ast_expr_func_ptr create_value;
create_sem_type_func_ptr create_expected_type; builder::sem_type_func_ptr create_expected_type;
}; };
Params all_cases[] = { template <typename T>
{ast_bool, sem_bool}, constexpr Params ParamsFor() {
{ast_u32, sem_u32}, return Params{DataType<T>::Expr, DataType<T>::Sem};
{ast_i32, sem_i32}, }
{ast_f32, sem_f32},
{ast_vec3<bool>, sem_vec3<sem_bool>},
{ast_vec3<i32>, sem_vec3<sem_i32>},
{ast_vec3<u32>, sem_vec3<sem_u32>},
{ast_vec3<f32>, sem_vec3<sem_f32>},
{ast_mat3x3<i32>, sem_mat3x3<sem_i32>},
{ast_mat3x3<u32>, sem_mat3x3<sem_u32>},
{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
{ast_alias<ast_bool>, sem_bool}, Params all_cases[] = {
{ast_alias<ast_u32>, sem_u32}, ParamsFor<bool>(), //
{ast_alias<ast_i32>, sem_i32}, ParamsFor<u32>(), //
{ast_alias<ast_f32>, sem_f32}, ParamsFor<i32>(), //
{ast_alias<ast_vec3<bool>>, sem_vec3<sem_bool>}, ParamsFor<f32>(), //
{ast_alias<ast_vec3<i32>>, sem_vec3<sem_i32>}, ParamsFor<vec3<bool>>(), //
{ast_alias<ast_vec3<u32>>, sem_vec3<sem_u32>}, ParamsFor<vec3<i32>>(), //
{ast_alias<ast_vec3<f32>>, sem_vec3<sem_f32>}, ParamsFor<vec3<u32>>(), //
{ast_alias<ast_mat3x3<i32>>, sem_mat3x3<sem_i32>}, ParamsFor<vec3<f32>>(), //
{ast_alias<ast_mat3x3<u32>>, sem_mat3x3<sem_u32>}, ParamsFor<mat3x3<i32>>(), //
{ast_alias<ast_mat3x3<f32>>, sem_mat3x3<sem_f32>}, ParamsFor<mat3x3<u32>>(), //
ParamsFor<mat3x3<f32>>(), //
ParamsFor<alias<bool>>(), //
ParamsFor<alias<u32>>(), //
ParamsFor<alias<i32>>(), //
ParamsFor<alias<f32>>(), //
ParamsFor<alias<vec3<bool>>>(), //
ParamsFor<alias<vec3<i32>>>(), //
ParamsFor<alias<vec3<u32>>>(), //
ParamsFor<alias<vec3<f32>>>(), //
ParamsFor<alias<mat3x3<i32>>>(), //
ParamsFor<alias<mat3x3<u32>>>(), //
ParamsFor<alias<mat3x3<f32>>>(), //
}; };
using ResolverInferredTypeParamTest = ResolverTestWithParam<Params>; using ResolverInferredTypeParamTest = ResolverTestWithParam<Params>;
@ -66,11 +86,10 @@ using ResolverInferredTypeParamTest = ResolverTestWithParam<Params>;
TEST_P(ResolverInferredTypeParamTest, GlobalLet_Pass) { TEST_P(ResolverInferredTypeParamTest, GlobalLet_Pass) {
auto& params = GetParam(); auto& params = GetParam();
auto* type = params.create_type(ty); auto* expected_type = params.create_expected_type(*this);
auto* expected_type = params.create_expected_type(ty);
// let a = <type constructor>; // let a = <type constructor>;
auto* ctor_expr = ConstructValueFilledWith(type); auto* ctor_expr = params.create_value(*this, 0);
auto* var = GlobalConst("a", nullptr, ctor_expr); auto* var = GlobalConst("a", nullptr, ctor_expr);
WrapInFunction(); WrapInFunction();
@ -81,10 +100,8 @@ TEST_P(ResolverInferredTypeParamTest, GlobalLet_Pass) {
TEST_P(ResolverInferredTypeParamTest, GlobalVar_Fail) { TEST_P(ResolverInferredTypeParamTest, GlobalVar_Fail) {
auto& params = GetParam(); auto& params = GetParam();
auto* type = params.create_type(ty);
// var a = <type constructor>; // var a = <type constructor>;
auto* ctor_expr = ConstructValueFilledWith(type); auto* ctor_expr = params.create_value(*this, 0);
Global(Source{{12, 34}}, "a", nullptr, ast::StorageClass::kPrivate, Global(Source{{12, 34}}, "a", nullptr, ast::StorageClass::kPrivate,
ctor_expr); ctor_expr);
WrapInFunction(); WrapInFunction();
@ -97,11 +114,10 @@ TEST_P(ResolverInferredTypeParamTest, GlobalVar_Fail) {
TEST_P(ResolverInferredTypeParamTest, LocalLet_Pass) { TEST_P(ResolverInferredTypeParamTest, LocalLet_Pass) {
auto& params = GetParam(); auto& params = GetParam();
auto* type = params.create_type(ty); auto* expected_type = params.create_expected_type(*this);
auto* expected_type = params.create_expected_type(ty);
// let a = <type constructor>; // let a = <type constructor>;
auto* ctor_expr = ConstructValueFilledWith(type); auto* ctor_expr = params.create_value(*this, 0);
auto* var = Const("a", nullptr, ctor_expr); auto* var = Const("a", nullptr, ctor_expr);
WrapInFunction(var); WrapInFunction(var);
@ -112,11 +128,10 @@ TEST_P(ResolverInferredTypeParamTest, LocalLet_Pass) {
TEST_P(ResolverInferredTypeParamTest, LocalVar_Pass) { TEST_P(ResolverInferredTypeParamTest, LocalVar_Pass) {
auto& params = GetParam(); auto& params = GetParam();
auto* type = params.create_type(ty); auto* expected_type = params.create_expected_type(*this);
auto* expected_type = params.create_expected_type(ty);
// var a = <type constructor>; // var a = <type constructor>;
auto* ctor_expr = ConstructValueFilledWith(type); auto* ctor_expr = params.create_value(*this, 0);
auto* var = Var("a", nullptr, ast::StorageClass::kFunction, ctor_expr); auto* var = Var("a", nullptr, ast::StorageClass::kFunction, ctor_expr);
WrapInFunction(var); WrapInFunction(var);

View File

@ -245,8 +245,7 @@ bool Resolver::ResolveInternal() {
for (auto* node : builder_->ASTNodes().Objects()) { for (auto* node : builder_->ASTNodes().Objects()) {
if (marked_.count(node) == 0) { if (marked_.count(node) == 0) {
if (node->IsAnyOf<ast::AccessDecoration, ast::StrideDecoration, if (node->IsAnyOf<ast::AccessDecoration, ast::StrideDecoration>()) {
ast::Type>()) {
// TODO(crbug.com/tint/724) - Remove once tint:724 is complete. // TODO(crbug.com/tint/724) - Remove once tint:724 is complete.
// ast::AccessDecorations are generated by the WGSL parser, used to // ast::AccessDecorations are generated by the WGSL parser, used to
// build sem::AccessControls and then leaked. // build sem::AccessControls and then leaked.
@ -254,7 +253,6 @@ bool Resolver::ResolveInternal() {
// multiple arrays of the same stride, size and element type are // multiple arrays of the same stride, size and element type are
// currently de-duplicated by the type manager, and we leak these // currently de-duplicated by the type manager, and we leak these
// decorations. // decorations.
// ast::Types are being built, but not yet being handled. This is WIP.
continue; continue;
} }
TINT_ICE(diagnostics_) << "AST node '" << node->TypeInfo().name TINT_ICE(diagnostics_) << "AST node '" << node->TypeInfo().name

View File

@ -51,9 +51,39 @@ namespace resolver {
namespace { namespace {
// Helpers and typedefs // Helpers and typedefs
using i32 = ProgramBuilder::i32; template <typename T>
using u32 = ProgramBuilder::u32; using DataType = builder::DataType<T>;
using f32 = ProgramBuilder::f32; template <int N, typename T>
using vec = builder::vec<N, T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <int N, int M, typename T>
using mat = builder::mat<N, M, T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat2x3 = builder::mat2x3<T>;
template <typename T>
using mat3x2 = builder::mat3x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T, int ID = 0>
using alias = builder::alias<T, ID>;
template <typename T>
using alias1 = builder::alias1<T>;
template <typename T>
using alias2 = builder::alias2<T>;
template <typename T>
using alias3 = builder::alias3<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
using Op = ast::BinaryOp; using Op = ast::BinaryOp;
TEST_F(ResolverTest, Stmt_Assign) { TEST_F(ResolverTest, Stmt_Assign) {
@ -1209,13 +1239,40 @@ TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
namespace ExprBinaryTest { namespace ExprBinaryTest {
template <typename T, int ID>
struct Aliased {
using type = alias<T, ID>;
};
template <int N, typename T, int ID>
struct Aliased<vec<N, T>, ID> {
using type = vec<N, alias<T, ID>>;
};
template <int N, int M, typename T, int ID>
struct Aliased<mat<N, M, T>, ID> {
using type = mat<N, M, alias<T, ID>>;
};
struct Params { struct Params {
ast::BinaryOp op; ast::BinaryOp op;
create_ast_type_func_ptr create_lhs_type; builder::ast_type_func_ptr create_lhs_type;
create_ast_type_func_ptr create_rhs_type; builder::ast_type_func_ptr create_rhs_type;
create_sem_type_func_ptr create_result_type; builder::ast_type_func_ptr create_lhs_alias_type;
builder::ast_type_func_ptr create_rhs_alias_type;
builder::sem_type_func_ptr create_result_type;
}; };
template <typename LHS, typename RHS, typename RES>
constexpr Params ParamsFor(ast::BinaryOp op) {
return Params{op,
DataType<LHS>::AST,
DataType<RHS>::AST,
DataType<typename Aliased<LHS, 0>::type>::AST,
DataType<typename Aliased<RHS, 1>::type>::AST,
DataType<RES>::Sem};
}
static constexpr ast::BinaryOp all_ops[] = { static constexpr ast::BinaryOp all_ops[] = {
ast::BinaryOp::kAnd, ast::BinaryOp::kAnd,
ast::BinaryOp::kOr, ast::BinaryOp::kOr,
@ -1237,12 +1294,24 @@ static constexpr ast::BinaryOp all_ops[] = {
ast::BinaryOp::kModulo, ast::BinaryOp::kModulo,
}; };
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = { static constexpr builder::ast_type_func_ptr all_create_type_funcs[] = {
ast_bool, ast_u32, ast_i32, ast_f32, DataType<bool>::AST, //
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>, DataType<u32>::AST, //
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>, // DataType<i32>::AST, //
ast_mat2x3<i32>, ast_mat2x3<u32>, ast_mat2x3<f32>, // DataType<f32>::AST, //
ast_mat3x2<i32>, ast_mat3x2<u32>, ast_mat3x2<f32> // DataType<vec3<bool>>::AST, //
DataType<vec3<i32>>::AST, //
DataType<vec3<u32>>::AST, //
DataType<vec3<f32>>::AST, //
DataType<mat3x3<i32>>::AST, //
DataType<mat3x3<u32>>::AST, //
DataType<mat3x3<f32>>::AST, //
DataType<mat2x3<i32>>::AST, //
DataType<mat2x3<u32>>::AST, //
DataType<mat2x3<f32>>::AST, //
DataType<mat3x2<i32>>::AST, //
DataType<mat3x2<u32>>::AST, //
DataType<mat3x2<f32>>::AST //
}; };
// A list of all valid test cases for 'lhs op rhs', except that for vecN and // A list of all valid test cases for 'lhs op rhs', except that for vecN and
@ -1252,229 +1321,216 @@ static constexpr Params all_valid_cases[] = {
// https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr // https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
// Binary logical expressions // Binary logical expressions
Params{Op::kLogicalAnd, ast_bool, ast_bool, sem_bool}, ParamsFor<bool, bool, bool>(Op::kLogicalAnd),
Params{Op::kLogicalOr, ast_bool, ast_bool, sem_bool}, ParamsFor<bool, bool, bool>(Op::kLogicalOr),
Params{Op::kAnd, ast_bool, ast_bool, sem_bool}, ParamsFor<bool, bool, bool>(Op::kAnd),
Params{Op::kOr, ast_bool, ast_bool, sem_bool}, ParamsFor<bool, bool, bool>(Op::kOr),
Params{Op::kAnd, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>}, ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kAnd),
Params{Op::kOr, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>}, ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kOr),
// Arithmetic expressions // Arithmetic expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr // https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
// Binary arithmetic expressions over scalars // Binary arithmetic expressions over scalars
Params{Op::kAdd, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kAdd),
Params{Op::kSubtract, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kSubtract),
Params{Op::kMultiply, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kMultiply),
Params{Op::kDivide, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kDivide),
Params{Op::kModulo, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kModulo),
Params{Op::kAdd, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kAdd),
Params{Op::kSubtract, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kSubtract),
Params{Op::kMultiply, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kMultiply),
Params{Op::kDivide, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kDivide),
Params{Op::kModulo, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kModulo),
Params{Op::kAdd, ast_f32, ast_f32, sem_f32}, ParamsFor<f32, f32, f32>(Op::kAdd),
Params{Op::kSubtract, ast_f32, ast_f32, sem_f32}, ParamsFor<f32, f32, f32>(Op::kSubtract),
Params{Op::kMultiply, ast_f32, ast_f32, sem_f32}, ParamsFor<f32, f32, f32>(Op::kMultiply),
Params{Op::kDivide, ast_f32, ast_f32, sem_f32}, ParamsFor<f32, f32, f32>(Op::kDivide),
Params{Op::kModulo, ast_f32, ast_f32, sem_f32}, ParamsFor<f32, f32, f32>(Op::kModulo),
// Binary arithmetic expressions over vectors // Binary arithmetic expressions over vectors
Params{Op::kAdd, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kAdd),
Params{Op::kSubtract, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kSubtract),
Params{Op::kMultiply, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kMultiply),
Params{Op::kDivide, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kDivide),
Params{Op::kModulo, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kModulo),
Params{Op::kAdd, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kAdd),
Params{Op::kSubtract, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kSubtract),
Params{Op::kMultiply, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kMultiply),
Params{Op::kDivide, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kDivide),
Params{Op::kModulo, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kModulo),
Params{Op::kAdd, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kAdd),
Params{Op::kSubtract, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kSubtract),
Params{Op::kMultiply, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kMultiply),
Params{Op::kDivide, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kDivide),
Params{Op::kModulo, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kModulo),
// Binary arithmetic expressions with mixed scalar and vector operands // Binary arithmetic expressions with mixed scalar and vector operands
Params{Op::kAdd, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kAdd),
Params{Op::kSubtract, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kSubtract),
Params{Op::kMultiply, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kMultiply),
Params{Op::kDivide, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kDivide),
Params{Op::kModulo, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kModulo),
Params{Op::kAdd, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kAdd),
Params{Op::kSubtract, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kSubtract),
Params{Op::kMultiply, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kMultiply),
Params{Op::kDivide, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kDivide),
Params{Op::kModulo, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kModulo),
Params{Op::kAdd, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kAdd),
Params{Op::kSubtract, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kSubtract),
Params{Op::kMultiply, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kMultiply),
Params{Op::kDivide, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kDivide),
Params{Op::kModulo, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kModulo),
Params{Op::kAdd, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kAdd),
Params{Op::kSubtract, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kSubtract),
Params{Op::kMultiply, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kMultiply),
Params{Op::kDivide, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kDivide),
Params{Op::kModulo, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kModulo),
Params{Op::kAdd, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kAdd),
Params{Op::kSubtract, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kSubtract),
Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kMultiply),
Params{Op::kDivide, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kDivide),
// NOTE: no kModulo for ast_vec3<f32>, ast_f32 // NOTE: no kModulo for vec3<f32>, f32
// Params{Op::kModulo, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, // ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kModulo),
Params{Op::kAdd, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kAdd),
Params{Op::kSubtract, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kSubtract),
Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kMultiply),
Params{Op::kDivide, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kDivide),
// NOTE: no kModulo for ast_f32, ast_vec3<f32> // NOTE: no kModulo for f32, vec3<f32>
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, // ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kModulo),
// Matrix arithmetic // Matrix arithmetic
Params{Op::kMultiply, ast_mat2x3<f32>, ast_f32, sem_mat2x3<sem_f32>}, ParamsFor<mat2x3<f32>, f32, mat2x3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat3x2<f32>, ast_f32, sem_mat3x2<sem_f32>}, ParamsFor<mat3x2<f32>, f32, mat3x2<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>}, ParamsFor<mat3x3<f32>, f32, mat3x3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_f32, ast_mat2x3<f32>, sem_mat2x3<sem_f32>}, ParamsFor<f32, mat2x3<f32>, mat2x3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_f32, ast_mat3x2<f32>, sem_mat3x2<sem_f32>}, ParamsFor<f32, mat3x2<f32>, mat3x2<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>}, ParamsFor<f32, mat3x3<f32>, mat3x3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_vec3<f32>, ast_mat2x3<f32>, sem_vec2<sem_f32>}, ParamsFor<vec3<f32>, mat2x3<f32>, vec2<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_vec2<f32>, ast_mat3x2<f32>, sem_vec3<sem_f32>}, ParamsFor<vec2<f32>, mat3x2<f32>, vec3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>, mat3x3<f32>, vec3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat3x2<f32>, ast_vec3<f32>, sem_vec2<sem_f32>}, ParamsFor<mat3x2<f32>, vec3<f32>, vec2<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat2x3<f32>, ast_vec2<f32>, sem_vec3<sem_f32>}, ParamsFor<mat2x3<f32>, vec2<f32>, vec3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<mat3x3<f32>, vec3<f32>, vec3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat2x3<f32>, ast_mat3x2<f32>, ParamsFor<mat2x3<f32>, mat3x2<f32>, mat3x3<f32>>(Op::kMultiply),
sem_mat3x3<sem_f32>}, ParamsFor<mat3x2<f32>, mat2x3<f32>, mat2x2<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat2x3<f32>, ParamsFor<mat3x2<f32>, mat3x3<f32>, mat3x2<f32>>(Op::kMultiply),
sem_mat2x2<sem_f32>}, ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kMultiply),
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat3x3<f32>, ParamsFor<mat3x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kMultiply),
sem_mat3x2<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat2x3<f32>,
sem_mat2x3<sem_f32>},
Params{Op::kAdd, ast_mat2x3<f32>, ast_mat2x3<f32>, sem_mat2x3<sem_f32>}, ParamsFor<mat2x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kAdd),
Params{Op::kAdd, ast_mat3x2<f32>, ast_mat3x2<f32>, sem_mat3x2<sem_f32>}, ParamsFor<mat3x2<f32>, mat3x2<f32>, mat3x2<f32>>(Op::kAdd),
Params{Op::kAdd, ast_mat3x3<f32>, ast_mat3x3<f32>, sem_mat3x3<sem_f32>}, ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kAdd),
Params{Op::kSubtract, ast_mat2x3<f32>, ast_mat2x3<f32>, ParamsFor<mat2x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kSubtract),
sem_mat2x3<sem_f32>}, ParamsFor<mat3x2<f32>, mat3x2<f32>, mat3x2<f32>>(Op::kSubtract),
Params{Op::kSubtract, ast_mat3x2<f32>, ast_mat3x2<f32>, ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kSubtract),
sem_mat3x2<sem_f32>},
Params{Op::kSubtract, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>},
// Comparison expressions // Comparison expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
// Comparisons over scalars // Comparisons over scalars
Params{Op::kEqual, ast_bool, ast_bool, sem_bool}, ParamsFor<bool, bool, bool>(Op::kEqual),
Params{Op::kNotEqual, ast_bool, ast_bool, sem_bool}, ParamsFor<bool, bool, bool>(Op::kNotEqual),
Params{Op::kEqual, ast_i32, ast_i32, sem_bool}, ParamsFor<i32, i32, bool>(Op::kEqual),
Params{Op::kNotEqual, ast_i32, ast_i32, sem_bool}, ParamsFor<i32, i32, bool>(Op::kNotEqual),
Params{Op::kLessThan, ast_i32, ast_i32, sem_bool}, ParamsFor<i32, i32, bool>(Op::kLessThan),
Params{Op::kLessThanEqual, ast_i32, ast_i32, sem_bool}, ParamsFor<i32, i32, bool>(Op::kLessThanEqual),
Params{Op::kGreaterThan, ast_i32, ast_i32, sem_bool}, ParamsFor<i32, i32, bool>(Op::kGreaterThan),
Params{Op::kGreaterThanEqual, ast_i32, ast_i32, sem_bool}, ParamsFor<i32, i32, bool>(Op::kGreaterThanEqual),
Params{Op::kEqual, ast_u32, ast_u32, sem_bool}, ParamsFor<u32, u32, bool>(Op::kEqual),
Params{Op::kNotEqual, ast_u32, ast_u32, sem_bool}, ParamsFor<u32, u32, bool>(Op::kNotEqual),
Params{Op::kLessThan, ast_u32, ast_u32, sem_bool}, ParamsFor<u32, u32, bool>(Op::kLessThan),
Params{Op::kLessThanEqual, ast_u32, ast_u32, sem_bool}, ParamsFor<u32, u32, bool>(Op::kLessThanEqual),
Params{Op::kGreaterThan, ast_u32, ast_u32, sem_bool}, ParamsFor<u32, u32, bool>(Op::kGreaterThan),
Params{Op::kGreaterThanEqual, ast_u32, ast_u32, sem_bool}, ParamsFor<u32, u32, bool>(Op::kGreaterThanEqual),
Params{Op::kEqual, ast_f32, ast_f32, sem_bool}, ParamsFor<f32, f32, bool>(Op::kEqual),
Params{Op::kNotEqual, ast_f32, ast_f32, sem_bool}, ParamsFor<f32, f32, bool>(Op::kNotEqual),
Params{Op::kLessThan, ast_f32, ast_f32, sem_bool}, ParamsFor<f32, f32, bool>(Op::kLessThan),
Params{Op::kLessThanEqual, ast_f32, ast_f32, sem_bool}, ParamsFor<f32, f32, bool>(Op::kLessThanEqual),
Params{Op::kGreaterThan, ast_f32, ast_f32, sem_bool}, ParamsFor<f32, f32, bool>(Op::kGreaterThan),
Params{Op::kGreaterThanEqual, ast_f32, ast_f32, sem_bool}, ParamsFor<f32, f32, bool>(Op::kGreaterThanEqual),
// Comparisons over vectors // Comparisons over vectors
Params{Op::kEqual, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>}, ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kEqual),
Params{Op::kNotEqual, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>}, ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kNotEqual),
Params{Op::kEqual, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kEqual),
Params{Op::kNotEqual, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kNotEqual),
Params{Op::kLessThan, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kLessThan),
Params{Op::kLessThanEqual, ast_vec3<i32>, ast_vec3<i32>, ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kLessThanEqual),
sem_vec3<sem_bool>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kGreaterThan),
Params{Op::kGreaterThan, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kGreaterThanEqual),
Params{Op::kGreaterThanEqual, ast_vec3<i32>, ast_vec3<i32>,
sem_vec3<sem_bool>},
Params{Op::kEqual, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kEqual),
Params{Op::kNotEqual, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kNotEqual),
Params{Op::kLessThan, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kLessThan),
Params{Op::kLessThanEqual, ast_vec3<u32>, ast_vec3<u32>, ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kLessThanEqual),
sem_vec3<sem_bool>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kGreaterThan),
Params{Op::kGreaterThan, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kGreaterThanEqual),
Params{Op::kGreaterThanEqual, ast_vec3<u32>, ast_vec3<u32>,
sem_vec3<sem_bool>},
Params{Op::kEqual, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kEqual),
Params{Op::kNotEqual, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kNotEqual),
Params{Op::kLessThan, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kLessThan),
Params{Op::kLessThanEqual, ast_vec3<f32>, ast_vec3<f32>, ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kLessThanEqual),
sem_vec3<sem_bool>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kGreaterThan),
Params{Op::kGreaterThan, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>}, ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kGreaterThanEqual),
Params{Op::kGreaterThanEqual, ast_vec3<f32>, ast_vec3<f32>,
sem_vec3<sem_bool>},
// Binary bitwise operations // Binary bitwise operations
Params{Op::kOr, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kOr),
Params{Op::kAnd, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kAnd),
Params{Op::kXor, ast_i32, ast_i32, sem_i32}, ParamsFor<i32, i32, i32>(Op::kXor),
Params{Op::kOr, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kOr),
Params{Op::kAnd, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kAnd),
Params{Op::kXor, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kXor),
Params{Op::kOr, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kOr),
Params{Op::kAnd, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kAnd),
Params{Op::kXor, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kXor),
Params{Op::kOr, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kOr),
Params{Op::kAnd, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kAnd),
Params{Op::kXor, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kXor),
// Bit shift expressions // Bit shift expressions
Params{Op::kShiftLeft, ast_i32, ast_u32, sem_i32}, ParamsFor<i32, u32, i32>(Op::kShiftLeft),
Params{Op::kShiftLeft, ast_vec3<i32>, ast_vec3<u32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<u32>, vec3<i32>>(Op::kShiftLeft),
Params{Op::kShiftLeft, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kShiftLeft),
Params{Op::kShiftLeft, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kShiftLeft),
Params{Op::kShiftRight, ast_i32, ast_u32, sem_i32}, ParamsFor<i32, u32, i32>(Op::kShiftRight),
Params{Op::kShiftRight, ast_vec3<i32>, ast_vec3<u32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>, vec3<u32>, vec3<i32>>(Op::kShiftRight),
Params{Op::kShiftRight, ast_u32, ast_u32, sem_u32}, ParamsFor<u32, u32, u32>(Op::kShiftRight),
Params{Op::kShiftRight, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}}; ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kShiftRight),
};
using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>; using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
TEST_P(Expr_Binary_Test_Valid, All) { TEST_P(Expr_Binary_Test_Valid, All) {
auto& params = GetParam(); auto& params = GetParam();
auto* lhs_type = params.create_lhs_type(ty); auto* lhs_type = params.create_lhs_type(*this);
auto* rhs_type = params.create_rhs_type(ty); auto* rhs_type = params.create_rhs_type(*this);
auto* result_type = params.create_result_type(ty); auto* result_type = params.create_result_type(*this);
std::stringstream ss; std::stringstream ss;
ss << FriendlyName(lhs_type) << " " << params.op << " " ss << FriendlyName(lhs_type) << " " << params.op << " "
@ -1503,38 +1559,22 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
const Params& params = std::get<0>(GetParam()); const Params& params = std::get<0>(GetParam());
BinaryExprSide side = std::get<1>(GetParam()); BinaryExprSide side = std::get<1>(GetParam());
auto* lhs_type = params.create_lhs_type(ty); auto* create_lhs_type =
auto* rhs_type = params.create_rhs_type(ty); (side == BinaryExprSide::Left || side == BinaryExprSide::Both)
? params.create_lhs_alias_type
: params.create_lhs_type;
auto* create_rhs_type =
(side == BinaryExprSide::Right || side == BinaryExprSide::Both)
? params.create_rhs_alias_type
: params.create_rhs_type;
auto* lhs_type = create_lhs_type(*this);
auto* rhs_type = create_rhs_type(*this);
std::stringstream ss; std::stringstream ss;
ss << FriendlyName(lhs_type) << " " << params.op << " " ss << FriendlyName(lhs_type) << " " << params.op << " "
<< FriendlyName(rhs_type); << FriendlyName(rhs_type);
// For vectors and matrices, wrap the sub type in an alias
auto make_alias = [this](ast::Type* type) -> ast::Type* {
if (auto* v = type->As<ast::Vector>()) {
auto* alias = ty.alias(Symbols().New(), v->type());
AST().AddConstructedType(alias);
return ty.vec(alias, v->size());
}
if (auto* m = type->As<ast::Matrix>()) {
auto* alias = ty.alias(Symbols().New(), m->type());
AST().AddConstructedType(alias);
return ty.mat(alias, m->columns(), m->rows());
}
auto* alias = ty.alias(Symbols().New(), type);
AST().AddConstructedType(alias);
return ty.type_name(alias->name());
};
// Wrap in alias
if (side == BinaryExprSide::Left || side == BinaryExprSide::Both) {
lhs_type = make_alias(lhs_type);
}
if (side == BinaryExprSide::Right || side == BinaryExprSide::Both) {
rhs_type = make_alias(rhs_type);
}
ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op
<< " " << FriendlyName(rhs_type); << " " << FriendlyName(rhs_type);
SCOPED_TRACE(ss.str()); SCOPED_TRACE(ss.str());
@ -1550,7 +1590,7 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
ASSERT_NE(TypeOf(expr), nullptr); ASSERT_NE(TypeOf(expr), nullptr);
// TODO(amaiorano): Bring this back once we have a way to get the canonical // TODO(amaiorano): Bring this back once we have a way to get the canonical
// type // type
// auto* *result_type = params.create_result_type(ty); // auto* *result_type = params.create_result_type(*this);
// ASSERT_TRUE(TypeOf(expr) == result_type); // ASSERT_TRUE(TypeOf(expr) == result_type);
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
@ -1565,13 +1605,13 @@ INSTANTIATE_TEST_SUITE_P(
// (type * type * op), and processing only the triplets that are not found in // (type * type * op), and processing only the triplets that are not found in
// the `all_valid_cases` table. // the `all_valid_cases` table.
using Expr_Binary_Test_Invalid = using Expr_Binary_Test_Invalid =
ResolverTestWithParam<std::tuple<create_ast_type_func_ptr, ResolverTestWithParam<std::tuple<builder::ast_type_func_ptr,
create_ast_type_func_ptr, builder::ast_type_func_ptr,
ast::BinaryOp>>; ast::BinaryOp>>;
TEST_P(Expr_Binary_Test_Invalid, All) { TEST_P(Expr_Binary_Test_Invalid, All) {
const create_ast_type_func_ptr& lhs_create_type_func = const builder::ast_type_func_ptr& lhs_create_type_func =
std::get<0>(GetParam()); std::get<0>(GetParam());
const create_ast_type_func_ptr& rhs_create_type_func = const builder::ast_type_func_ptr& rhs_create_type_func =
std::get<1>(GetParam()); std::get<1>(GetParam());
const ast::BinaryOp op = std::get<2>(GetParam()); const ast::BinaryOp op = std::get<2>(GetParam());
@ -1584,8 +1624,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
} }
} }
auto* lhs_type = lhs_create_type_func(ty); auto* lhs_type = lhs_create_type_func(*this);
auto* rhs_type = rhs_create_type_func(ty); auto* rhs_type = rhs_create_type_func(*this);
std::stringstream ss; std::stringstream ss;
ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type); ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type);

View File

@ -120,172 +120,271 @@ template <typename T>
class ResolverTestWithParam : public TestHelper, class ResolverTestWithParam : public TestHelper,
public testing::TestWithParam<T> {}; public testing::TestWithParam<T> {};
inline ast::Type* ast_bool(const ProgramBuilder::TypesBuilder& ty) { namespace builder {
return ty.bool_();
}
inline ast::Type* ast_i32(const ProgramBuilder::TypesBuilder& ty) {
return ty.i32();
}
inline ast::Type* ast_u32(const ProgramBuilder::TypesBuilder& ty) {
return ty.u32();
}
inline ast::Type* ast_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.f32();
}
using create_ast_type_func_ptr = using i32 = ProgramBuilder::i32;
ast::Type* (*)(const ProgramBuilder::TypesBuilder& ty); using u32 = ProgramBuilder::u32;
using f32 = ProgramBuilder::f32;
template <int N, typename T>
struct vec {};
template <typename T> template <typename T>
ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) { using vec2 = vec<2, T>;
return ty.vec2<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec2(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) { using vec3 = vec<3, T>;
return ty.vec3<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) { using vec4 = vec<4, T>;
return ty.vec4<T>();
}
template <create_ast_type_func_ptr create_type> template <int N, int M, typename T>
ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) { struct mat {};
return ty.vec4(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) { using mat2x2 = mat<2, 2, T>;
return ty.mat2x2<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x2(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) { using mat2x3 = mat<2, 3, T>;
return ty.mat2x3<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x3(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) { using mat3x2 = mat<3, 2, T>;
return ty.mat3x2<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x2(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) { using mat3x3 = mat<3, 3, T>;
return ty.mat3x3<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) { using mat4x4 = mat<4, 4, T>;
return ty.mat4x4<T>();
}
template <create_ast_type_func_ptr create_type> template <typename TO, int ID = 0>
ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) { struct alias {};
return ty.mat4x4(create_type(ty));
}
template <create_ast_type_func_ptr create_type> template <typename TO>
ast::Type* ast_alias(const ProgramBuilder::TypesBuilder& ty) { using alias1 = alias<TO, 1>;
auto* type = create_type(ty);
auto name = ty.builder->Symbols().Register("alias_" + type->type_name()); template <typename TO>
if (!ty.builder->AST().LookupType(name)) { using alias2 = alias<TO, 2>;
ty.builder->AST().AddConstructedType(ty.alias(name, type));
template <typename TO>
using alias3 = alias<TO, 3>;
using ast_type_func_ptr = ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = ast::Expression* (*)(ProgramBuilder& b,
int elem_value);
using sem_type_func_ptr = sem::Type* (*)(ProgramBuilder& b);
template <typename T>
struct DataType {};
/// Helper for building bool types and expressions
template <>
struct DataType<bool> {
/// false as bool is not a composite type
static constexpr bool is_composite = false;
/// @param b the ProgramBuilder
/// @return a new AST bool type
static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); }
/// @param b the ProgramBuilder
/// @return the semantic bool type
static inline sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Bool>();
} }
return ty.builder->create<ast::TypeName>(name); /// @param b the ProgramBuilder
} /// @param elem_value the b
/// @return a new AST expression of the bool type
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(elem_value == 0);
}
};
inline sem::Type* sem_bool(const ProgramBuilder::TypesBuilder& ty) { /// Helper for building i32 types and expressions
return ty.builder->create<sem::Bool>(); template <>
} struct DataType<i32> {
inline sem::Type* sem_i32(const ProgramBuilder::TypesBuilder& ty) { /// false as i32 is not a composite type
return ty.builder->create<sem::I32>(); static constexpr bool is_composite = false;
}
inline sem::Type* sem_u32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::U32>();
}
inline sem::Type* sem_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::F32>();
}
using create_sem_type_func_ptr = /// @param b the ProgramBuilder
sem::Type* (*)(const ProgramBuilder::TypesBuilder& ty); /// @return a new AST i32 type
static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); }
/// @param b the ProgramBuilder
/// @return the semantic i32 type
static inline sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::I32>();
}
/// @param b the ProgramBuilder
/// @param elem_value the value i32 will be initialized with
/// @return a new AST i32 literal value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(static_cast<i32>(elem_value));
}
};
template <create_sem_type_func_ptr create_type> /// Helper for building u32 types and expressions
sem::Type* sem_vec2(const ProgramBuilder::TypesBuilder& ty) { template <>
return ty.builder->create<sem::Vector>(create_type(ty), 2); struct DataType<u32> {
} /// false as u32 is not a composite type
static constexpr bool is_composite = false;
template <create_sem_type_func_ptr create_type> /// @param b the ProgramBuilder
sem::Type* sem_vec3(const ProgramBuilder::TypesBuilder& ty) { /// @return a new AST u32 type
return ty.builder->create<sem::Vector>(create_type(ty), 3); static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); }
} /// @param b the ProgramBuilder
/// @return the semantic u32 type
static inline sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::U32>();
}
/// @param b the ProgramBuilder
/// @param elem_value the value u32 will be initialized with
/// @return a new AST u32 literal value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(static_cast<u32>(elem_value));
}
};
template <create_sem_type_func_ptr create_type> /// Helper for building f32 types and expressions
sem::Type* sem_vec4(const ProgramBuilder::TypesBuilder& ty) { template <>
return ty.builder->create<sem::Vector>(create_type(ty), 4); struct DataType<f32> {
} /// false as f32 is not a composite type
static constexpr bool is_composite = false;
template <create_sem_type_func_ptr create_type> /// @param b the ProgramBuilder
sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) { /// @return a new AST f32 type
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u); static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); }
return ty.builder->create<sem::Matrix>(column_type, 2u); /// @param b the ProgramBuilder
} /// @return the semantic f32 type
static inline sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::F32>();
}
/// @param b the ProgramBuilder
/// @param elem_value the value f32 will be initialized with
/// @return a new AST f32 literal value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(static_cast<f32>(elem_value));
}
};
template <create_sem_type_func_ptr create_type> /// Helper for building vector types and expressions
sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) { template <int N, typename T>
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u); struct DataType<vec<N, T>> {
return ty.builder->create<sem::Matrix>(column_type, 2u); /// true as vectors are a composite type
} static constexpr bool is_composite = true;
template <create_sem_type_func_ptr create_type> /// @param b the ProgramBuilder
sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) { /// @return a new AST vector type
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u); static inline ast::Type* AST(ProgramBuilder& b) {
return ty.builder->create<sem::Matrix>(column_type, 3u); return b.ty.vec(DataType<T>::AST(b), N);
} }
/// @param b the ProgramBuilder
/// @return the semantic vector type
static inline sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Vector>(DataType<T>::Sem(b), N);
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the vector will be initialized
/// with
/// @return a new AST vector value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
template <create_sem_type_func_ptr create_type> /// @param b the ProgramBuilder
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) { /// @param elem_value the value each element will be initialized with
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u); /// @return the list of expressions that are used to construct the vector
return ty.builder->create<sem::Matrix>(column_type, 3u); static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
} int elem_value) {
ast::ExpressionList args;
for (int i = 0; i < N; i++) {
args.emplace_back(DataType<T>::Expr(b, elem_value));
}
return args;
}
};
template <create_sem_type_func_ptr create_type> /// Helper for building matrix types and expressions
sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) { template <int N, int M, typename T>
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 4u); struct DataType<mat<N, M, T>> {
return ty.builder->create<sem::Matrix>(column_type, 4u); /// true as matrices are a composite type
} static constexpr bool is_composite = true;
/// @param b the ProgramBuilder
/// @return a new AST matrix type
static inline ast::Type* AST(ProgramBuilder& b) {
return b.ty.mat(DataType<T>::AST(b), N, M);
}
/// @param b the ProgramBuilder
/// @return the semantic matrix type
static inline sem::Type* Sem(ProgramBuilder& b) {
auto* column_type = b.create<sem::Vector>(DataType<T>::Sem(b), M);
return b.create<sem::Matrix>(column_type, N);
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the matrix will be initialized
/// with
/// @return a new AST matrix value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @return the list of expressions that are used to construct the matrix
static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
int elem_value) {
ast::ExpressionList args;
for (int i = 0; i < N; i++) {
args.emplace_back(DataType<vec<M, T>>::Expr(b, elem_value));
}
return args;
}
};
/// Helper for building alias types and expressions
template <typename T, int ID>
struct DataType<alias<T, ID>> {
/// true if the aliased type is a composite type
static constexpr bool is_composite = DataType<T>::is_composite;
/// @param b the ProgramBuilder
/// @return a new AST alias type
static inline ast::Type* AST(ProgramBuilder& b) {
auto name = b.Symbols().Register("alias_" + std::to_string(ID));
if (!b.AST().LookupType(name)) {
auto* type = DataType<T>::AST(b);
b.AST().AddConstructedType(b.ty.alias(name, type));
}
return b.create<ast::TypeName>(name);
}
/// @param b the ProgramBuilder
/// @return the semantic aliased type
static inline sem::Type* Sem(ProgramBuilder& b) {
return DataType<T>::Sem(b);
}
/// @param b the ProgramBuilder
/// @param elem_value the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<!IS_COMPOSITE, ast::Expression*> Expr(
ProgramBuilder& b,
int elem_value) {
// Cast
return b.Construct(AST(b), DataType<T>::Expr(b, elem_value));
}
/// @param b the ProgramBuilder
/// @param elem_value the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<IS_COMPOSITE, ast::Expression*> Expr(
ProgramBuilder& b,
int elem_value) {
// Construct
return b.Construct(AST(b), DataType<T>::ExprArgs(b, elem_value));
}
};
} // namespace builder
} // namespace resolver } // namespace resolver
} // namespace tint } // namespace tint

View File

@ -19,30 +19,47 @@ namespace tint {
namespace resolver { namespace resolver {
namespace { namespace {
/// @return the element type of `type` for vec and mat, otherwise `type` itself // Helpers and typedefs
ast::Type* ElementTypeOf(ast::Type* type) { template <typename T>
if (auto* v = type->As<ast::Vector>()) { using DataType = builder::DataType<T>;
return v->type(); template <typename T>
} using vec2 = builder::vec2<T>;
if (auto* m = type->As<ast::Matrix>()) { template <typename T>
return m->type(); using vec3 = builder::vec3<T>;
} template <typename T>
return type; using vec4 = builder::vec4<T>;
} template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T>
using alias = builder::alias<T>;
template <typename T>
using alias1 = builder::alias1<T>;
template <typename T>
using alias2 = builder::alias2<T>;
template <typename T>
using alias3 = builder::alias3<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
class ResolverTypeConstructorValidationTest : public resolver::TestHelper, class ResolverTypeConstructorValidationTest : public resolver::TestHelper,
public testing::Test {}; public testing::Test {};
namespace InferTypeTest { namespace InferTypeTest {
struct Params { struct Params {
create_ast_type_func_ptr create_rhs_ast_type; builder::ast_type_func_ptr create_rhs_ast_type;
create_sem_type_func_ptr create_rhs_sem_type; builder::ast_expr_func_ptr create_rhs_ast_value;
builder::sem_type_func_ptr create_rhs_sem_type;
}; };
// Helpers and typedefs template <typename T>
using i32 = ProgramBuilder::i32; constexpr Params ParamsFor() {
using u32 = ProgramBuilder::u32; return Params{DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
using f32 = ProgramBuilder::f32; }
TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) { TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) {
// var a = 1; // var a = 1;
@ -75,8 +92,7 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
// } // }
auto& params = GetParam(); auto& params = GetParam();
auto* rhs_type = params.create_rhs_ast_type(ty); auto* constructor_expr = params.create_rhs_ast_value(*this, 0);
auto* constructor_expr = ConstructValueFilledWith(rhs_type, 0);
auto* a = Var("a", nullptr, ast::StorageClass::kNone, constructor_expr); auto* a = Var("a", nullptr, ast::StorageClass::kNone, constructor_expr);
// Self-assign 'a' to force the expression to be resolved so we can test its // Self-assign 'a' to force the expression to be resolved so we can test its
@ -86,7 +102,7 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(a_ident); auto* got = TypeOf(a_ident);
auto* expected = create<sem::Reference>(params.create_rhs_sem_type(ty), auto* expected = create<sem::Reference>(params.create_rhs_sem_type(*this),
ast::StorageClass::kFunction, ast::StorageClass::kFunction,
ast::Access::kReadWrite); ast::Access::kReadWrite);
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
@ -94,26 +110,26 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
} }
static constexpr Params from_constructor_expression_cases[] = { static constexpr Params from_constructor_expression_cases[] = {
Params{ast_bool, sem_bool}, ParamsFor<bool>(),
Params{ast_i32, sem_i32}, ParamsFor<i32>(),
Params{ast_u32, sem_u32}, ParamsFor<u32>(),
Params{ast_f32, sem_f32}, ParamsFor<f32>(),
Params{ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>>(),
Params{ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>>(),
Params{ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>>(),
Params{ast_mat3x3<i32>, sem_mat3x3<sem_i32>}, ParamsFor<mat3x3<i32>>(),
Params{ast_mat3x3<u32>, sem_mat3x3<sem_u32>}, ParamsFor<mat3x3<u32>>(),
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>}, ParamsFor<mat3x3<f32>>(),
Params{ast_alias<ast_bool>, sem_bool}, ParamsFor<alias<bool>>(),
Params{ast_alias<ast_i32>, sem_i32}, ParamsFor<alias<i32>>(),
Params{ast_alias<ast_u32>, sem_u32}, ParamsFor<alias<u32>>(),
Params{ast_alias<ast_f32>, sem_f32}, ParamsFor<alias<f32>>(),
Params{ast_alias<ast_vec3<i32>>, sem_vec3<sem_i32>}, ParamsFor<alias<vec3<i32>>>(),
Params{ast_alias<ast_vec3<u32>>, sem_vec3<sem_u32>}, ParamsFor<alias<vec3<u32>>>(),
Params{ast_alias<ast_vec3<f32>>, sem_vec3<sem_f32>}, ParamsFor<alias<vec3<f32>>>(),
Params{ast_alias<ast_mat3x3<i32>>, sem_mat3x3<sem_i32>}, ParamsFor<alias<mat3x3<i32>>>(),
Params{ast_alias<ast_mat3x3<u32>>, sem_mat3x3<sem_u32>}, ParamsFor<alias<mat3x3<u32>>>(),
Params{ast_alias<ast_mat3x3<f32>>, sem_mat3x3<sem_f32>}, ParamsFor<alias<mat3x3<f32>>>(),
}; };
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
InferTypeTest_FromConstructorExpression, InferTypeTest_FromConstructorExpression,
@ -127,13 +143,11 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) {
// } // }
auto& params = GetParam(); auto& params = GetParam();
auto* rhs_type = params.create_rhs_ast_type(ty); auto* arith_lhs_expr = params.create_rhs_ast_value(*this, 2);
auto* arith_rhs_expr = params.create_rhs_ast_value(*this, 3);
auto* arith_lhs_expr = ConstructValueFilledWith(rhs_type, 2);
auto* arith_rhs_expr = ConstructValueFilledWith(ElementTypeOf(rhs_type), 3);
auto* constructor_expr = Mul(arith_lhs_expr, arith_rhs_expr); auto* constructor_expr = Mul(arith_lhs_expr, arith_rhs_expr);
auto* a = Var("a", nullptr, ast::StorageClass::kNone, constructor_expr); auto* a = Var("a", nullptr, constructor_expr);
// Self-assign 'a' to force the expression to be resolved so we can test its // Self-assign 'a' to force the expression to be resolved so we can test its
// type below // type below
auto* a_ident = Expr("a"); auto* a_ident = Expr("a");
@ -141,25 +155,22 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) {
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(a_ident); auto* got = TypeOf(a_ident);
auto* expected = create<sem::Reference>(params.create_rhs_sem_type(ty), auto* expected = create<sem::Reference>(params.create_rhs_sem_type(*this),
ast::StorageClass::kFunction, ast::StorageClass::kFunction,
ast::Access::kReadWrite); ast::Access::kReadWrite);
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n"; << "expected: " << FriendlyName(expected) << "\n";
} }
static constexpr Params from_arithmetic_expression_cases[] = { static constexpr Params from_arithmetic_expression_cases[] = {
Params{ast_i32, sem_i32}, ParamsFor<i32>(), ParamsFor<u32>(), ParamsFor<f32>(),
Params{ast_u32, sem_u32}, ParamsFor<vec3<f32>>(), ParamsFor<mat3x3<f32>>(),
Params{ast_f32, sem_f32},
Params{ast_vec3<f32>, sem_vec3<sem_f32>},
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
// TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed // TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed
// Params{ty_alias<ty_i32>}, // ParamsFor<alias<ty_i32>>(),
// Params{ty_alias<ty_u32>}, // ParamsFor<alias<ty_u32>>(),
// Params{ty_alias<ty_f32>}, // ParamsFor<alias<ty_f32>>(),
// Params{ty_alias<ty_vec3<f32>>}, // ParamsFor<alias<ty_vec3<f32>>>(),
// Params{ty_alias<ty_mat3x3<f32>>}, // ParamsFor<alias<ty_mat3x3<f32>>>(),
}; };
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
InferTypeTest_FromArithmeticExpression, InferTypeTest_FromArithmeticExpression,
@ -170,7 +181,7 @@ TEST_P(InferTypeTest_FromCallExpression, All) {
// e.g. for vec3<f32> // e.g. for vec3<f32>
// //
// fn foo() -> vec3<f32> { // fn foo() -> vec3<f32> {
// return vec3<f32>(0.0, 0.0, 0.0); // return vec3<f32>();
// } // }
// //
// fn bar() // fn bar()
@ -179,11 +190,10 @@ TEST_P(InferTypeTest_FromCallExpression, All) {
// } // }
auto& params = GetParam(); auto& params = GetParam();
Func("foo", {}, params.create_rhs_ast_type(ty), Func("foo", {}, params.create_rhs_ast_type(*this),
{Return(ConstructValueFilledWith(params.create_rhs_ast_type(ty), 0))}, {Return(Construct(params.create_rhs_ast_type(*this)))}, {});
{});
auto* a = Var("a", nullptr, ast::StorageClass::kNone, Call(Expr("foo"))); auto* a = Var("a", nullptr, Call("foo"));
// Self-assign 'a' to force the expression to be resolved so we can test its // Self-assign 'a' to force the expression to be resolved so we can test its
// type below // type below
auto* a_ident = Expr("a"); auto* a_ident = Expr("a");
@ -191,33 +201,33 @@ TEST_P(InferTypeTest_FromCallExpression, All) {
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(a_ident); auto* got = TypeOf(a_ident);
auto* expected = create<sem::Reference>(params.create_rhs_sem_type(ty), auto* expected = create<sem::Reference>(params.create_rhs_sem_type(*this),
ast::StorageClass::kFunction, ast::StorageClass::kFunction,
ast::Access::kReadWrite); ast::Access::kReadWrite);
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n"; << "expected: " << FriendlyName(expected) << "\n";
} }
static constexpr Params from_call_expression_cases[] = { static constexpr Params from_call_expression_cases[] = {
Params{ast_bool, sem_bool}, ParamsFor<bool>(),
Params{ast_i32, sem_i32}, ParamsFor<i32>(),
Params{ast_u32, sem_u32}, ParamsFor<u32>(),
Params{ast_f32, sem_f32}, ParamsFor<f32>(),
Params{ast_vec3<i32>, sem_vec3<sem_i32>}, ParamsFor<vec3<i32>>(),
Params{ast_vec3<u32>, sem_vec3<sem_u32>}, ParamsFor<vec3<u32>>(),
Params{ast_vec3<f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>>(),
Params{ast_mat3x3<i32>, sem_mat3x3<sem_i32>}, ParamsFor<mat3x3<i32>>(),
Params{ast_mat3x3<u32>, sem_mat3x3<sem_u32>}, ParamsFor<mat3x3<u32>>(),
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>}, ParamsFor<mat3x3<f32>>(),
Params{ast_alias<ast_bool>, sem_bool}, ParamsFor<alias<bool>>(),
Params{ast_alias<ast_i32>, sem_i32}, ParamsFor<alias<i32>>(),
Params{ast_alias<ast_u32>, sem_u32}, ParamsFor<alias<u32>>(),
Params{ast_alias<ast_f32>, sem_f32}, ParamsFor<alias<f32>>(),
Params{ast_alias<ast_vec3<i32>>, sem_vec3<sem_i32>}, ParamsFor<alias<vec3<i32>>>(),
Params{ast_alias<ast_vec3<u32>>, sem_vec3<sem_u32>}, ParamsFor<alias<vec3<u32>>>(),
Params{ast_alias<ast_vec3<f32>>, sem_vec3<sem_f32>}, ParamsFor<alias<vec3<f32>>>(),
Params{ast_alias<ast_mat3x3<i32>>, sem_mat3x3<sem_i32>}, ParamsFor<alias<mat3x3<i32>>>(),
Params{ast_alias<ast_mat3x3<u32>>, sem_mat3x3<sem_u32>}, ParamsFor<alias<mat3x3<u32>>>(),
Params{ast_alias<ast_mat3x3<f32>>, sem_mat3x3<sem_f32>}, ParamsFor<alias<mat3x3<f32>>>(),
}; };
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
InferTypeTest_FromCallExpression, InferTypeTest_FromCallExpression,

View File

@ -27,6 +27,33 @@ namespace tint {
namespace resolver { namespace resolver {
namespace { namespace {
// Helpers and typedefs
template <typename T>
using DataType = builder::DataType<T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T>
using alias = builder::alias<T>;
template <typename T>
using alias1 = builder::alias1<T>;
template <typename T>
using alias2 = builder::alias2<T>;
template <typename T>
using alias3 = builder::alias3<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
class ResolverTypeValidationTest : public resolver::TestHelper, class ResolverTypeValidationTest : public resolver::TestHelper,
public testing::Test {}; public testing::Test {};
@ -366,43 +393,44 @@ TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) {
namespace GetCanonicalTests { namespace GetCanonicalTests {
struct Params { struct Params {
create_ast_type_func_ptr create_ast_type; builder::ast_type_func_ptr create_ast_type;
create_sem_type_func_ptr create_sem_type; builder::sem_type_func_ptr create_sem_type;
}; };
template <typename T>
constexpr Params ParamsFor() {
return Params{DataType<T>::AST, DataType<T>::Sem};
}
static constexpr Params cases[] = { static constexpr Params cases[] = {
Params{ast_bool, sem_bool}, ParamsFor<bool>(),
Params{ast_alias<ast_bool>, sem_bool}, ParamsFor<alias<bool>>(),
Params{ast_alias<ast_alias<ast_bool>>, sem_bool}, ParamsFor<alias1<alias<bool>>>(),
Params{ast_vec3<ast_f32>, sem_vec3<sem_f32>}, ParamsFor<vec3<f32>>(),
Params{ast_alias<ast_vec3<ast_f32>>, sem_vec3<sem_f32>}, ParamsFor<alias<vec3<f32>>>(),
Params{ast_alias<ast_alias<ast_vec3<ast_f32>>>, sem_vec3<sem_f32>}, ParamsFor<alias1<alias<vec3<f32>>>>(),
Params{ast_vec3<ast_alias<ast_f32>>, sem_vec3<sem_f32>}, ParamsFor<vec3<alias<f32>>>(),
Params{ast_alias<ast_vec3<ast_alias<ast_f32>>>, sem_vec3<sem_f32>}, ParamsFor<alias1<vec3<alias<f32>>>>(),
Params{ast_alias<ast_alias<ast_vec3<ast_alias<ast_f32>>>>, ParamsFor<alias2<alias1<vec3<alias<f32>>>>>(),
sem_vec3<sem_f32>}, ParamsFor<alias3<alias2<vec3<alias1<alias<f32>>>>>>(),
Params{ast_alias<ast_alias<ast_vec3<ast_alias<ast_alias<ast_f32>>>>>,
sem_vec3<sem_f32>},
Params{ast_mat3x3<ast_alias<ast_f32>>, sem_mat3x3<sem_f32>}, ParamsFor<mat3x3<alias<f32>>>(),
Params{ast_alias<ast_mat3x3<ast_alias<ast_f32>>>, sem_mat3x3<sem_f32>}, ParamsFor<alias1<mat3x3<alias<f32>>>>(),
Params{ast_alias<ast_alias<ast_mat3x3<ast_alias<ast_f32>>>>, ParamsFor<alias2<alias1<mat3x3<alias<f32>>>>>(),
sem_mat3x3<sem_f32>}, ParamsFor<alias3<alias2<mat3x3<alias1<alias<f32>>>>>>(),
Params{ast_alias<ast_alias<ast_mat3x3<ast_alias<ast_alias<ast_f32>>>>>,
sem_mat3x3<sem_f32>},
Params{ast_alias<ast_alias<ast_bool>>, sem_bool}, ParamsFor<alias1<alias<bool>>>(),
Params{ast_alias<ast_alias<ast_vec3<ast_f32>>>, sem_vec3<sem_f32>}, ParamsFor<alias1<alias<vec3<f32>>>>(),
Params{ast_alias<ast_alias<ast_mat3x3<ast_f32>>>, sem_mat3x3<sem_f32>}, ParamsFor<alias1<alias<mat3x3<f32>>>>(),
}; };
using CanonicalTest = ResolverTestWithParam<Params>; using CanonicalTest = ResolverTestWithParam<Params>;
TEST_P(CanonicalTest, All) { TEST_P(CanonicalTest, All) {
auto& params = GetParam(); auto& params = GetParam();
auto* type = params.create_ast_type(ty); auto* type = params.create_ast_type(*this);
auto* var = Var("v", type); auto* var = Var("v", type);
auto* expr = Expr("v"); auto* expr = Expr("v");
@ -411,7 +439,7 @@ TEST_P(CanonicalTest, All) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(expr)->UnwrapRef(); auto* got = TypeOf(expr)->UnwrapRef();
auto* expected = params.create_sem_type(ty); auto* expected = params.create_sem_type(*this);
EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n"; << "expected: " << FriendlyName(expected) << "\n";
@ -459,38 +487,44 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
testing::ValuesIn(dimension_cases)); testing::ValuesIn(dimension_cases));
struct TypeParams { struct TypeParams {
create_ast_type_func_ptr type_func; builder::ast_type_func_ptr type_func;
bool is_valid; bool is_valid;
}; };
template <typename T>
constexpr TypeParams TypeParamsFor(bool is_valid) {
return TypeParams{DataType<T>::AST, is_valid};
}
static constexpr TypeParams type_cases[] = { static constexpr TypeParams type_cases[] = {
TypeParams{ast_bool, false}, TypeParamsFor<bool>(false),
TypeParams{ast_i32, true}, TypeParamsFor<i32>(true),
TypeParams{ast_u32, true}, TypeParamsFor<u32>(true),
TypeParams{ast_f32, true}, TypeParamsFor<f32>(true),
TypeParams{ast_alias<ast_bool>, false}, TypeParamsFor<alias<bool>>(false),
TypeParams{ast_alias<ast_i32>, true}, TypeParamsFor<alias<i32>>(true),
TypeParams{ast_alias<ast_u32>, true}, TypeParamsFor<alias<u32>>(true),
TypeParams{ast_alias<ast_f32>, true}, TypeParamsFor<alias<f32>>(true),
TypeParams{ast_vec3<ast_f32>, false}, TypeParamsFor<vec3<f32>>(false),
TypeParams{ast_mat3x3<ast_f32>, false}, TypeParamsFor<mat3x3<f32>>(false),
TypeParams{ast_alias<ast_vec3<ast_f32>>, false}, TypeParamsFor<alias<vec3<f32>>>(false),
TypeParams{ast_alias<ast_mat3x3<ast_f32>>, false}}; TypeParamsFor<alias<mat3x3<f32>>>(false),
};
using MultisampledTextureTypeTest = ResolverTestWithParam<TypeParams>; using MultisampledTextureTypeTest = ResolverTestWithParam<TypeParams>;
TEST_P(MultisampledTextureTypeTest, All) { TEST_P(MultisampledTextureTypeTest, All) {
auto& params = GetParam(); auto& params = GetParam();
Global( Global(Source{{12, 34}}, "a",
Source{{12, 34}}, "a", ty.multisampled_texture(ast::TextureDimension::k2d,
ty.multisampled_texture(ast::TextureDimension::k2d, params.type_func(ty)), params.type_func(*this)),
ast::StorageClass::kNone, nullptr, ast::StorageClass::kNone, nullptr,
ast::DecorationList{ ast::DecorationList{
create<ast::BindingDecoration>(0), create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0), create<ast::GroupDecoration>(0),
}); });
if (params.is_valid) { if (params.is_valid) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();

View File

@ -378,10 +378,11 @@ TEST_F(BuilderTest, MemberAccessor_Nested_NonPointer) {
} }
TEST_F(BuilderTest, MemberAccessor_Nested_WithAlias) { TEST_F(BuilderTest, MemberAccessor_Nested_WithAlias) {
// type Inner = struct { // struct Inner {
// a : f32 // a : f32
// b : f32 // b : f32
// } // };
// type Alias = Inner;
// my_struct { // my_struct {
// inner : Inner // inner : Inner
// } // }
@ -393,7 +394,8 @@ TEST_F(BuilderTest, MemberAccessor_Nested_WithAlias) {
Member("b", ty.f32()), Member("b", ty.f32()),
}); });
auto* alias = ty.alias("Inner", inner_struct); auto* alias = ty.alias("Alias", inner_struct);
AST().AddConstructedType(alias);
auto* s_type = Structure("Outer", {Member("inner", alias)}); auto* s_type = Structure("Outer", {Member("inner", alias)});
auto* var = Var("ident", s_type); auto* var = Var("ident", s_type);

View File

@ -23,6 +23,7 @@ using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, EmitAlias_F32) { TEST_F(WgslGeneratorImplTest, EmitAlias_F32) {
auto* alias = ty.alias("a", ty.f32()); auto* alias = ty.alias("a", ty.f32());
AST().AddConstructedType(alias);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitConstructedType(alias)) << gen.error(); ASSERT_TRUE(gen.EmitConstructedType(alias)) << gen.error();
@ -37,6 +38,7 @@ TEST_F(WgslGeneratorImplTest, EmitConstructedType_Struct) {
}); });
auto* alias = ty.alias("B", s); auto* alias = ty.alias("B", s);
AST().AddConstructedType(alias);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -57,6 +59,7 @@ TEST_F(WgslGeneratorImplTest, EmitAlias_ToStruct) {
}); });
auto* alias = ty.alias("B", s); auto* alias = ty.alias("B", s);
AST().AddConstructedType(alias);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -250,6 +250,7 @@ struct S {
TEST_F(WgslGeneratorImplTest, EmitType_U32) { TEST_F(WgslGeneratorImplTest, EmitType_U32) {
auto* u32 = ty.u32(); auto* u32 = ty.u32();
AST().AddConstructedType(ty.alias("make_type_reachable", u32));
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -402,6 +403,11 @@ TEST_P(WgslGenerator_StorageTextureTest, EmitType_StorageTexture) {
auto param = GetParam(); auto param = GetParam();
auto* t = ty.storage_texture(param.dim, param.fmt, param.access); auto* t = ty.storage_texture(param.dim, param.fmt, param.access);
Global("g", t,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -412,30 +418,30 @@ INSTANTIATE_TEST_SUITE_P(
WgslGeneratorImplTest, WgslGeneratorImplTest,
WgslGenerator_StorageTextureTest, WgslGenerator_StorageTextureTest,
testing::Values( testing::Values(
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k1d, ast::Access::kRead, ast::TextureDimension::k1d, ast::Access::kRead,
"texture_storage_1d<r8unorm, read>"}, "texture_storage_1d<rgba8sint, read>"},
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k2d, ast::Access::kRead, ast::TextureDimension::k2d, ast::Access::kRead,
"texture_storage_2d<r8unorm, read>"}, "texture_storage_2d<rgba8sint, read>"},
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k2dArray, ast::Access::kRead, ast::TextureDimension::k2dArray, ast::Access::kRead,
"texture_storage_2d_array<r8unorm, read>"}, "texture_storage_2d_array<rgba8sint, read>"},
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k3d, ast::Access::kRead, ast::TextureDimension::k3d, ast::Access::kRead,
"texture_storage_3d<r8unorm, read>"}, "texture_storage_3d<rgba8sint, read>"},
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k1d, ast::Access::kWrite, ast::TextureDimension::k1d, ast::Access::kWrite,
"texture_storage_1d<r8unorm, write>"}, "texture_storage_1d<rgba8sint, write>"},
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k2d, ast::Access::kWrite, ast::TextureDimension::k2d, ast::Access::kWrite,
"texture_storage_2d<r8unorm, write>"}, "texture_storage_2d<rgba8sint, write>"},
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k2dArray, ast::Access::kWrite, ast::TextureDimension::k2dArray, ast::Access::kWrite,
"texture_storage_2d_array<r8unorm, write>"}, "texture_storage_2d_array<rgba8sint, write>"},
StorageTextureData{ast::ImageFormat::kR8Unorm, StorageTextureData{ast::ImageFormat::kRgba8Sint,
ast::TextureDimension::k3d, ast::Access::kWrite, ast::TextureDimension::k3d, ast::Access::kWrite,
"texture_storage_3d<r8unorm, write>"})); "texture_storage_3d<rgba8sint, write>"}));
struct ImageFormatData { struct ImageFormatData {
ast::ImageFormat fmt; ast::ImageFormat fmt;