From c57f7257978daf17f04f5b9a4ecc1afe46fd6774 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 9 Jun 2021 07:48:17 +0000 Subject: [PATCH] resolver: Enable AST type reachability checks Required a lot of test fixes. ProgramBuilder: :ConstructValueFilledWith() was a major source of unreached AST types, and this has been removed with more powerful type-building helpers in resolver_test_helper.h. Change-Id: I1f2007cdaef7f319ab4ef8b4fb8c37687a0fb5d8 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53800 Reviewed-by: Antonio Maiorano Kokoro: Kokoro Commit-Queue: Ben Clayton --- src/program_builder.cc | 53 -- src/program_builder.h | 11 - src/resolver/decoration_validation_test.cc | 113 +++-- src/resolver/entry_point_validation_test.cc | 84 ++-- src/resolver/inferred_type_test.cc | 95 ++-- src/resolver/resolver.cc | 4 +- src/resolver/resolver_test.cc | 468 ++++++++++-------- src/resolver/resolver_test_helper.h | 369 +++++++++----- .../type_constructor_validation_test.cc | 172 ++++--- src/resolver/type_validation_test.cc | 126 +++-- .../spirv/builder_accessor_expression_test.cc | 8 +- .../wgsl/generator_impl_alias_type_test.cc | 3 + src/writer/wgsl/generator_impl_type_test.cc | 38 +- 13 files changed, 870 insertions(+), 674 deletions(-) diff --git a/src/program_builder.cc b/src/program_builder.cc index a85ba4dbc7..9a2cb09134 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -102,59 +102,6 @@ const sem::Type* ProgramBuilder::TypeOf(const ast::Type* type) const { return Sem().Get(type); } -ast::ConstructorExpression* ProgramBuilder::ConstructValueFilledWith( - const ast::Type* type, - int elem_value) { - CloneContext ctx(this); - - if (type->Is()) { - return create( - create(elem_value == 0 ? false : true)); - } - if (type->Is()) { - return create( - create(static_cast(elem_value))); - } - if (type->Is()) { - return create( - create(static_cast(elem_value))); - } - if (type->Is()) { - return create( - create(static_cast(elem_value))); - } - if (auto* v = type->As()) { - ast::ExpressionList el(v->size()); - for (size_t i = 0; i < el.size(); i++) { - el[i] = ConstructValueFilledWith(ctx.Clone(v->type()), elem_value); - } - return create(const_cast(type), - std::move(el)); - } - if (auto* m = type->As()) { - ast::ExpressionList el(m->columns()); - for (size_t i = 0; i < el.size(); i++) { - auto* col_vec_type = create(ctx.Clone(m->type()), m->rows()); - el[i] = ConstructValueFilledWith(col_vec_type, elem_value); - } - return create(const_cast(type), - std::move(el)); - } - if (auto* tn = type->As()) { - if (auto* lookup = AST().LookupType(tn->name())) { - if (auto* alias = lookup->As()) { - return ConstructValueFilledWith(ctx.Clone(alias->type()), elem_value); - } - } - TINT_ICE(diagnostics_) << "unable to find NamedType '" - << Symbols().NameFor(tn->name()) << "'"; - return nullptr; - } - - TINT_ICE(diagnostics_) << "unhandled type: " << type->TypeInfo().name; - return nullptr; -} - ast::Type* ProgramBuilder::TypesBuilder::MaybeCreateTypename( ast::Type* type) const { if (auto* nt = As(type)) { diff --git a/src/program_builder.h b/src/program_builder.h index 8a38e182f6..82814a797e 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -1078,17 +1078,6 @@ class ProgramBuilder { type, ExprList(std::forward(args)...)); } - /// Creates a constructor expression that constructs an object of - /// `type` filled with `elem_value`. For example, - /// ConstructValueFilledWith(ty.mat3x4(), 5) returns a - /// TypeConstructorExpression for a Mat3x4 filled with 5.0f values. - /// @param type the type to construct - /// @param elem_value the initial or element value (for vec and mat) to - /// construct with - /// @return the constructor expression - ast::ConstructorExpression* ConstructValueFilledWith(const ast::Type* type, - int elem_value = 0); - /// @param args the arguments for the vector constructor /// @return an `ast::TypeConstructorExpression` of a 2-element vector of type /// `T`, constructed with the values `args`. diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index 1de719b092..614fec94a3 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -26,6 +26,33 @@ namespace tint { namespace resolver { +// Helpers and typedefs +template +using DataType = builder::DataType; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using mat2x2 = builder::mat2x2; +template +using mat3x3 = builder::mat3x3; +template +using mat4x4 = builder::mat4x4; +template +using alias = builder::alias; +template +using alias1 = builder::alias1; +template +using alias2 = builder::alias2; +template +using alias3 = builder::alias3; +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; + namespace DecorationTests { namespace { @@ -357,17 +384,22 @@ namespace ArrayStrideTests { namespace { struct Params { - create_ast_type_func_ptr create_el_type; + builder::ast_type_func_ptr create_el_type; uint32_t stride; bool should_pass; }; +template +constexpr Params ParamsFor(uint32_t stride, bool should_pass) { + return Params{DataType::AST, stride, should_pass}; +} + struct TestWithParams : ResolverTestWithParam {}; using ArrayStrideTest = TestWithParams; TEST_P(ArrayStrideTest, All) { auto& params = GetParam(); - auto* el_ty = params.create_el_type(ty); + auto* el_ty = params.create_el_type(*this); std::stringstream ss; ss << "el_ty: " << FriendlyName(el_ty) << ", stride: " << params.stride @@ -389,11 +421,6 @@ TEST_P(ArrayStrideTest, All) { } } -// Helpers and typedefs -using i32 = ProgramBuilder::i32; -using u32 = ProgramBuilder::u32; -using f32 = ProgramBuilder::f32; - struct SizeAndAlignment { uint32_t size; uint32_t align; @@ -414,49 +441,49 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( // Succeed because stride >= element size (while being multiple of // element alignment) - Params{ast_u32, default_u32.size, true}, - Params{ast_i32, default_i32.size, true}, - Params{ast_f32, default_f32.size, true}, - Params{ast_vec2, default_vec2.size, true}, + ParamsFor(default_u32.size, true), + ParamsFor(default_i32.size, true), + ParamsFor(default_f32.size, true), + ParamsFor>(default_vec2.size, true), // vec3's default size is not a multiple of its alignment - // Params{ast_vec3, default_vec3.size, true}, - Params{ast_vec4, default_vec4.size, true}, - Params{ast_mat2x2, default_mat2x2.size, true}, - Params{ast_mat3x3, default_mat3x3.size, true}, - Params{ast_mat4x4, default_mat4x4.size, true}, + // ParamsFor, default_vec3.size, true}, + ParamsFor>(default_vec4.size, true), + ParamsFor>(default_mat2x2.size, true), + ParamsFor>(default_mat3x3.size, true), + ParamsFor>(default_mat4x4.size, true), // Fail because stride is < element size - Params{ast_u32, default_u32.size - 1, false}, - Params{ast_i32, default_i32.size - 1, false}, - Params{ast_f32, default_f32.size - 1, false}, - Params{ast_vec2, default_vec2.size - 1, false}, - Params{ast_vec3, default_vec3.size - 1, false}, - Params{ast_vec4, default_vec4.size - 1, false}, - Params{ast_mat2x2, default_mat2x2.size - 1, false}, - Params{ast_mat3x3, default_mat3x3.size - 1, false}, - Params{ast_mat4x4, default_mat4x4.size - 1, false}, + ParamsFor(default_u32.size - 1, false), + ParamsFor(default_i32.size - 1, false), + ParamsFor(default_f32.size - 1, false), + ParamsFor>(default_vec2.size - 1, false), + ParamsFor>(default_vec3.size - 1, false), + ParamsFor>(default_vec4.size - 1, false), + ParamsFor>(default_mat2x2.size - 1, false), + ParamsFor>(default_mat3x3.size - 1, false), + ParamsFor>(default_mat4x4.size - 1, false), // Succeed because stride equals multiple of element alignment - Params{ast_u32, default_u32.align * 7, true}, - Params{ast_i32, default_i32.align * 7, true}, - Params{ast_f32, default_f32.align * 7, true}, - Params{ast_vec2, default_vec2.align * 7, true}, - Params{ast_vec3, default_vec3.align * 7, true}, - Params{ast_vec4, default_vec4.align * 7, true}, - Params{ast_mat2x2, default_mat2x2.align * 7, true}, - Params{ast_mat3x3, default_mat3x3.align * 7, true}, - Params{ast_mat4x4, default_mat4x4.align * 7, true}, + ParamsFor(default_u32.align * 7, true), + ParamsFor(default_i32.align * 7, true), + ParamsFor(default_f32.align * 7, true), + ParamsFor>(default_vec2.align * 7, true), + ParamsFor>(default_vec3.align * 7, true), + ParamsFor>(default_vec4.align * 7, true), + ParamsFor>(default_mat2x2.align * 7, true), + ParamsFor>(default_mat3x3.align * 7, true), + ParamsFor>(default_mat4x4.align * 7, true), // Fail because stride is not multiple of element alignment - Params{ast_u32, (default_u32.align - 1) * 7, false}, - Params{ast_i32, (default_i32.align - 1) * 7, false}, - Params{ast_f32, (default_f32.align - 1) * 7, false}, - Params{ast_vec2, (default_vec2.align - 1) * 7, false}, - Params{ast_vec3, (default_vec3.align - 1) * 7, false}, - Params{ast_vec4, (default_vec4.align - 1) * 7, false}, - Params{ast_mat2x2, (default_mat2x2.align - 1) * 7, false}, - Params{ast_mat3x3, (default_mat3x3.align - 1) * 7, false}, - Params{ast_mat4x4, (default_mat4x4.align - 1) * 7, false})); + ParamsFor((default_u32.align - 1) * 7, false), + ParamsFor((default_i32.align - 1) * 7, false), + ParamsFor((default_f32.align - 1) * 7, false), + ParamsFor>((default_vec2.align - 1) * 7, false), + ParamsFor>((default_vec3.align - 1) * 7, false), + ParamsFor>((default_vec4.align - 1) * 7, false), + ParamsFor>((default_mat2x2.align - 1) * 7, false), + ParamsFor>((default_mat3x3.align - 1) * 7, false), + ParamsFor>((default_mat4x4.align - 1) * 7, false))); TEST_F(ArrayStrideTest, MultipleDecorations) { auto* arr = ty.array(Source{{12, 34}}, ty.i32(), 4, diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc index 3a7eef33ac..80d910bd8c 100644 --- a/src/resolver/entry_point_validation_test.cc +++ b/src/resolver/entry_point_validation_test.cc @@ -26,6 +26,27 @@ namespace tint { namespace resolver { namespace { +// Helpers and typedefs +template +using DataType = builder::DataType; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using mat2x2 = builder::mat2x2; +template +using mat3x3 = builder::mat3x3; +template +using mat4x4 = builder::mat4x4; +template +using alias = builder::alias; +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; + class ResolverEntryPointValidationTest : public TestHelper, public testing::Test {}; @@ -517,43 +538,48 @@ TEST_F(ResolverEntryPointValidationTest, namespace TypeValidationTests { struct Params { - create_ast_type_func_ptr create_ast_type; + builder::ast_type_func_ptr create_ast_type; bool is_valid; }; +template +constexpr Params ParamsFor(bool is_valid) { + return Params{DataType::AST, is_valid}; +} + using TypeValidationTest = resolver::ResolverTestWithParam; static constexpr Params cases[] = { - {ast_f32, true}, - {ast_i32, true}, - {ast_u32, true}, - {ast_bool, false}, - {ast_vec2, true}, - {ast_vec3, true}, - {ast_vec4, true}, - {ast_mat2x2, false}, - {ast_mat2x2, false}, - {ast_mat2x2, false}, - {ast_mat2x2, false}, - {ast_mat3x3, false}, - {ast_mat3x3, false}, - {ast_mat3x3, false}, - {ast_mat3x3, false}, - {ast_mat4x4, false}, - {ast_mat4x4, false}, - {ast_mat4x4, false}, - {ast_mat4x4, false}, - {ast_alias, true}, - {ast_alias, true}, - {ast_alias, true}, - {ast_alias, false}, + ParamsFor(true), // + ParamsFor(true), // + ParamsFor(true), // + ParamsFor(false), // + ParamsFor>(true), // + ParamsFor>(true), // + ParamsFor>(true), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(false), // + ParamsFor>(true), // + ParamsFor>(true), // + ParamsFor>(true), // + ParamsFor>(false), // }; TEST_P(TypeValidationTest, BareInputs) { // [[stage(fragment)]] // fn main([[location(0)]] a : *) {} auto params = GetParam(); - auto* a = Param("a", params.create_ast_type(ty), {Location(0)}); + auto* a = Param("a", params.create_ast_type(*this), {Location(0)}); Func(Source{{12, 34}}, "main", {a}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -572,7 +598,7 @@ TEST_P(TypeValidationTest, StructInputs) { // fn main(a : Input) {} auto params = GetParam(); auto* input = Structure( - "Input", {Member("a", params.create_ast_type(ty), {Location(0)})}); + "Input", {Member("a", params.create_ast_type(*this), {Location(0)})}); auto* a = Param("a", input, {}); Func(Source{{12, 34}}, "main", {a}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); @@ -590,8 +616,8 @@ TEST_P(TypeValidationTest, BareOutputs) { // return *(); // } auto params = GetParam(); - Func(Source{{12, 34}}, "main", {}, params.create_ast_type(ty), - {Return(Construct(params.create_ast_type(ty)))}, + Func(Source{{12, 34}}, "main", {}, params.create_ast_type(*this), + {Return(Construct(params.create_ast_type(*this)))}, {Stage(ast::PipelineStage::kFragment)}, {Location(0)}); if (params.is_valid) { @@ -611,7 +637,7 @@ TEST_P(TypeValidationTest, StructOutputs) { // } auto params = GetParam(); auto* output = Structure( - "Output", {Member("a", params.create_ast_type(ty), {Location(0)})}); + "Output", {Member("a", params.create_ast_type(*this), {Location(0)})}); Func(Source{{12, 34}}, "main", {}, output, {Return(Construct(output))}, {Stage(ast::PipelineStage::kFragment)}); diff --git a/src/resolver/inferred_type_test.cc b/src/resolver/inferred_type_test.cc index caceb4321f..538850deea 100644 --- a/src/resolver/inferred_type_test.cc +++ b/src/resolver/inferred_type_test.cc @@ -23,42 +23,62 @@ namespace resolver { namespace { // Helpers and typedefs -using i32 = ProgramBuilder::i32; -using u32 = ProgramBuilder::u32; -using f32 = ProgramBuilder::f32; +template +using DataType = builder::DataType; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using mat2x2 = builder::mat2x2; +template +using mat3x3 = builder::mat3x3; +template +using mat4x4 = builder::mat4x4; +template +using alias = builder::alias; +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; struct ResolverInferredTypeTest : public resolver::TestHelper, public testing::Test {}; struct Params { - create_ast_type_func_ptr create_type; - create_sem_type_func_ptr create_expected_type; + builder::ast_expr_func_ptr create_value; + builder::sem_type_func_ptr create_expected_type; }; -Params all_cases[] = { - {ast_bool, sem_bool}, - {ast_u32, sem_u32}, - {ast_i32, sem_i32}, - {ast_f32, sem_f32}, - {ast_vec3, sem_vec3}, - {ast_vec3, sem_vec3}, - {ast_vec3, sem_vec3}, - {ast_vec3, sem_vec3}, - {ast_mat3x3, sem_mat3x3}, - {ast_mat3x3, sem_mat3x3}, - {ast_mat3x3, sem_mat3x3}, +template +constexpr Params ParamsFor() { + return Params{DataType::Expr, DataType::Sem}; +} - {ast_alias, sem_bool}, - {ast_alias, sem_u32}, - {ast_alias, sem_i32}, - {ast_alias, sem_f32}, - {ast_alias>, sem_vec3}, - {ast_alias>, sem_vec3}, - {ast_alias>, sem_vec3}, - {ast_alias>, sem_vec3}, - {ast_alias>, sem_mat3x3}, - {ast_alias>, sem_mat3x3}, - {ast_alias>, sem_mat3x3}, +Params all_cases[] = { + ParamsFor(), // + ParamsFor(), // + ParamsFor(), // + ParamsFor(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>(), // + ParamsFor>>(), // + ParamsFor>>(), // + ParamsFor>>(), // + ParamsFor>>(), // + ParamsFor>>(), // + ParamsFor>>(), // + ParamsFor>>(), // }; using ResolverInferredTypeParamTest = ResolverTestWithParam; @@ -66,11 +86,10 @@ using ResolverInferredTypeParamTest = ResolverTestWithParam; TEST_P(ResolverInferredTypeParamTest, GlobalLet_Pass) { auto& params = GetParam(); - auto* type = params.create_type(ty); - auto* expected_type = params.create_expected_type(ty); + auto* expected_type = params.create_expected_type(*this); // let a = ; - auto* ctor_expr = ConstructValueFilledWith(type); + auto* ctor_expr = params.create_value(*this, 0); auto* var = GlobalConst("a", nullptr, ctor_expr); WrapInFunction(); @@ -81,10 +100,8 @@ TEST_P(ResolverInferredTypeParamTest, GlobalLet_Pass) { TEST_P(ResolverInferredTypeParamTest, GlobalVar_Fail) { auto& params = GetParam(); - auto* type = params.create_type(ty); - // var a = ; - auto* ctor_expr = ConstructValueFilledWith(type); + auto* ctor_expr = params.create_value(*this, 0); Global(Source{{12, 34}}, "a", nullptr, ast::StorageClass::kPrivate, ctor_expr); WrapInFunction(); @@ -97,11 +114,10 @@ TEST_P(ResolverInferredTypeParamTest, GlobalVar_Fail) { TEST_P(ResolverInferredTypeParamTest, LocalLet_Pass) { auto& params = GetParam(); - auto* type = params.create_type(ty); - auto* expected_type = params.create_expected_type(ty); + auto* expected_type = params.create_expected_type(*this); // let a = ; - auto* ctor_expr = ConstructValueFilledWith(type); + auto* ctor_expr = params.create_value(*this, 0); auto* var = Const("a", nullptr, ctor_expr); WrapInFunction(var); @@ -112,11 +128,10 @@ TEST_P(ResolverInferredTypeParamTest, LocalLet_Pass) { TEST_P(ResolverInferredTypeParamTest, LocalVar_Pass) { auto& params = GetParam(); - auto* type = params.create_type(ty); - auto* expected_type = params.create_expected_type(ty); + auto* expected_type = params.create_expected_type(*this); // var a = ; - auto* ctor_expr = ConstructValueFilledWith(type); + auto* ctor_expr = params.create_value(*this, 0); auto* var = Var("a", nullptr, ast::StorageClass::kFunction, ctor_expr); WrapInFunction(var); diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index e6c6875b05..ff4f455b40 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -245,8 +245,7 @@ bool Resolver::ResolveInternal() { for (auto* node : builder_->ASTNodes().Objects()) { if (marked_.count(node) == 0) { - if (node->IsAnyOf()) { + if (node->IsAnyOf()) { // TODO(crbug.com/tint/724) - Remove once tint:724 is complete. // ast::AccessDecorations are generated by the WGSL parser, used to // build sem::AccessControls and then leaked. @@ -254,7 +253,6 @@ bool Resolver::ResolveInternal() { // multiple arrays of the same stride, size and element type are // currently de-duplicated by the type manager, and we leak these // decorations. - // ast::Types are being built, but not yet being handled. This is WIP. continue; } TINT_ICE(diagnostics_) << "AST node '" << node->TypeInfo().name diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index ba3ef08763..91b819594d 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -51,9 +51,39 @@ namespace resolver { namespace { // Helpers and typedefs -using i32 = ProgramBuilder::i32; -using u32 = ProgramBuilder::u32; -using f32 = ProgramBuilder::f32; +template +using DataType = builder::DataType; +template +using vec = builder::vec; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using mat = builder::mat; +template +using mat2x2 = builder::mat2x2; +template +using mat2x3 = builder::mat2x3; +template +using mat3x2 = builder::mat3x2; +template +using mat3x3 = builder::mat3x3; +template +using mat4x4 = builder::mat4x4; +template +using alias = builder::alias; +template +using alias1 = builder::alias1; +template +using alias2 = builder::alias2; +template +using alias3 = builder::alias3; +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; using Op = ast::BinaryOp; TEST_F(ResolverTest, Stmt_Assign) { @@ -1209,13 +1239,40 @@ TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) { namespace ExprBinaryTest { +template +struct Aliased { + using type = alias; +}; + +template +struct Aliased, ID> { + using type = vec>; +}; + +template +struct Aliased, ID> { + using type = mat>; +}; + struct Params { ast::BinaryOp op; - create_ast_type_func_ptr create_lhs_type; - create_ast_type_func_ptr create_rhs_type; - create_sem_type_func_ptr create_result_type; + builder::ast_type_func_ptr create_lhs_type; + builder::ast_type_func_ptr create_rhs_type; + builder::ast_type_func_ptr create_lhs_alias_type; + builder::ast_type_func_ptr create_rhs_alias_type; + builder::sem_type_func_ptr create_result_type; }; +template +constexpr Params ParamsFor(ast::BinaryOp op) { + return Params{op, + DataType::AST, + DataType::AST, + DataType::type>::AST, + DataType::type>::AST, + DataType::Sem}; +} + static constexpr ast::BinaryOp all_ops[] = { ast::BinaryOp::kAnd, ast::BinaryOp::kOr, @@ -1237,12 +1294,24 @@ static constexpr ast::BinaryOp all_ops[] = { ast::BinaryOp::kModulo, }; -static constexpr create_ast_type_func_ptr all_create_type_funcs[] = { - ast_bool, ast_u32, ast_i32, ast_f32, - ast_vec3, ast_vec3, ast_vec3, ast_vec3, - ast_mat3x3, ast_mat3x3, ast_mat3x3, // - ast_mat2x3, ast_mat2x3, ast_mat2x3, // - ast_mat3x2, ast_mat3x2, ast_mat3x2 // +static constexpr builder::ast_type_func_ptr all_create_type_funcs[] = { + DataType::AST, // + DataType::AST, // + DataType::AST, // + DataType::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST, // + DataType>::AST // }; // A list of all valid test cases for 'lhs op rhs', except that for vecN and @@ -1252,229 +1321,216 @@ static constexpr Params all_valid_cases[] = { // https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr // Binary logical expressions - Params{Op::kLogicalAnd, ast_bool, ast_bool, sem_bool}, - Params{Op::kLogicalOr, ast_bool, ast_bool, sem_bool}, + ParamsFor(Op::kLogicalAnd), + ParamsFor(Op::kLogicalOr), - Params{Op::kAnd, ast_bool, ast_bool, sem_bool}, - Params{Op::kOr, ast_bool, ast_bool, sem_bool}, - Params{Op::kAnd, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kOr, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor(Op::kAnd), + ParamsFor(Op::kOr), + ParamsFor, vec3, vec3>(Op::kAnd), + ParamsFor, vec3, vec3>(Op::kOr), // Arithmetic expressions // https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr // Binary arithmetic expressions over scalars - Params{Op::kAdd, ast_i32, ast_i32, sem_i32}, - Params{Op::kSubtract, ast_i32, ast_i32, sem_i32}, - Params{Op::kMultiply, ast_i32, ast_i32, sem_i32}, - Params{Op::kDivide, ast_i32, ast_i32, sem_i32}, - Params{Op::kModulo, ast_i32, ast_i32, sem_i32}, + ParamsFor(Op::kAdd), + ParamsFor(Op::kSubtract), + ParamsFor(Op::kMultiply), + ParamsFor(Op::kDivide), + ParamsFor(Op::kModulo), - Params{Op::kAdd, ast_u32, ast_u32, sem_u32}, - Params{Op::kSubtract, ast_u32, ast_u32, sem_u32}, - Params{Op::kMultiply, ast_u32, ast_u32, sem_u32}, - Params{Op::kDivide, ast_u32, ast_u32, sem_u32}, - Params{Op::kModulo, ast_u32, ast_u32, sem_u32}, + ParamsFor(Op::kAdd), + ParamsFor(Op::kSubtract), + ParamsFor(Op::kMultiply), + ParamsFor(Op::kDivide), + ParamsFor(Op::kModulo), - Params{Op::kAdd, ast_f32, ast_f32, sem_f32}, - Params{Op::kSubtract, ast_f32, ast_f32, sem_f32}, - Params{Op::kMultiply, ast_f32, ast_f32, sem_f32}, - Params{Op::kDivide, ast_f32, ast_f32, sem_f32}, - Params{Op::kModulo, ast_f32, ast_f32, sem_f32}, + ParamsFor(Op::kAdd), + ParamsFor(Op::kSubtract), + ParamsFor(Op::kMultiply), + ParamsFor(Op::kDivide), + ParamsFor(Op::kModulo), // Binary arithmetic expressions over vectors - Params{Op::kAdd, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kSubtract, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kMultiply, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kDivide, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kModulo, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor, vec3, vec3>(Op::kAdd), + ParamsFor, vec3, vec3>(Op::kSubtract), + ParamsFor, vec3, vec3>(Op::kMultiply), + ParamsFor, vec3, vec3>(Op::kDivide), + ParamsFor, vec3, vec3>(Op::kModulo), - Params{Op::kAdd, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kSubtract, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kMultiply, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kDivide, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kModulo, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor, vec3, vec3>(Op::kAdd), + ParamsFor, vec3, vec3>(Op::kSubtract), + ParamsFor, vec3, vec3>(Op::kMultiply), + ParamsFor, vec3, vec3>(Op::kDivide), + ParamsFor, vec3, vec3>(Op::kModulo), - Params{Op::kAdd, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kSubtract, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kMultiply, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kDivide, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kModulo, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor, vec3, vec3>(Op::kAdd), + ParamsFor, vec3, vec3>(Op::kSubtract), + ParamsFor, vec3, vec3>(Op::kMultiply), + ParamsFor, vec3, vec3>(Op::kDivide), + ParamsFor, vec3, vec3>(Op::kModulo), // Binary arithmetic expressions with mixed scalar and vector operands - Params{Op::kAdd, ast_vec3, ast_i32, sem_vec3}, - Params{Op::kSubtract, ast_vec3, ast_i32, sem_vec3}, - Params{Op::kMultiply, ast_vec3, ast_i32, sem_vec3}, - Params{Op::kDivide, ast_vec3, ast_i32, sem_vec3}, - Params{Op::kModulo, ast_vec3, ast_i32, sem_vec3}, + ParamsFor, i32, vec3>(Op::kAdd), + ParamsFor, i32, vec3>(Op::kSubtract), + ParamsFor, i32, vec3>(Op::kMultiply), + ParamsFor, i32, vec3>(Op::kDivide), + ParamsFor, i32, vec3>(Op::kModulo), - Params{Op::kAdd, ast_i32, ast_vec3, sem_vec3}, - Params{Op::kSubtract, ast_i32, ast_vec3, sem_vec3}, - Params{Op::kMultiply, ast_i32, ast_vec3, sem_vec3}, - Params{Op::kDivide, ast_i32, ast_vec3, sem_vec3}, - Params{Op::kModulo, ast_i32, ast_vec3, sem_vec3}, + ParamsFor, vec3>(Op::kAdd), + ParamsFor, vec3>(Op::kSubtract), + ParamsFor, vec3>(Op::kMultiply), + ParamsFor, vec3>(Op::kDivide), + ParamsFor, vec3>(Op::kModulo), - Params{Op::kAdd, ast_vec3, ast_u32, sem_vec3}, - Params{Op::kSubtract, ast_vec3, ast_u32, sem_vec3}, - Params{Op::kMultiply, ast_vec3, ast_u32, sem_vec3}, - Params{Op::kDivide, ast_vec3, ast_u32, sem_vec3}, - Params{Op::kModulo, ast_vec3, ast_u32, sem_vec3}, + ParamsFor, u32, vec3>(Op::kAdd), + ParamsFor, u32, vec3>(Op::kSubtract), + ParamsFor, u32, vec3>(Op::kMultiply), + ParamsFor, u32, vec3>(Op::kDivide), + ParamsFor, u32, vec3>(Op::kModulo), - Params{Op::kAdd, ast_u32, ast_vec3, sem_vec3}, - Params{Op::kSubtract, ast_u32, ast_vec3, sem_vec3}, - Params{Op::kMultiply, ast_u32, ast_vec3, sem_vec3}, - Params{Op::kDivide, ast_u32, ast_vec3, sem_vec3}, - Params{Op::kModulo, ast_u32, ast_vec3, sem_vec3}, + ParamsFor, vec3>(Op::kAdd), + ParamsFor, vec3>(Op::kSubtract), + ParamsFor, vec3>(Op::kMultiply), + ParamsFor, vec3>(Op::kDivide), + ParamsFor, vec3>(Op::kModulo), - Params{Op::kAdd, ast_vec3, ast_f32, sem_vec3}, - Params{Op::kSubtract, ast_vec3, ast_f32, sem_vec3}, - Params{Op::kMultiply, ast_vec3, ast_f32, sem_vec3}, - Params{Op::kDivide, ast_vec3, ast_f32, sem_vec3}, - // NOTE: no kModulo for ast_vec3, ast_f32 - // Params{Op::kModulo, ast_vec3, ast_f32, sem_vec3}, + ParamsFor, f32, vec3>(Op::kAdd), + ParamsFor, f32, vec3>(Op::kSubtract), + ParamsFor, f32, vec3>(Op::kMultiply), + ParamsFor, f32, vec3>(Op::kDivide), + // NOTE: no kModulo for vec3, f32 + // ParamsFor, f32, vec3>(Op::kModulo), - Params{Op::kAdd, ast_f32, ast_vec3, sem_vec3}, - Params{Op::kSubtract, ast_f32, ast_vec3, sem_vec3}, - Params{Op::kMultiply, ast_f32, ast_vec3, sem_vec3}, - Params{Op::kDivide, ast_f32, ast_vec3, sem_vec3}, - // NOTE: no kModulo for ast_f32, ast_vec3 - // Params{Op::kModulo, ast_f32, ast_vec3, sem_vec3}, + ParamsFor, vec3>(Op::kAdd), + ParamsFor, vec3>(Op::kSubtract), + ParamsFor, vec3>(Op::kMultiply), + ParamsFor, vec3>(Op::kDivide), + // NOTE: no kModulo for f32, vec3 + // ParamsFor, vec3>(Op::kModulo), // Matrix arithmetic - Params{Op::kMultiply, ast_mat2x3, ast_f32, sem_mat2x3}, - Params{Op::kMultiply, ast_mat3x2, ast_f32, sem_mat3x2}, - Params{Op::kMultiply, ast_mat3x3, ast_f32, sem_mat3x3}, + ParamsFor, f32, mat2x3>(Op::kMultiply), + ParamsFor, f32, mat3x2>(Op::kMultiply), + ParamsFor, f32, mat3x3>(Op::kMultiply), - Params{Op::kMultiply, ast_f32, ast_mat2x3, sem_mat2x3}, - Params{Op::kMultiply, ast_f32, ast_mat3x2, sem_mat3x2}, - Params{Op::kMultiply, ast_f32, ast_mat3x3, sem_mat3x3}, + ParamsFor, mat2x3>(Op::kMultiply), + ParamsFor, mat3x2>(Op::kMultiply), + ParamsFor, mat3x3>(Op::kMultiply), - Params{Op::kMultiply, ast_vec3, ast_mat2x3, sem_vec2}, - Params{Op::kMultiply, ast_vec2, ast_mat3x2, sem_vec3}, - Params{Op::kMultiply, ast_vec3, ast_mat3x3, sem_vec3}, + ParamsFor, mat2x3, vec2>(Op::kMultiply), + ParamsFor, mat3x2, vec3>(Op::kMultiply), + ParamsFor, mat3x3, vec3>(Op::kMultiply), - Params{Op::kMultiply, ast_mat3x2, ast_vec3, sem_vec2}, - Params{Op::kMultiply, ast_mat2x3, ast_vec2, sem_vec3}, - Params{Op::kMultiply, ast_mat3x3, ast_vec3, sem_vec3}, + ParamsFor, vec3, vec2>(Op::kMultiply), + ParamsFor, vec2, vec3>(Op::kMultiply), + ParamsFor, vec3, vec3>(Op::kMultiply), - Params{Op::kMultiply, ast_mat2x3, ast_mat3x2, - sem_mat3x3}, - Params{Op::kMultiply, ast_mat3x2, ast_mat2x3, - sem_mat2x2}, - Params{Op::kMultiply, ast_mat3x2, ast_mat3x3, - sem_mat3x2}, - Params{Op::kMultiply, ast_mat3x3, ast_mat3x3, - sem_mat3x3}, - Params{Op::kMultiply, ast_mat3x3, ast_mat2x3, - sem_mat2x3}, + ParamsFor, mat3x2, mat3x3>(Op::kMultiply), + ParamsFor, mat2x3, mat2x2>(Op::kMultiply), + ParamsFor, mat3x3, mat3x2>(Op::kMultiply), + ParamsFor, mat3x3, mat3x3>(Op::kMultiply), + ParamsFor, mat2x3, mat2x3>(Op::kMultiply), - Params{Op::kAdd, ast_mat2x3, ast_mat2x3, sem_mat2x3}, - Params{Op::kAdd, ast_mat3x2, ast_mat3x2, sem_mat3x2}, - Params{Op::kAdd, ast_mat3x3, ast_mat3x3, sem_mat3x3}, + ParamsFor, mat2x3, mat2x3>(Op::kAdd), + ParamsFor, mat3x2, mat3x2>(Op::kAdd), + ParamsFor, mat3x3, mat3x3>(Op::kAdd), - Params{Op::kSubtract, ast_mat2x3, ast_mat2x3, - sem_mat2x3}, - Params{Op::kSubtract, ast_mat3x2, ast_mat3x2, - sem_mat3x2}, - Params{Op::kSubtract, ast_mat3x3, ast_mat3x3, - sem_mat3x3}, + ParamsFor, mat2x3, mat2x3>(Op::kSubtract), + ParamsFor, mat3x2, mat3x2>(Op::kSubtract), + ParamsFor, mat3x3, mat3x3>(Op::kSubtract), // Comparison expressions // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr // Comparisons over scalars - Params{Op::kEqual, ast_bool, ast_bool, sem_bool}, - Params{Op::kNotEqual, ast_bool, ast_bool, sem_bool}, + ParamsFor(Op::kEqual), + ParamsFor(Op::kNotEqual), - Params{Op::kEqual, ast_i32, ast_i32, sem_bool}, - Params{Op::kNotEqual, ast_i32, ast_i32, sem_bool}, - Params{Op::kLessThan, ast_i32, ast_i32, sem_bool}, - Params{Op::kLessThanEqual, ast_i32, ast_i32, sem_bool}, - Params{Op::kGreaterThan, ast_i32, ast_i32, sem_bool}, - Params{Op::kGreaterThanEqual, ast_i32, ast_i32, sem_bool}, + ParamsFor(Op::kEqual), + ParamsFor(Op::kNotEqual), + ParamsFor(Op::kLessThan), + ParamsFor(Op::kLessThanEqual), + ParamsFor(Op::kGreaterThan), + ParamsFor(Op::kGreaterThanEqual), - Params{Op::kEqual, ast_u32, ast_u32, sem_bool}, - Params{Op::kNotEqual, ast_u32, ast_u32, sem_bool}, - Params{Op::kLessThan, ast_u32, ast_u32, sem_bool}, - Params{Op::kLessThanEqual, ast_u32, ast_u32, sem_bool}, - Params{Op::kGreaterThan, ast_u32, ast_u32, sem_bool}, - Params{Op::kGreaterThanEqual, ast_u32, ast_u32, sem_bool}, + ParamsFor(Op::kEqual), + ParamsFor(Op::kNotEqual), + ParamsFor(Op::kLessThan), + ParamsFor(Op::kLessThanEqual), + ParamsFor(Op::kGreaterThan), + ParamsFor(Op::kGreaterThanEqual), - Params{Op::kEqual, ast_f32, ast_f32, sem_bool}, - Params{Op::kNotEqual, ast_f32, ast_f32, sem_bool}, - Params{Op::kLessThan, ast_f32, ast_f32, sem_bool}, - Params{Op::kLessThanEqual, ast_f32, ast_f32, sem_bool}, - Params{Op::kGreaterThan, ast_f32, ast_f32, sem_bool}, - Params{Op::kGreaterThanEqual, ast_f32, ast_f32, sem_bool}, + ParamsFor(Op::kEqual), + ParamsFor(Op::kNotEqual), + ParamsFor(Op::kLessThan), + ParamsFor(Op::kLessThanEqual), + ParamsFor(Op::kGreaterThan), + ParamsFor(Op::kGreaterThanEqual), // Comparisons over vectors - Params{Op::kEqual, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kNotEqual, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor, vec3, vec3>(Op::kEqual), + ParamsFor, vec3, vec3>(Op::kNotEqual), - Params{Op::kEqual, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kNotEqual, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kLessThan, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kLessThanEqual, ast_vec3, ast_vec3, - sem_vec3}, - Params{Op::kGreaterThan, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kGreaterThanEqual, ast_vec3, ast_vec3, - sem_vec3}, + ParamsFor, vec3, vec3>(Op::kEqual), + ParamsFor, vec3, vec3>(Op::kNotEqual), + ParamsFor, vec3, vec3>(Op::kLessThan), + ParamsFor, vec3, vec3>(Op::kLessThanEqual), + ParamsFor, vec3, vec3>(Op::kGreaterThan), + ParamsFor, vec3, vec3>(Op::kGreaterThanEqual), - Params{Op::kEqual, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kNotEqual, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kLessThan, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kLessThanEqual, ast_vec3, ast_vec3, - sem_vec3}, - Params{Op::kGreaterThan, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kGreaterThanEqual, ast_vec3, ast_vec3, - sem_vec3}, + ParamsFor, vec3, vec3>(Op::kEqual), + ParamsFor, vec3, vec3>(Op::kNotEqual), + ParamsFor, vec3, vec3>(Op::kLessThan), + ParamsFor, vec3, vec3>(Op::kLessThanEqual), + ParamsFor, vec3, vec3>(Op::kGreaterThan), + ParamsFor, vec3, vec3>(Op::kGreaterThanEqual), - Params{Op::kEqual, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kNotEqual, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kLessThan, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kLessThanEqual, ast_vec3, ast_vec3, - sem_vec3}, - Params{Op::kGreaterThan, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kGreaterThanEqual, ast_vec3, ast_vec3, - sem_vec3}, + ParamsFor, vec3, vec3>(Op::kEqual), + ParamsFor, vec3, vec3>(Op::kNotEqual), + ParamsFor, vec3, vec3>(Op::kLessThan), + ParamsFor, vec3, vec3>(Op::kLessThanEqual), + ParamsFor, vec3, vec3>(Op::kGreaterThan), + ParamsFor, vec3, vec3>(Op::kGreaterThanEqual), // Binary bitwise operations - Params{Op::kOr, ast_i32, ast_i32, sem_i32}, - Params{Op::kAnd, ast_i32, ast_i32, sem_i32}, - Params{Op::kXor, ast_i32, ast_i32, sem_i32}, + ParamsFor(Op::kOr), + ParamsFor(Op::kAnd), + ParamsFor(Op::kXor), - Params{Op::kOr, ast_u32, ast_u32, sem_u32}, - Params{Op::kAnd, ast_u32, ast_u32, sem_u32}, - Params{Op::kXor, ast_u32, ast_u32, sem_u32}, + ParamsFor(Op::kOr), + ParamsFor(Op::kAnd), + ParamsFor(Op::kXor), - Params{Op::kOr, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kAnd, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kXor, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor, vec3, vec3>(Op::kOr), + ParamsFor, vec3, vec3>(Op::kAnd), + ParamsFor, vec3, vec3>(Op::kXor), - Params{Op::kOr, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kAnd, ast_vec3, ast_vec3, sem_vec3}, - Params{Op::kXor, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor, vec3, vec3>(Op::kOr), + ParamsFor, vec3, vec3>(Op::kAnd), + ParamsFor, vec3, vec3>(Op::kXor), // Bit shift expressions - Params{Op::kShiftLeft, ast_i32, ast_u32, sem_i32}, - Params{Op::kShiftLeft, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor(Op::kShiftLeft), + ParamsFor, vec3, vec3>(Op::kShiftLeft), - Params{Op::kShiftLeft, ast_u32, ast_u32, sem_u32}, - Params{Op::kShiftLeft, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor(Op::kShiftLeft), + ParamsFor, vec3, vec3>(Op::kShiftLeft), - Params{Op::kShiftRight, ast_i32, ast_u32, sem_i32}, - Params{Op::kShiftRight, ast_vec3, ast_vec3, sem_vec3}, + ParamsFor(Op::kShiftRight), + ParamsFor, vec3, vec3>(Op::kShiftRight), - Params{Op::kShiftRight, ast_u32, ast_u32, sem_u32}, - Params{Op::kShiftRight, ast_vec3, ast_vec3, sem_vec3}}; + ParamsFor(Op::kShiftRight), + ParamsFor, vec3, vec3>(Op::kShiftRight), +}; using Expr_Binary_Test_Valid = ResolverTestWithParam; TEST_P(Expr_Binary_Test_Valid, All) { auto& params = GetParam(); - auto* lhs_type = params.create_lhs_type(ty); - auto* rhs_type = params.create_rhs_type(ty); - auto* result_type = params.create_result_type(ty); + auto* lhs_type = params.create_lhs_type(*this); + auto* rhs_type = params.create_rhs_type(*this); + auto* result_type = params.create_result_type(*this); std::stringstream ss; ss << FriendlyName(lhs_type) << " " << params.op << " " @@ -1503,38 +1559,22 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) { const Params& params = std::get<0>(GetParam()); BinaryExprSide side = std::get<1>(GetParam()); - auto* lhs_type = params.create_lhs_type(ty); - auto* rhs_type = params.create_rhs_type(ty); + auto* create_lhs_type = + (side == BinaryExprSide::Left || side == BinaryExprSide::Both) + ? params.create_lhs_alias_type + : params.create_lhs_type; + auto* create_rhs_type = + (side == BinaryExprSide::Right || side == BinaryExprSide::Both) + ? params.create_rhs_alias_type + : params.create_rhs_type; + + auto* lhs_type = create_lhs_type(*this); + auto* rhs_type = create_rhs_type(*this); std::stringstream ss; ss << FriendlyName(lhs_type) << " " << params.op << " " << FriendlyName(rhs_type); - // For vectors and matrices, wrap the sub type in an alias - auto make_alias = [this](ast::Type* type) -> ast::Type* { - if (auto* v = type->As()) { - auto* alias = ty.alias(Symbols().New(), v->type()); - AST().AddConstructedType(alias); - return ty.vec(alias, v->size()); - } - if (auto* m = type->As()) { - auto* alias = ty.alias(Symbols().New(), m->type()); - AST().AddConstructedType(alias); - return ty.mat(alias, m->columns(), m->rows()); - } - auto* alias = ty.alias(Symbols().New(), type); - AST().AddConstructedType(alias); - return ty.type_name(alias->name()); - }; - - // Wrap in alias - if (side == BinaryExprSide::Left || side == BinaryExprSide::Both) { - lhs_type = make_alias(lhs_type); - } - if (side == BinaryExprSide::Right || side == BinaryExprSide::Both) { - rhs_type = make_alias(rhs_type); - } - ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op << " " << FriendlyName(rhs_type); SCOPED_TRACE(ss.str()); @@ -1550,7 +1590,7 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) { ASSERT_NE(TypeOf(expr), nullptr); // TODO(amaiorano): Bring this back once we have a way to get the canonical // type - // auto* *result_type = params.create_result_type(ty); + // auto* *result_type = params.create_result_type(*this); // ASSERT_TRUE(TypeOf(expr) == result_type); } INSTANTIATE_TEST_SUITE_P( @@ -1565,13 +1605,13 @@ INSTANTIATE_TEST_SUITE_P( // (type * type * op), and processing only the triplets that are not found in // the `all_valid_cases` table. using Expr_Binary_Test_Invalid = - ResolverTestWithParam>; TEST_P(Expr_Binary_Test_Invalid, All) { - const create_ast_type_func_ptr& lhs_create_type_func = + const builder::ast_type_func_ptr& lhs_create_type_func = std::get<0>(GetParam()); - const create_ast_type_func_ptr& rhs_create_type_func = + const builder::ast_type_func_ptr& rhs_create_type_func = std::get<1>(GetParam()); const ast::BinaryOp op = std::get<2>(GetParam()); @@ -1584,8 +1624,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) { } } - auto* lhs_type = lhs_create_type_func(ty); - auto* rhs_type = rhs_create_type_func(ty); + auto* lhs_type = lhs_create_type_func(*this); + auto* rhs_type = rhs_create_type_func(*this); std::stringstream ss; ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type); diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index 3597e27566..37692fbbe0 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -120,172 +120,271 @@ template class ResolverTestWithParam : public TestHelper, public testing::TestWithParam {}; -inline ast::Type* ast_bool(const ProgramBuilder::TypesBuilder& ty) { - return ty.bool_(); -} -inline ast::Type* ast_i32(const ProgramBuilder::TypesBuilder& ty) { - return ty.i32(); -} -inline ast::Type* ast_u32(const ProgramBuilder::TypesBuilder& ty) { - return ty.u32(); -} -inline ast::Type* ast_f32(const ProgramBuilder::TypesBuilder& ty) { - return ty.f32(); -} +namespace builder { -using create_ast_type_func_ptr = - ast::Type* (*)(const ProgramBuilder::TypesBuilder& ty); +using i32 = ProgramBuilder::i32; +using u32 = ProgramBuilder::u32; +using f32 = ProgramBuilder::f32; + +template +struct vec {}; template -ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) { - return ty.vec2(); -} - -template -ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) { - return ty.vec2(create_type(ty)); -} +using vec2 = vec<2, T>; template -ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) { - return ty.vec3(); -} - -template -ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) { - return ty.vec3(create_type(ty)); -} +using vec3 = vec<3, T>; template -ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) { - return ty.vec4(); -} +using vec4 = vec<4, T>; -template -ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) { - return ty.vec4(create_type(ty)); -} +template +struct mat {}; template -ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat2x2(); -} - -template -ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat2x2(create_type(ty)); -} +using mat2x2 = mat<2, 2, T>; template -ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat2x3(); -} - -template -ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat2x3(create_type(ty)); -} +using mat2x3 = mat<2, 3, T>; template -ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat3x2(); -} - -template -ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat3x2(create_type(ty)); -} +using mat3x2 = mat<3, 2, T>; template -ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat3x3(); -} - -template -ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat3x3(create_type(ty)); -} +using mat3x3 = mat<3, 3, T>; template -ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat4x4(); -} +using mat4x4 = mat<4, 4, T>; -template -ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat4x4(create_type(ty)); -} +template +struct alias {}; -template -ast::Type* ast_alias(const ProgramBuilder::TypesBuilder& ty) { - auto* type = create_type(ty); - auto name = ty.builder->Symbols().Register("alias_" + type->type_name()); - if (!ty.builder->AST().LookupType(name)) { - ty.builder->AST().AddConstructedType(ty.alias(name, type)); +template +using alias1 = alias; + +template +using alias2 = alias; + +template +using alias3 = alias; + +using ast_type_func_ptr = ast::Type* (*)(ProgramBuilder& b); +using ast_expr_func_ptr = ast::Expression* (*)(ProgramBuilder& b, + int elem_value); +using sem_type_func_ptr = sem::Type* (*)(ProgramBuilder& b); + +template +struct DataType {}; + +/// Helper for building bool types and expressions +template <> +struct DataType { + /// false as bool is not a composite type + static constexpr bool is_composite = false; + + /// @param b the ProgramBuilder + /// @return a new AST bool type + static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); } + /// @param b the ProgramBuilder + /// @return the semantic bool type + static inline sem::Type* Sem(ProgramBuilder& b) { + return b.create(); } - return ty.builder->create(name); -} + /// @param b the ProgramBuilder + /// @param elem_value the b + /// @return a new AST expression of the bool type + static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) { + return b.Expr(elem_value == 0); + } +}; -inline sem::Type* sem_bool(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(); -} -inline sem::Type* sem_i32(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(); -} -inline sem::Type* sem_u32(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(); -} -inline sem::Type* sem_f32(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(); -} +/// Helper for building i32 types and expressions +template <> +struct DataType { + /// false as i32 is not a composite type + static constexpr bool is_composite = false; -using create_sem_type_func_ptr = - sem::Type* (*)(const ProgramBuilder::TypesBuilder& ty); + /// @param b the ProgramBuilder + /// @return a new AST i32 type + static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); } + /// @param b the ProgramBuilder + /// @return the semantic i32 type + static inline sem::Type* Sem(ProgramBuilder& b) { + return b.create(); + } + /// @param b the ProgramBuilder + /// @param elem_value the value i32 will be initialized with + /// @return a new AST i32 literal value expression + static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) { + return b.Expr(static_cast(elem_value)); + } +}; -template -sem::Type* sem_vec2(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(create_type(ty), 2); -} +/// Helper for building u32 types and expressions +template <> +struct DataType { + /// false as u32 is not a composite type + static constexpr bool is_composite = false; -template -sem::Type* sem_vec3(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(create_type(ty), 3); -} + /// @param b the ProgramBuilder + /// @return a new AST u32 type + static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); } + /// @param b the ProgramBuilder + /// @return the semantic u32 type + static inline sem::Type* Sem(ProgramBuilder& b) { + return b.create(); + } + /// @param b the ProgramBuilder + /// @param elem_value the value u32 will be initialized with + /// @return a new AST u32 literal value expression + static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) { + return b.Expr(static_cast(elem_value)); + } +}; -template -sem::Type* sem_vec4(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(create_type(ty), 4); -} +/// Helper for building f32 types and expressions +template <> +struct DataType { + /// false as f32 is not a composite type + static constexpr bool is_composite = false; -template -sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) { - auto* column_type = ty.builder->create(create_type(ty), 2u); - return ty.builder->create(column_type, 2u); -} + /// @param b the ProgramBuilder + /// @return a new AST f32 type + static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); } + /// @param b the ProgramBuilder + /// @return the semantic f32 type + static inline sem::Type* Sem(ProgramBuilder& b) { + return b.create(); + } + /// @param b the ProgramBuilder + /// @param elem_value the value f32 will be initialized with + /// @return a new AST f32 literal value expression + static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) { + return b.Expr(static_cast(elem_value)); + } +}; -template -sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) { - auto* column_type = ty.builder->create(create_type(ty), 3u); - return ty.builder->create(column_type, 2u); -} +/// Helper for building vector types and expressions +template +struct DataType> { + /// true as vectors are a composite type + static constexpr bool is_composite = true; -template -sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) { - auto* column_type = ty.builder->create(create_type(ty), 2u); - return ty.builder->create(column_type, 3u); -} + /// @param b the ProgramBuilder + /// @return a new AST vector type + static inline ast::Type* AST(ProgramBuilder& b) { + return b.ty.vec(DataType::AST(b), N); + } + /// @param b the ProgramBuilder + /// @return the semantic vector type + static inline sem::Type* Sem(ProgramBuilder& b) { + return b.create(DataType::Sem(b), N); + } + /// @param b the ProgramBuilder + /// @param elem_value the value each element in the vector will be initialized + /// with + /// @return a new AST vector value expression + static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) { + return b.Construct(AST(b), ExprArgs(b, elem_value)); + } -template -sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) { - auto* column_type = ty.builder->create(create_type(ty), 3u); - return ty.builder->create(column_type, 3u); -} + /// @param b the ProgramBuilder + /// @param elem_value the value each element will be initialized with + /// @return the list of expressions that are used to construct the vector + static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, + int elem_value) { + ast::ExpressionList args; + for (int i = 0; i < N; i++) { + args.emplace_back(DataType::Expr(b, elem_value)); + } + return args; + } +}; -template -sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) { - auto* column_type = ty.builder->create(create_type(ty), 4u); - return ty.builder->create(column_type, 4u); -} +/// Helper for building matrix types and expressions +template +struct DataType> { + /// true as matrices are a composite type + static constexpr bool is_composite = true; + + /// @param b the ProgramBuilder + /// @return a new AST matrix type + static inline ast::Type* AST(ProgramBuilder& b) { + return b.ty.mat(DataType::AST(b), N, M); + } + /// @param b the ProgramBuilder + /// @return the semantic matrix type + static inline sem::Type* Sem(ProgramBuilder& b) { + auto* column_type = b.create(DataType::Sem(b), M); + return b.create(column_type, N); + } + /// @param b the ProgramBuilder + /// @param elem_value the value each element in the matrix will be initialized + /// with + /// @return a new AST matrix value expression + static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) { + return b.Construct(AST(b), ExprArgs(b, elem_value)); + } + + /// @param b the ProgramBuilder + /// @param elem_value the value each element will be initialized with + /// @return the list of expressions that are used to construct the matrix + static inline ast::ExpressionList ExprArgs(ProgramBuilder& b, + int elem_value) { + ast::ExpressionList args; + for (int i = 0; i < N; i++) { + args.emplace_back(DataType>::Expr(b, elem_value)); + } + return args; + } +}; + +/// Helper for building alias types and expressions +template +struct DataType> { + /// true if the aliased type is a composite type + static constexpr bool is_composite = DataType::is_composite; + + /// @param b the ProgramBuilder + /// @return a new AST alias type + static inline ast::Type* AST(ProgramBuilder& b) { + auto name = b.Symbols().Register("alias_" + std::to_string(ID)); + if (!b.AST().LookupType(name)) { + auto* type = DataType::AST(b); + b.AST().AddConstructedType(b.ty.alias(name, type)); + } + return b.create(name); + } + /// @param b the ProgramBuilder + /// @return the semantic aliased type + static inline sem::Type* Sem(ProgramBuilder& b) { + return DataType::Sem(b); + } + + /// @param b the ProgramBuilder + /// @param elem_value the value nested elements will be initialized with + /// @return a new AST expression of the alias type + template + static inline traits::EnableIf Expr( + ProgramBuilder& b, + int elem_value) { + // Cast + return b.Construct(AST(b), DataType::Expr(b, elem_value)); + } + + /// @param b the ProgramBuilder + /// @param elem_value the value nested elements will be initialized with + /// @return a new AST expression of the alias type + template + static inline traits::EnableIf Expr( + ProgramBuilder& b, + int elem_value) { + // Construct + return b.Construct(AST(b), DataType::ExprArgs(b, elem_value)); + } +}; + +} // namespace builder } // namespace resolver } // namespace tint diff --git a/src/resolver/type_constructor_validation_test.cc b/src/resolver/type_constructor_validation_test.cc index 24a19b9096..383a22c2e3 100644 --- a/src/resolver/type_constructor_validation_test.cc +++ b/src/resolver/type_constructor_validation_test.cc @@ -19,30 +19,47 @@ namespace tint { namespace resolver { namespace { -/// @return the element type of `type` for vec and mat, otherwise `type` itself -ast::Type* ElementTypeOf(ast::Type* type) { - if (auto* v = type->As()) { - return v->type(); - } - if (auto* m = type->As()) { - return m->type(); - } - return type; -} +// Helpers and typedefs +template +using DataType = builder::DataType; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using mat2x2 = builder::mat2x2; +template +using mat3x3 = builder::mat3x3; +template +using mat4x4 = builder::mat4x4; +template +using alias = builder::alias; +template +using alias1 = builder::alias1; +template +using alias2 = builder::alias2; +template +using alias3 = builder::alias3; +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; class ResolverTypeConstructorValidationTest : public resolver::TestHelper, public testing::Test {}; namespace InferTypeTest { struct Params { - create_ast_type_func_ptr create_rhs_ast_type; - create_sem_type_func_ptr create_rhs_sem_type; + builder::ast_type_func_ptr create_rhs_ast_type; + builder::ast_expr_func_ptr create_rhs_ast_value; + builder::sem_type_func_ptr create_rhs_sem_type; }; -// Helpers and typedefs -using i32 = ProgramBuilder::i32; -using u32 = ProgramBuilder::u32; -using f32 = ProgramBuilder::f32; +template +constexpr Params ParamsFor() { + return Params{DataType::AST, DataType::Expr, DataType::Sem}; +} TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) { // var a = 1; @@ -75,8 +92,7 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) { // } auto& params = GetParam(); - auto* rhs_type = params.create_rhs_ast_type(ty); - auto* constructor_expr = ConstructValueFilledWith(rhs_type, 0); + auto* constructor_expr = params.create_rhs_ast_value(*this, 0); auto* a = Var("a", nullptr, ast::StorageClass::kNone, constructor_expr); // Self-assign 'a' to force the expression to be resolved so we can test its @@ -86,7 +102,7 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) { ASSERT_TRUE(r()->Resolve()) << r()->error(); auto* got = TypeOf(a_ident); - auto* expected = create(params.create_rhs_sem_type(ty), + auto* expected = create(params.create_rhs_sem_type(*this), ast::StorageClass::kFunction, ast::Access::kReadWrite); ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" @@ -94,26 +110,26 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) { } static constexpr Params from_constructor_expression_cases[] = { - Params{ast_bool, sem_bool}, - Params{ast_i32, sem_i32}, - Params{ast_u32, sem_u32}, - Params{ast_f32, sem_f32}, - Params{ast_vec3, sem_vec3}, - Params{ast_vec3, sem_vec3}, - Params{ast_vec3, sem_vec3}, - Params{ast_mat3x3, sem_mat3x3}, - Params{ast_mat3x3, sem_mat3x3}, - Params{ast_mat3x3, sem_mat3x3}, - Params{ast_alias, sem_bool}, - Params{ast_alias, sem_i32}, - Params{ast_alias, sem_u32}, - Params{ast_alias, sem_f32}, - Params{ast_alias>, sem_vec3}, - Params{ast_alias>, sem_vec3}, - Params{ast_alias>, sem_vec3}, - Params{ast_alias>, sem_mat3x3}, - Params{ast_alias>, sem_mat3x3}, - Params{ast_alias>, sem_mat3x3}, + ParamsFor(), + ParamsFor(), + ParamsFor(), + ParamsFor(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), }; INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, InferTypeTest_FromConstructorExpression, @@ -127,13 +143,11 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) { // } auto& params = GetParam(); - auto* rhs_type = params.create_rhs_ast_type(ty); - - auto* arith_lhs_expr = ConstructValueFilledWith(rhs_type, 2); - auto* arith_rhs_expr = ConstructValueFilledWith(ElementTypeOf(rhs_type), 3); + auto* arith_lhs_expr = params.create_rhs_ast_value(*this, 2); + auto* arith_rhs_expr = params.create_rhs_ast_value(*this, 3); auto* constructor_expr = Mul(arith_lhs_expr, arith_rhs_expr); - auto* a = Var("a", nullptr, ast::StorageClass::kNone, constructor_expr); + auto* a = Var("a", nullptr, constructor_expr); // Self-assign 'a' to force the expression to be resolved so we can test its // type below auto* a_ident = Expr("a"); @@ -141,25 +155,22 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) { ASSERT_TRUE(r()->Resolve()) << r()->error(); auto* got = TypeOf(a_ident); - auto* expected = create(params.create_rhs_sem_type(ty), + auto* expected = create(params.create_rhs_sem_type(*this), ast::StorageClass::kFunction, ast::Access::kReadWrite); ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" << "expected: " << FriendlyName(expected) << "\n"; } static constexpr Params from_arithmetic_expression_cases[] = { - Params{ast_i32, sem_i32}, - Params{ast_u32, sem_u32}, - Params{ast_f32, sem_f32}, - Params{ast_vec3, sem_vec3}, - Params{ast_mat3x3, sem_mat3x3}, + ParamsFor(), ParamsFor(), ParamsFor(), + ParamsFor>(), ParamsFor>(), // TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed - // Params{ty_alias}, - // Params{ty_alias}, - // Params{ty_alias}, - // Params{ty_alias>}, - // Params{ty_alias>}, + // ParamsFor>(), + // ParamsFor>(), + // ParamsFor>(), + // ParamsFor>>(), + // ParamsFor>>(), }; INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, InferTypeTest_FromArithmeticExpression, @@ -170,7 +181,7 @@ TEST_P(InferTypeTest_FromCallExpression, All) { // e.g. for vec3 // // fn foo() -> vec3 { - // return vec3(0.0, 0.0, 0.0); + // return vec3(); // } // // fn bar() @@ -179,11 +190,10 @@ TEST_P(InferTypeTest_FromCallExpression, All) { // } auto& params = GetParam(); - Func("foo", {}, params.create_rhs_ast_type(ty), - {Return(ConstructValueFilledWith(params.create_rhs_ast_type(ty), 0))}, - {}); + Func("foo", {}, params.create_rhs_ast_type(*this), + {Return(Construct(params.create_rhs_ast_type(*this)))}, {}); - auto* a = Var("a", nullptr, ast::StorageClass::kNone, Call(Expr("foo"))); + auto* a = Var("a", nullptr, Call("foo")); // Self-assign 'a' to force the expression to be resolved so we can test its // type below auto* a_ident = Expr("a"); @@ -191,33 +201,33 @@ TEST_P(InferTypeTest_FromCallExpression, All) { ASSERT_TRUE(r()->Resolve()) << r()->error(); auto* got = TypeOf(a_ident); - auto* expected = create(params.create_rhs_sem_type(ty), + auto* expected = create(params.create_rhs_sem_type(*this), ast::StorageClass::kFunction, ast::Access::kReadWrite); ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" << "expected: " << FriendlyName(expected) << "\n"; } static constexpr Params from_call_expression_cases[] = { - Params{ast_bool, sem_bool}, - Params{ast_i32, sem_i32}, - Params{ast_u32, sem_u32}, - Params{ast_f32, sem_f32}, - Params{ast_vec3, sem_vec3}, - Params{ast_vec3, sem_vec3}, - Params{ast_vec3, sem_vec3}, - Params{ast_mat3x3, sem_mat3x3}, - Params{ast_mat3x3, sem_mat3x3}, - Params{ast_mat3x3, sem_mat3x3}, - Params{ast_alias, sem_bool}, - Params{ast_alias, sem_i32}, - Params{ast_alias, sem_u32}, - Params{ast_alias, sem_f32}, - Params{ast_alias>, sem_vec3}, - Params{ast_alias>, sem_vec3}, - Params{ast_alias>, sem_vec3}, - Params{ast_alias>, sem_mat3x3}, - Params{ast_alias>, sem_mat3x3}, - Params{ast_alias>, sem_mat3x3}, + ParamsFor(), + ParamsFor(), + ParamsFor(), + ParamsFor(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), + ParamsFor>>(), }; INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, InferTypeTest_FromCallExpression, diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc index c7882c067c..5c9973e71c 100644 --- a/src/resolver/type_validation_test.cc +++ b/src/resolver/type_validation_test.cc @@ -27,6 +27,33 @@ namespace tint { namespace resolver { namespace { +// Helpers and typedefs +template +using DataType = builder::DataType; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using mat2x2 = builder::mat2x2; +template +using mat3x3 = builder::mat3x3; +template +using mat4x4 = builder::mat4x4; +template +using alias = builder::alias; +template +using alias1 = builder::alias1; +template +using alias2 = builder::alias2; +template +using alias3 = builder::alias3; +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; + class ResolverTypeValidationTest : public resolver::TestHelper, public testing::Test {}; @@ -366,43 +393,44 @@ TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) { namespace GetCanonicalTests { struct Params { - create_ast_type_func_ptr create_ast_type; - create_sem_type_func_ptr create_sem_type; + builder::ast_type_func_ptr create_ast_type; + builder::sem_type_func_ptr create_sem_type; }; +template +constexpr Params ParamsFor() { + return Params{DataType::AST, DataType::Sem}; +} + static constexpr Params cases[] = { - Params{ast_bool, sem_bool}, - Params{ast_alias, sem_bool}, - Params{ast_alias>, sem_bool}, + ParamsFor(), + ParamsFor>(), + ParamsFor>>(), - Params{ast_vec3, sem_vec3}, - Params{ast_alias>, sem_vec3}, - Params{ast_alias>>, sem_vec3}, + ParamsFor>(), + ParamsFor>>(), + ParamsFor>>>(), - Params{ast_vec3>, sem_vec3}, - Params{ast_alias>>, sem_vec3}, - Params{ast_alias>>>, - sem_vec3}, - Params{ast_alias>>>>, - sem_vec3}, + ParamsFor>>(), + ParamsFor>>>(), + ParamsFor>>>>(), + ParamsFor>>>>>(), - Params{ast_mat3x3>, sem_mat3x3}, - Params{ast_alias>>, sem_mat3x3}, - Params{ast_alias>>>, - sem_mat3x3}, - Params{ast_alias>>>>, - sem_mat3x3}, + ParamsFor>>(), + ParamsFor>>>(), + ParamsFor>>>>(), + ParamsFor>>>>>(), - Params{ast_alias>, sem_bool}, - Params{ast_alias>>, sem_vec3}, - Params{ast_alias>>, sem_mat3x3}, + ParamsFor>>(), + ParamsFor>>>(), + ParamsFor>>>(), }; using CanonicalTest = ResolverTestWithParam; TEST_P(CanonicalTest, All) { auto& params = GetParam(); - auto* type = params.create_ast_type(ty); + auto* type = params.create_ast_type(*this); auto* var = Var("v", type); auto* expr = Expr("v"); @@ -411,7 +439,7 @@ TEST_P(CanonicalTest, All) { EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* got = TypeOf(expr)->UnwrapRef(); - auto* expected = params.create_sem_type(ty); + auto* expected = params.create_sem_type(*this); EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" << "expected: " << FriendlyName(expected) << "\n"; @@ -459,38 +487,44 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest, testing::ValuesIn(dimension_cases)); struct TypeParams { - create_ast_type_func_ptr type_func; + builder::ast_type_func_ptr type_func; bool is_valid; }; +template +constexpr TypeParams TypeParamsFor(bool is_valid) { + return TypeParams{DataType::AST, is_valid}; +} + static constexpr TypeParams type_cases[] = { - TypeParams{ast_bool, false}, - TypeParams{ast_i32, true}, - TypeParams{ast_u32, true}, - TypeParams{ast_f32, true}, + TypeParamsFor(false), + TypeParamsFor(true), + TypeParamsFor(true), + TypeParamsFor(true), - TypeParams{ast_alias, false}, - TypeParams{ast_alias, true}, - TypeParams{ast_alias, true}, - TypeParams{ast_alias, true}, + TypeParamsFor>(false), + TypeParamsFor>(true), + TypeParamsFor>(true), + TypeParamsFor>(true), - TypeParams{ast_vec3, false}, - TypeParams{ast_mat3x3, false}, + TypeParamsFor>(false), + TypeParamsFor>(false), - TypeParams{ast_alias>, false}, - TypeParams{ast_alias>, false}}; + TypeParamsFor>>(false), + TypeParamsFor>>(false), +}; using MultisampledTextureTypeTest = ResolverTestWithParam; TEST_P(MultisampledTextureTypeTest, All) { auto& params = GetParam(); - Global( - Source{{12, 34}}, "a", - ty.multisampled_texture(ast::TextureDimension::k2d, params.type_func(ty)), - ast::StorageClass::kNone, nullptr, - ast::DecorationList{ - create(0), - create(0), - }); + Global(Source{{12, 34}}, "a", + ty.multisampled_texture(ast::TextureDimension::k2d, + params.type_func(*this)), + ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(0), + create(0), + }); if (params.is_valid) { EXPECT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc index f1afe46b81..b7400ee59b 100644 --- a/src/writer/spirv/builder_accessor_expression_test.cc +++ b/src/writer/spirv/builder_accessor_expression_test.cc @@ -378,10 +378,11 @@ TEST_F(BuilderTest, MemberAccessor_Nested_NonPointer) { } TEST_F(BuilderTest, MemberAccessor_Nested_WithAlias) { - // type Inner = struct { + // struct Inner { // a : f32 // b : f32 - // } + // }; + // type Alias = Inner; // my_struct { // inner : Inner // } @@ -393,7 +394,8 @@ TEST_F(BuilderTest, MemberAccessor_Nested_WithAlias) { Member("b", ty.f32()), }); - auto* alias = ty.alias("Inner", inner_struct); + auto* alias = ty.alias("Alias", inner_struct); + AST().AddConstructedType(alias); auto* s_type = Structure("Outer", {Member("inner", alias)}); auto* var = Var("ident", s_type); diff --git a/src/writer/wgsl/generator_impl_alias_type_test.cc b/src/writer/wgsl/generator_impl_alias_type_test.cc index a1ecefa400..4b7311bc85 100644 --- a/src/writer/wgsl/generator_impl_alias_type_test.cc +++ b/src/writer/wgsl/generator_impl_alias_type_test.cc @@ -23,6 +23,7 @@ using WgslGeneratorImplTest = TestHelper; TEST_F(WgslGeneratorImplTest, EmitAlias_F32) { auto* alias = ty.alias("a", ty.f32()); + AST().AddConstructedType(alias); GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.EmitConstructedType(alias)) << gen.error(); @@ -37,6 +38,7 @@ TEST_F(WgslGeneratorImplTest, EmitConstructedType_Struct) { }); auto* alias = ty.alias("B", s); + AST().AddConstructedType(alias); GeneratorImpl& gen = Build(); @@ -57,6 +59,7 @@ TEST_F(WgslGeneratorImplTest, EmitAlias_ToStruct) { }); auto* alias = ty.alias("B", s); + AST().AddConstructedType(alias); GeneratorImpl& gen = Build(); diff --git a/src/writer/wgsl/generator_impl_type_test.cc b/src/writer/wgsl/generator_impl_type_test.cc index 5bb5abd17f..5494e7b907 100644 --- a/src/writer/wgsl/generator_impl_type_test.cc +++ b/src/writer/wgsl/generator_impl_type_test.cc @@ -250,6 +250,7 @@ struct S { TEST_F(WgslGeneratorImplTest, EmitType_U32) { auto* u32 = ty.u32(); + AST().AddConstructedType(ty.alias("make_type_reachable", u32)); GeneratorImpl& gen = Build(); @@ -402,6 +403,11 @@ TEST_P(WgslGenerator_StorageTextureTest, EmitType_StorageTexture) { auto param = GetParam(); auto* t = ty.storage_texture(param.dim, param.fmt, param.access); + Global("g", t, + ast::DecorationList{ + create(1), + create(2), + }); GeneratorImpl& gen = Build(); @@ -412,30 +418,30 @@ INSTANTIATE_TEST_SUITE_P( WgslGeneratorImplTest, WgslGenerator_StorageTextureTest, testing::Values( - StorageTextureData{ast::ImageFormat::kR8Unorm, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k1d, ast::Access::kRead, - "texture_storage_1d"}, - StorageTextureData{ast::ImageFormat::kR8Unorm, + "texture_storage_1d"}, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k2d, ast::Access::kRead, - "texture_storage_2d"}, - StorageTextureData{ast::ImageFormat::kR8Unorm, + "texture_storage_2d"}, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k2dArray, ast::Access::kRead, - "texture_storage_2d_array"}, - StorageTextureData{ast::ImageFormat::kR8Unorm, + "texture_storage_2d_array"}, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k3d, ast::Access::kRead, - "texture_storage_3d"}, - StorageTextureData{ast::ImageFormat::kR8Unorm, + "texture_storage_3d"}, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k1d, ast::Access::kWrite, - "texture_storage_1d"}, - StorageTextureData{ast::ImageFormat::kR8Unorm, + "texture_storage_1d"}, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k2d, ast::Access::kWrite, - "texture_storage_2d"}, - StorageTextureData{ast::ImageFormat::kR8Unorm, + "texture_storage_2d"}, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k2dArray, ast::Access::kWrite, - "texture_storage_2d_array"}, - StorageTextureData{ast::ImageFormat::kR8Unorm, + "texture_storage_2d_array"}, + StorageTextureData{ast::ImageFormat::kRgba8Sint, ast::TextureDimension::k3d, ast::Access::kWrite, - "texture_storage_3d"})); + "texture_storage_3d"})); struct ImageFormatData { ast::ImageFormat fmt;