// Copyright 2020 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/resolver/resolver.h" #include #include "gmock/gmock.h" #include "gtest/gtest-spi.h" #include "src/tint/ast/assignment_statement.h" #include "src/tint/ast/bitcast_expression.h" #include "src/tint/ast/break_statement.h" #include "src/tint/ast/builtin_texture_helper_test.h" #include "src/tint/ast/call_statement.h" #include "src/tint/ast/continue_statement.h" #include "src/tint/ast/float_literal_expression.h" #include "src/tint/ast/id_attribute.h" #include "src/tint/ast/if_statement.h" #include "src/tint/ast/loop_statement.h" #include "src/tint/ast/return_statement.h" #include "src/tint/ast/stage_attribute.h" #include "src/tint/ast/switch_statement.h" #include "src/tint/ast/unary_op_expression.h" #include "src/tint/ast/variable_decl_statement.h" #include "src/tint/ast/workgroup_attribute.h" #include "src/tint/resolver/resolver_test_helper.h" #include "src/tint/sem/call.h" #include "src/tint/sem/function.h" #include "src/tint/sem/member_accessor_expression.h" #include "src/tint/sem/module.h" #include "src/tint/sem/reference.h" #include "src/tint/sem/sampled_texture.h" #include "src/tint/sem/statement.h" #include "src/tint/sem/switch_statement.h" #include "src/tint/sem/variable.h" using ::testing::ElementsAre; using ::testing::HasSubstr; using namespace tint::number_suffixes; // NOLINT namespace tint::resolver { namespace { // Helpers and typedefs template using DataType = builder::DataType; template using vec = builder::vec; template using vec2 = builder::vec2; template using vec3 = builder::vec3; template using vec4 = builder::vec4; template using mat = builder::mat; template using mat2x2 = builder::mat2x2; template using mat2x3 = builder::mat2x3; template using mat3x2 = builder::mat3x2; template using mat3x3 = builder::mat3x3; template using mat4x4 = builder::mat4x4; template using alias = builder::alias; template using alias1 = builder::alias1; template using alias2 = builder::alias2; template using alias3 = builder::alias3; using Op = ast::BinaryOp; TEST_F(ResolverTest, Stmt_Assign) { auto* v = Var("v", ty.f32()); auto* lhs = Expr("v"); auto* rhs = Expr(2.3_f); auto* assign = Assign(lhs, rhs); WrapInFunction(v, assign); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(lhs), nullptr); ASSERT_NE(TypeOf(rhs), nullptr); EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(rhs)->Is()); EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(rhs), assign); } TEST_F(ResolverTest, Stmt_Case) { auto* v = Var("v", ty.f32()); auto* lhs = Expr("v"); auto* rhs = Expr(2.3_f); auto* assign = Assign(lhs, rhs); auto* block = Block(assign); auto* sel = CaseSelector(3_i); auto* cse = Case(sel, block); auto* def = DefaultCase(); auto* cond_var = Var("c", ty.i32()); auto* sw = Switch(cond_var, cse, def); WrapInFunction(v, cond_var, sw); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(lhs), nullptr); ASSERT_NE(TypeOf(rhs), nullptr); EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(rhs)->Is()); EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(rhs), assign); EXPECT_EQ(BlockOf(assign), block); auto* sem = Sem().Get(sw); ASSERT_EQ(sem->Cases().size(), 2u); EXPECT_EQ(sem->Cases()[0]->Declaration(), cse); ASSERT_EQ(sem->Cases()[0]->Selectors().size(), 1u); EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 1u); } TEST_F(ResolverTest, Stmt_Case_AddressOf_Invalid) { auto* cond_var = Var("i", ty.i32()); WrapInFunction(cond_var, Switch("i", Case(CaseSelector(AddressOf(1_a)), Block()))); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "error: cannot take the address of expression"); } TEST_F(ResolverTest, Stmt_Block) { auto* v = Var("v", ty.f32()); auto* lhs = Expr("v"); auto* rhs = Expr(2.3_f); auto* assign = Assign(lhs, rhs); auto* block = Block(assign); WrapInFunction(v, block); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(lhs), nullptr); ASSERT_NE(TypeOf(rhs), nullptr); EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(rhs)->Is()); EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(rhs), assign); EXPECT_EQ(BlockOf(lhs), block); EXPECT_EQ(BlockOf(rhs), block); EXPECT_EQ(BlockOf(assign), block); } TEST_F(ResolverTest, Stmt_If) { auto* v = Var("v", ty.f32()); auto* else_lhs = Expr("v"); auto* else_rhs = Expr(2.3_f); auto* else_body = Block(Assign(else_lhs, else_rhs)); auto* else_cond = Expr(true); auto* else_stmt = If(else_cond, else_body); auto* lhs = Expr("v"); auto* rhs = Expr(2.3_f); auto* assign = Assign(lhs, rhs); auto* body = Block(assign); auto* cond = Expr(true); auto* stmt = If(cond, body, Else(else_stmt)); WrapInFunction(v, stmt); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(stmt->condition), nullptr); ASSERT_NE(TypeOf(else_lhs), nullptr); ASSERT_NE(TypeOf(else_rhs), nullptr); ASSERT_NE(TypeOf(lhs), nullptr); ASSERT_NE(TypeOf(rhs), nullptr); EXPECT_TRUE(TypeOf(stmt->condition)->Is()); EXPECT_TRUE(TypeOf(else_lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(else_rhs)->Is()); EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(rhs)->Is()); EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(rhs), assign); EXPECT_EQ(StmtOf(cond), stmt); EXPECT_EQ(StmtOf(else_cond), else_stmt); EXPECT_EQ(BlockOf(lhs), body); EXPECT_EQ(BlockOf(rhs), body); EXPECT_EQ(BlockOf(else_lhs), else_body); EXPECT_EQ(BlockOf(else_rhs), else_body); } TEST_F(ResolverTest, Stmt_Loop) { auto* v = Var("v", ty.f32()); auto* body_lhs = Expr("v"); auto* body_rhs = Expr(2.3_f); auto* body = Block(Assign(body_lhs, body_rhs), Break()); auto* continuing_lhs = Expr("v"); auto* continuing_rhs = Expr(2.3_f); auto* break_if = BreakIf(false); auto* continuing = Block(Assign(continuing_lhs, continuing_rhs), break_if); auto* stmt = Loop(body, continuing); WrapInFunction(v, stmt); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(body_lhs), nullptr); ASSERT_NE(TypeOf(body_rhs), nullptr); ASSERT_NE(TypeOf(continuing_lhs), nullptr); ASSERT_NE(TypeOf(continuing_rhs), nullptr); EXPECT_TRUE(TypeOf(body_lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(body_rhs)->Is()); EXPECT_TRUE(TypeOf(continuing_lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(continuing_rhs)->Is()); EXPECT_EQ(BlockOf(body_lhs), body); EXPECT_EQ(BlockOf(body_rhs), body); EXPECT_EQ(BlockOf(continuing_lhs), continuing); EXPECT_EQ(BlockOf(continuing_rhs), continuing); EXPECT_EQ(BlockOf(break_if), continuing); } TEST_F(ResolverTest, Stmt_Return) { auto* cond = Expr(2_i); auto* ret = Return(cond); Func("test", utils::Empty, ty.i32(), utils::Vector{ret}, utils::Empty); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(cond), nullptr); EXPECT_TRUE(TypeOf(cond)->Is()); } TEST_F(ResolverTest, Stmt_Return_WithoutValue) { auto* ret = Return(); WrapInFunction(ret); EXPECT_TRUE(r()->Resolve()) << r()->error(); } TEST_F(ResolverTest, Stmt_Switch) { auto* v = Var("v", ty.f32()); auto* lhs = Expr("v"); auto* rhs = Expr(2.3_f); auto* case_block = Block(Assign(lhs, rhs)); auto* stmt = Switch(Expr(2_i), Case(CaseSelector(3_i), case_block), DefaultCase()); WrapInFunction(v, stmt); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(stmt->condition), nullptr); ASSERT_NE(TypeOf(lhs), nullptr); ASSERT_NE(TypeOf(rhs), nullptr); EXPECT_TRUE(TypeOf(stmt->condition)->Is()); EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is()); EXPECT_TRUE(TypeOf(rhs)->Is()); EXPECT_EQ(BlockOf(lhs), case_block); EXPECT_EQ(BlockOf(rhs), case_block); } TEST_F(ResolverTest, Stmt_Call) { Func("my_func", utils::Empty, ty.void_(), utils::Vector{ Return(), }); auto* expr = Call("my_func"); auto* call = CallStmt(expr); WrapInFunction(call); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(expr), nullptr); EXPECT_TRUE(TypeOf(expr)->Is()); EXPECT_EQ(StmtOf(expr), call); } TEST_F(ResolverTest, Stmt_VariableDecl) { auto* var = Var("my_var", ty.i32(), Expr(2_i)); auto* init = var->initializer; auto* decl = Decl(var); WrapInFunction(decl); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(init), nullptr); EXPECT_TRUE(TypeOf(init)->Is()); } TEST_F(ResolverTest, Stmt_VariableDecl_Alias) { auto* my_int = Alias("MyInt", ty.i32()); auto* var = Var("my_var", ty.Of(my_int), Expr(2_i)); auto* init = var->initializer; auto* decl = Decl(var); WrapInFunction(decl); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(init), nullptr); EXPECT_TRUE(TypeOf(init)->Is()); } TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScope) { auto* init = Expr(2_i); GlobalVar("my_var", ty.i32(), ast::AddressSpace::kPrivate, init); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(init), nullptr); EXPECT_TRUE(TypeOf(init)->Is()); EXPECT_EQ(StmtOf(init), nullptr); } TEST_F(ResolverTest, Stmt_VariableDecl_OuterScopeAfterInnerScope) { // fn func_i32() { // { // var foo : i32 = 2; // var bar : i32 = foo; // } // var foo : f32 = 2.0; // var bar : f32 = foo; // } // Declare i32 "foo" inside a block auto* foo_i32 = Var("foo", ty.i32(), Expr(2_i)); auto* foo_i32_init = foo_i32->initializer; auto* foo_i32_decl = Decl(foo_i32); // Reference "foo" inside the block auto* bar_i32 = Var("bar", ty.i32(), Expr("foo")); auto* bar_i32_init = bar_i32->initializer; auto* bar_i32_decl = Decl(bar_i32); auto* inner = Block(foo_i32_decl, bar_i32_decl); // Declare f32 "foo" at function scope auto* foo_f32 = Var("foo", ty.f32(), Expr(2_f)); auto* foo_f32_init = foo_f32->initializer; auto* foo_f32_decl = Decl(foo_f32); // Reference "foo" at function scope auto* bar_f32 = Var("bar", ty.f32(), Expr("foo")); auto* bar_f32_init = bar_f32->initializer; auto* bar_f32_decl = Decl(bar_f32); Func("func", utils::Empty, ty.void_(), utils::Vector{inner, foo_f32_decl, bar_f32_decl}); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(foo_i32_init), nullptr); EXPECT_TRUE(TypeOf(foo_i32_init)->Is()); ASSERT_NE(TypeOf(foo_f32_init), nullptr); EXPECT_TRUE(TypeOf(foo_f32_init)->Is()); ASSERT_NE(TypeOf(bar_i32_init), nullptr); EXPECT_TRUE(TypeOf(bar_i32_init)->UnwrapRef()->Is()); ASSERT_NE(TypeOf(bar_f32_init), nullptr); EXPECT_TRUE(TypeOf(bar_f32_init)->UnwrapRef()->Is()); EXPECT_EQ(StmtOf(foo_i32_init), foo_i32_decl); EXPECT_EQ(StmtOf(bar_i32_init), bar_i32_decl); EXPECT_EQ(StmtOf(foo_f32_init), foo_f32_decl); EXPECT_EQ(StmtOf(bar_f32_init), bar_f32_decl); EXPECT_TRUE(CheckVarUsers(foo_i32, utils::Vector{bar_i32->initializer})); EXPECT_TRUE(CheckVarUsers(foo_f32, utils::Vector{bar_f32->initializer})); ASSERT_NE(VarOf(bar_i32->initializer), nullptr); EXPECT_EQ(VarOf(bar_i32->initializer)->Declaration(), foo_i32); ASSERT_NE(VarOf(bar_f32->initializer), nullptr); EXPECT_EQ(VarOf(bar_f32->initializer)->Declaration(), foo_f32); } TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScopeAfterFunctionScope) { // fn func_i32() { // var foo : i32 = 2; // } // var foo : f32 = 2.0; // fn func_f32() { // var bar : f32 = foo; // } // Declare i32 "foo" inside a function auto* fn_i32 = Var("foo", ty.i32(), Expr(2_i)); auto* fn_i32_init = fn_i32->initializer; auto* fn_i32_decl = Decl(fn_i32); Func("func_i32", utils::Empty, ty.void_(), utils::Vector{fn_i32_decl}); // Declare f32 "foo" at module scope auto* mod_f32 = Var("foo", ty.f32(), ast::AddressSpace::kPrivate, Expr(2_f)); auto* mod_init = mod_f32->initializer; AST().AddGlobalVariable(mod_f32); // Reference "foo" in another function auto* fn_f32 = Var("bar", ty.f32(), Expr("foo")); auto* fn_f32_init = fn_f32->initializer; auto* fn_f32_decl = Decl(fn_f32); Func("func_f32", utils::Empty, ty.void_(), utils::Vector{fn_f32_decl}); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mod_init), nullptr); EXPECT_TRUE(TypeOf(mod_init)->Is()); ASSERT_NE(TypeOf(fn_i32_init), nullptr); EXPECT_TRUE(TypeOf(fn_i32_init)->Is()); ASSERT_NE(TypeOf(fn_f32_init), nullptr); EXPECT_TRUE(TypeOf(fn_f32_init)->UnwrapRef()->Is()); EXPECT_EQ(StmtOf(fn_i32_init), fn_i32_decl); EXPECT_EQ(StmtOf(mod_init), nullptr); EXPECT_EQ(StmtOf(fn_f32_init), fn_f32_decl); EXPECT_TRUE(CheckVarUsers(fn_i32, utils::Empty)); EXPECT_TRUE(CheckVarUsers(mod_f32, utils::Vector{fn_f32->initializer})); ASSERT_NE(VarOf(fn_f32->initializer), nullptr); EXPECT_EQ(VarOf(fn_f32->initializer)->Declaration(), mod_f32); } TEST_F(ResolverTest, ArraySize_UnsignedLiteral) { // var a : array; auto* a = GlobalVar("a", ty.array(ty.f32(), Expr(10_u)), ast::AddressSpace::kPrivate); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(a), nullptr); auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); } TEST_F(ResolverTest, ArraySize_SignedLiteral) { // var a : array; auto* a = GlobalVar("a", ty.array(ty.f32(), Expr(10_i)), ast::AddressSpace::kPrivate); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(a), nullptr); auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); } TEST_F(ResolverTest, ArraySize_UnsignedConst) { // const size = 10u; // var a : array; GlobalConst("size", Expr(10_u)); auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::AddressSpace::kPrivate); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(a), nullptr); auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); } TEST_F(ResolverTest, ArraySize_SignedConst) { // const size = 0; // var a : array; GlobalConst("size", Expr(10_i)); auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::AddressSpace::kPrivate); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(a), nullptr); auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); } TEST_F(ResolverTest, ArraySize_Override) { // override size = 0; // var a : array; auto* override = Override("size", Expr(10_i)); auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::AddressSpace::kWorkgroup); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(a), nullptr); auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); auto* sem_override = Sem().Get(override); ASSERT_NE(sem_override, nullptr); EXPECT_EQ(ary->Count(), sem::OverrideArrayCount{sem_override}); } TEST_F(ResolverTest, ArraySize_Override_Equivalence) { // override size = 0; // var a : array; // var b : array; auto* override = Override("size", Expr(10_i)); auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::AddressSpace::kWorkgroup); auto* b = GlobalVar("b", ty.array(ty.f32(), Expr("size")), ast::AddressSpace::kWorkgroup); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(a), nullptr); auto* ref_a = TypeOf(a)->As(); ASSERT_NE(ref_a, nullptr); auto* ary_a = ref_a->StoreType()->As(); ASSERT_NE(TypeOf(b), nullptr); auto* ref_b = TypeOf(b)->As(); ASSERT_NE(ref_b, nullptr); auto* ary_b = ref_b->StoreType()->As(); auto* sem_override = Sem().Get(override); ASSERT_NE(sem_override, nullptr); EXPECT_EQ(ary_a->Count(), sem::OverrideArrayCount{sem_override}); EXPECT_EQ(ary_b->Count(), sem::OverrideArrayCount{sem_override}); EXPECT_EQ(ary_a, ary_b); } TEST_F(ResolverTest, Expr_Bitcast) { GlobalVar("name", ty.f32(), ast::AddressSpace::kPrivate); auto* bitcast = create(ty.f32(), Expr("name")); WrapInFunction(bitcast); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(bitcast), nullptr); EXPECT_TRUE(TypeOf(bitcast)->Is()); } TEST_F(ResolverTest, Expr_Call) { Func("my_func", utils::Empty, ty.f32(), utils::Vector{Return(0_f)}); auto* call = Call("my_func"); WrapInFunction(call); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(call), nullptr); EXPECT_TRUE(TypeOf(call)->Is()); } TEST_F(ResolverTest, Expr_Call_InBinaryOp) { Func("func", utils::Empty, ty.f32(), utils::Vector{Return(0_f)}); auto* expr = Add(Call("func"), Call("func")); WrapInFunction(expr); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(expr), nullptr); EXPECT_TRUE(TypeOf(expr)->Is()); } TEST_F(ResolverTest, Expr_Call_WithParams) { Func("my_func", utils::Vector{Param(Sym(), ty.f32())}, ty.f32(), utils::Vector{ Return(1.2_f), }); auto* param = Expr(2.4_f); auto* call = Call("my_func", param); WrapInFunction(call); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(param), nullptr); EXPECT_TRUE(TypeOf(param)->Is()); } TEST_F(ResolverTest, Expr_Call_Builtin) { auto* call = Call("round", 2.4_f); WrapInFunction(call); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(call), nullptr); EXPECT_TRUE(TypeOf(call)->Is()); } TEST_F(ResolverTest, Expr_Cast) { GlobalVar("name", ty.f32(), ast::AddressSpace::kPrivate); auto* cast = Construct(ty.f32(), "name"); WrapInFunction(cast); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(cast), nullptr); EXPECT_TRUE(TypeOf(cast)->Is()); } TEST_F(ResolverTest, Expr_Initializer_Scalar) { auto* s = Expr(1_f); WrapInFunction(s); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(s), nullptr); EXPECT_TRUE(TypeOf(s)->Is()); } TEST_F(ResolverTest, Expr_Initializer_Type_Vec2) { auto* tc = vec2(1_f, 1_f); WrapInFunction(tc); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(tc), nullptr); ASSERT_TRUE(TypeOf(tc)->Is()); EXPECT_TRUE(TypeOf(tc)->As()->type()->Is()); EXPECT_EQ(TypeOf(tc)->As()->Width(), 2u); } TEST_F(ResolverTest, Expr_Initializer_Type_Vec3) { auto* tc = vec3(1_f, 1_f, 1_f); WrapInFunction(tc); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(tc), nullptr); ASSERT_TRUE(TypeOf(tc)->Is()); EXPECT_TRUE(TypeOf(tc)->As()->type()->Is()); EXPECT_EQ(TypeOf(tc)->As()->Width(), 3u); } TEST_F(ResolverTest, Expr_Initializer_Type_Vec4) { auto* tc = vec4(1_f, 1_f, 1_f, 1_f); WrapInFunction(tc); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(tc), nullptr); ASSERT_TRUE(TypeOf(tc)->Is()); EXPECT_TRUE(TypeOf(tc)->As()->type()->Is()); EXPECT_EQ(TypeOf(tc)->As()->Width(), 4u); } TEST_F(ResolverTest, Expr_Identifier_GlobalVariable) { auto* my_var = GlobalVar("my_var", ty.f32(), ast::AddressSpace::kPrivate); auto* ident = Expr("my_var"); WrapInFunction(ident); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(ident), nullptr); ASSERT_TRUE(TypeOf(ident)->Is()); EXPECT_TRUE(TypeOf(ident)->UnwrapRef()->Is()); EXPECT_TRUE(CheckVarUsers(my_var, utils::Vector{ident})); ASSERT_NE(VarOf(ident), nullptr); EXPECT_EQ(VarOf(ident)->Declaration(), my_var); } TEST_F(ResolverTest, Expr_Identifier_GlobalConst) { auto* my_var = GlobalConst("my_var", ty.f32(), Construct(ty.f32())); auto* ident = Expr("my_var"); WrapInFunction(ident); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(ident), nullptr); EXPECT_TRUE(TypeOf(ident)->Is()); EXPECT_TRUE(CheckVarUsers(my_var, utils::Vector{ident})); ASSERT_NE(VarOf(ident), nullptr); EXPECT_EQ(VarOf(ident)->Declaration(), my_var); } TEST_F(ResolverTest, Expr_Identifier_FunctionVariable_Const) { auto* my_var_a = Expr("my_var"); auto* var = Let("my_var", ty.f32(), Construct(ty.f32())); auto* decl = Decl(Var("b", ty.f32(), my_var_a)); Func("my_func", utils::Empty, ty.void_(), utils::Vector{ Decl(var), decl, }); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(my_var_a), nullptr); EXPECT_TRUE(TypeOf(my_var_a)->Is()); EXPECT_EQ(StmtOf(my_var_a), decl); EXPECT_TRUE(CheckVarUsers(var, utils::Vector{my_var_a})); ASSERT_NE(VarOf(my_var_a), nullptr); EXPECT_EQ(VarOf(my_var_a)->Declaration(), var); } TEST_F(ResolverTest, IndexAccessor_Dynamic_Ref_F32) { // var a : array = 0; // var idx : f32 = f32(); // var f : f32 = a[idx]; auto* a = Var("a", ty.array(), array()); auto* idx = Var("idx", ty.f32(), Construct(ty.f32())); auto* f = Var("f", ty.f32(), IndexAccessor("a", Expr(Source{{12, 34}}, idx))); Func("my_func", utils::Empty, ty.void_(), utils::Vector{ Decl(a), Decl(idx), Decl(f), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: index must be of type 'i32' or 'u32', found: 'f32'"); } TEST_F(ResolverTest, Expr_Identifier_FunctionVariable) { auto* my_var_a = Expr("my_var"); auto* my_var_b = Expr("my_var"); auto* assign = Assign(my_var_a, my_var_b); auto* var = Var("my_var", ty.f32()); Func("my_func", utils::Empty, ty.void_(), utils::Vector{ Decl(var), assign, }); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(my_var_a), nullptr); ASSERT_TRUE(TypeOf(my_var_a)->Is()); EXPECT_TRUE(TypeOf(my_var_a)->UnwrapRef()->Is()); EXPECT_EQ(StmtOf(my_var_a), assign); ASSERT_NE(TypeOf(my_var_b), nullptr); ASSERT_TRUE(TypeOf(my_var_b)->Is()); EXPECT_TRUE(TypeOf(my_var_b)->UnwrapRef()->Is()); EXPECT_EQ(StmtOf(my_var_b), assign); EXPECT_TRUE(CheckVarUsers(var, utils::Vector{my_var_a, my_var_b})); ASSERT_NE(VarOf(my_var_a), nullptr); EXPECT_EQ(VarOf(my_var_a)->Declaration(), var); ASSERT_NE(VarOf(my_var_b), nullptr); EXPECT_EQ(VarOf(my_var_b)->Declaration(), var); } TEST_F(ResolverTest, Expr_Identifier_Function_Ptr) { auto* v = Expr("v"); auto* p = Expr("p"); auto* v_decl = Decl(Var("v", ty.f32())); auto* p_decl = Decl(Let("p", ty.pointer(ast::AddressSpace::kFunction), AddressOf(v))); auto* assign = Assign(Deref(p), 1.23_f); Func("my_func", utils::Empty, ty.void_(), utils::Vector{ v_decl, p_decl, assign, }); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(v), nullptr); ASSERT_TRUE(TypeOf(v)->Is()); EXPECT_TRUE(TypeOf(v)->UnwrapRef()->Is()); EXPECT_EQ(StmtOf(v), p_decl); ASSERT_NE(TypeOf(p), nullptr); ASSERT_TRUE(TypeOf(p)->Is()); EXPECT_TRUE(TypeOf(p)->UnwrapPtr()->Is()); EXPECT_EQ(StmtOf(p), assign); } TEST_F(ResolverTest, Expr_Call_Function) { Func("my_func", utils::Empty, ty.f32(), utils::Vector{ Return(0_f), }); auto* call = Call("my_func"); WrapInFunction(call); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(call), nullptr); EXPECT_TRUE(TypeOf(call)->Is()); } TEST_F(ResolverTest, Expr_Identifier_Unknown) { auto* a = Expr("a"); WrapInFunction(a); EXPECT_FALSE(r()->Resolve()); } TEST_F(ResolverTest, Function_Parameters) { auto* param_a = Param("a", ty.f32()); auto* param_b = Param("b", ty.i32()); auto* param_c = Param("c", ty.u32()); auto* func = Func("my_func", utils::Vector{ param_a, param_b, param_c, }, ty.void_(), utils::Empty); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->Parameters().Length(), 3u); EXPECT_TRUE(func_sem->Parameters()[0]->Type()->Is()); EXPECT_TRUE(func_sem->Parameters()[1]->Type()->Is()); EXPECT_TRUE(func_sem->Parameters()[2]->Type()->Is()); EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a); EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b); EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c); EXPECT_TRUE(func_sem->ReturnType()->Is()); } TEST_F(ResolverTest, Function_Parameters_Locations) { auto* param_a = Param("a", ty.f32(), utils::Vector{Location(3_a)}); auto* param_b = Param("b", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kVertexIndex)}); auto* param_c = Param("c", ty.u32(), utils::Vector{Location(1_a)}); GlobalVar("my_vec", ty.vec4(), ast::AddressSpace::kPrivate); auto* func = Func("my_func", utils::Vector{ param_a, param_b, param_c, }, ty.vec4(), utils::Vector{ Return("my_vec"), }, utils::Vector{ Stage(ast::PipelineStage::kVertex), }, utils::Vector{ Builtin(ast::BuiltinValue::kPosition), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->Parameters().Length(), 3u); EXPECT_EQ(3u, func_sem->Parameters()[0]->Location()); EXPECT_FALSE(func_sem->Parameters()[1]->Location().has_value()); EXPECT_EQ(1u, func_sem->Parameters()[2]->Location()); } TEST_F(ResolverTest, Function_GlobalVariable_Location) { auto* var = GlobalVar( "my_vec", ty.vec4(), ast::AddressSpace::kIn, utils::Vector{Location(3_a), Disable(ast::DisabledValidation::kIgnoreAddressSpace)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(var); ASSERT_NE(sem, nullptr); EXPECT_EQ(3u, sem->Location()); } TEST_F(ResolverTest, Function_RegisterInputOutputVariables) { auto* s = Structure("S", utils::Vector{Member("m", ty.u32())}); auto* sb_var = GlobalVar("sb_var", ty.Of(s), ast::AddressSpace::kStorage, ast::Access::kReadWrite, Binding(0_a), Group(0_a)); auto* wg_var = GlobalVar("wg_var", ty.f32(), ast::AddressSpace::kWorkgroup); auto* priv_var = GlobalVar("priv_var", ty.f32(), ast::AddressSpace::kPrivate); auto* func = Func("my_func", utils::Empty, ty.void_(), utils::Vector{ Assign("wg_var", "wg_var"), Assign("sb_var", "sb_var"), Assign("priv_var", "priv_var"), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->Parameters().Length(), 0u); EXPECT_TRUE(func_sem->ReturnType()->Is()); const auto& vars = func_sem->TransitivelyReferencedGlobals(); ASSERT_EQ(vars.Length(), 3u); EXPECT_EQ(vars[0]->Declaration(), wg_var); EXPECT_EQ(vars[1]->Declaration(), sb_var); EXPECT_EQ(vars[2]->Declaration(), priv_var); } TEST_F(ResolverTest, Function_ReturnType_Location) { auto* func = Func("my_func", utils::Empty, ty.f32(), utils::Vector{ Return(1_f), }, utils::Vector{ Stage(ast::PipelineStage::kFragment), }, utils::Vector{ Location(2_a), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(func); ASSERT_NE(nullptr, sem); EXPECT_EQ(2u, sem->ReturnLocation()); } TEST_F(ResolverTest, Function_ReturnType_NoLocation) { GlobalVar("my_vec", ty.vec4(), ast::AddressSpace::kPrivate); auto* func = Func("my_func", utils::Empty, ty.vec4(), utils::Vector{ Return("my_vec"), }, utils::Vector{ Stage(ast::PipelineStage::kVertex), }, utils::Vector{ Builtin(ast::BuiltinValue::kPosition), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(func); ASSERT_NE(nullptr, sem); EXPECT_FALSE(sem->ReturnLocation()); } TEST_F(ResolverTest, Function_RegisterInputOutputVariables_SubFunction) { auto* s = Structure("S", utils::Vector{Member("m", ty.u32())}); auto* sb_var = GlobalVar("sb_var", ty.Of(s), ast::AddressSpace::kStorage, ast::Access::kReadWrite, Binding(0_a), Group(0_a)); auto* wg_var = GlobalVar("wg_var", ty.f32(), ast::AddressSpace::kWorkgroup); auto* priv_var = GlobalVar("priv_var", ty.f32(), ast::AddressSpace::kPrivate); Func("my_func", utils::Empty, ty.f32(), utils::Vector{Assign("wg_var", "wg_var"), Assign("sb_var", "sb_var"), Assign("priv_var", "priv_var"), Return(0_f)}); auto* func2 = Func("func", utils::Empty, ty.void_(), utils::Vector{ WrapInStatement(Call("my_func")), }, utils::Empty); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func2_sem = Sem().Get(func2); ASSERT_NE(func2_sem, nullptr); EXPECT_EQ(func2_sem->Parameters().Length(), 0u); const auto& vars = func2_sem->TransitivelyReferencedGlobals(); ASSERT_EQ(vars.Length(), 3u); EXPECT_EQ(vars[0]->Declaration(), wg_var); EXPECT_EQ(vars[1]->Declaration(), sb_var); EXPECT_EQ(vars[2]->Declaration(), priv_var); } TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) { auto* func = Func("my_func", utils::Empty, ty.void_(), utils::Vector{ Decl(Var("var", ty.f32())), Assign("var", 1_f), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().Length(), 0u); EXPECT_TRUE(func_sem->ReturnType()->Is()); } TEST_F(ResolverTest, Function_NotRegisterFunctionConstant) { auto* func = Func("my_func", utils::Empty, ty.void_(), utils::Vector{ Decl(Let("var", ty.f32(), Construct(ty.f32()))), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().Length(), 0u); EXPECT_TRUE(func_sem->ReturnType()->Is()); } TEST_F(ResolverTest, Function_NotRegisterFunctionParams) { auto* func = Func("my_func", utils::Vector{Param("var", ty.f32())}, ty.void_(), utils::Empty); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().Length(), 0u); EXPECT_TRUE(func_sem->ReturnType()->Is()); } TEST_F(ResolverTest, Function_CallSites) { auto* foo = Func("foo", utils::Empty, ty.void_(), utils::Empty); auto* call_1 = Call("foo"); auto* call_2 = Call("foo"); auto* bar = Func("bar", utils::Empty, ty.void_(), utils::Vector{ CallStmt(call_1), CallStmt(call_2), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* foo_sem = Sem().Get(foo); ASSERT_NE(foo_sem, nullptr); ASSERT_EQ(foo_sem->CallSites().size(), 2u); EXPECT_EQ(foo_sem->CallSites()[0]->Declaration(), call_1); EXPECT_EQ(foo_sem->CallSites()[1]->Declaration(), call_2); auto* bar_sem = Sem().Get(bar); ASSERT_NE(bar_sem, nullptr); EXPECT_EQ(bar_sem->CallSites().size(), 0u); } TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) { // @compute @workgroup_size(1) // fn main() {} auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->WorkgroupSize()[0], 1u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 1u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u); } TEST_F(ResolverTest, Function_WorkgroupSize_Literals) { // @compute @workgroup_size(8, 2, 3) // fn main() {} auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize(8_i, 2_i, 3_i), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 2u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u); } TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst) { // const width = 16i; // const height = 8i; // const depth = 2i; // @compute @workgroup_size(width, height, depth) // fn main() {} GlobalConst("width", ty.i32(), Expr(16_i)); GlobalConst("height", ty.i32(), Expr(8_i)); GlobalConst("depth", ty.i32(), Expr(2_i)); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height", "depth"), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->WorkgroupSize()[0], 16u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 8u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 2u); } TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst_NestedInitializer) { // const width = i32(i32(i32(8i))); // const height = i32(i32(i32(4i))); // @compute @workgroup_size(width, height) // fn main() {} GlobalConst("width", ty.i32(), Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 8_i)))); GlobalConst("height", ty.i32(), Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 4_i)))); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height"), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 4u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u); } TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) { // @id(0) override width = 16i; // @id(1) override height = 8i; // @id(2) override depth = 2i; // @compute @workgroup_size(width, height, depth) // fn main() {} Override("width", ty.i32(), Expr(16_i), Id(0_a)); Override("height", ty.i32(), Expr(8_i), Id(1_a)); Override("depth", ty.i32(), Expr(2_i), Id(2_a)); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height", "depth"), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt); EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt); EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt); } TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) { // @id(0) override width : i32; // @id(1) override height : i32; // @id(2) override depth : i32; // @compute @workgroup_size(width, height, depth) // fn main() {} Override("width", ty.i32(), Id(0_a)); Override("height", ty.i32(), Id(1_a)); Override("depth", ty.i32(), Id(2_a)); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height", "depth"), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt); EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt); EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt); } TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) { // @id(1) override height = 2i; // const depth = 3i; // @compute @workgroup_size(8, height, depth) // fn main() {} Override("height", ty.i32(), Expr(2_i), Id(0_a)); GlobalConst("depth", ty.i32(), Expr(3_i)); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize(8_i, "height", "depth"), }); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u); EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt); EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u); } TEST_F(ResolverTest, Expr_MemberAccessor_Struct) { auto* st = Structure( "S", utils::Vector{Member("first_member", ty.i32()), Member("second_member", ty.f32())}); GlobalVar("my_struct", ty.Of(st), ast::AddressSpace::kPrivate); auto* mem = MemberAccessor("my_struct", "second_member"); WrapInFunction(mem); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); ASSERT_TRUE(TypeOf(mem)->Is()); auto* ref = TypeOf(mem)->As(); EXPECT_TRUE(ref->StoreType()->Is()); auto* sma = Sem().Get(mem)->As(); ASSERT_NE(sma, nullptr); EXPECT_TRUE(sma->Member()->Type()->Is()); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); EXPECT_EQ(sma->Member()->Index(), 1u); EXPECT_EQ(sma->Member()->Declaration()->symbol, Symbols().Get("second_member")); } TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) { auto* st = Structure( "S", utils::Vector{Member("first_member", ty.i32()), Member("second_member", ty.f32())}); auto* alias = Alias("alias", ty.Of(st)); GlobalVar("my_struct", ty.Of(alias), ast::AddressSpace::kPrivate); auto* mem = MemberAccessor("my_struct", "second_member"); WrapInFunction(mem); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); ASSERT_TRUE(TypeOf(mem)->Is()); auto* ref = TypeOf(mem)->As(); EXPECT_TRUE(ref->StoreType()->Is()); auto* sma = Sem().Get(mem)->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); EXPECT_TRUE(sma->Member()->Type()->Is()); EXPECT_EQ(sma->Member()->Index(), 1u); } TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) { GlobalVar("my_vec", ty.vec4(), ast::AddressSpace::kPrivate); auto* mem = MemberAccessor("my_vec", "xzyw"); WrapInFunction(mem); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); ASSERT_TRUE(TypeOf(mem)->Is()); EXPECT_TRUE(TypeOf(mem)->As()->type()->Is()); EXPECT_EQ(TypeOf(mem)->As()->Width(), 4u); auto* sma = Sem().Get(mem)->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); EXPECT_THAT(sma->As()->Indices(), ElementsAre(0, 2, 1, 3)); } TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) { GlobalVar("my_vec", ty.vec3(), ast::AddressSpace::kPrivate); auto* mem = MemberAccessor("my_vec", "b"); WrapInFunction(mem); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); ASSERT_TRUE(TypeOf(mem)->Is()); auto* ref = TypeOf(mem)->As(); ASSERT_TRUE(ref->StoreType()->Is()); auto* sma = Sem().Get(mem)->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Object()->Declaration(), mem->structure); EXPECT_THAT(Sem().Get(mem)->As()->Indices(), ElementsAre(2)); } TEST_F(ResolverTest, Expr_Accessor_MultiLevel) { // struct b { // vec4 foo // } // struct A { // array mem // } // var c : A // c.mem[0].foo.yx // -> vec2 // // fn f() { // c.mem[0].foo // } // auto* stB = Structure("B", utils::Vector{Member("foo", ty.vec4())}); auto* stA = Structure("A", utils::Vector{Member("mem", ty.array(ty.Of(stB), 3_i))}); GlobalVar("c", ty.Of(stA), ast::AddressSpace::kPrivate); auto* mem = MemberAccessor(MemberAccessor(IndexAccessor(MemberAccessor("c", "mem"), 0_i), "foo"), "yx"); WrapInFunction(mem); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(mem), nullptr); ASSERT_TRUE(TypeOf(mem)->Is()); EXPECT_TRUE(TypeOf(mem)->As()->type()->Is()); EXPECT_EQ(TypeOf(mem)->As()->Width(), 2u); ASSERT_TRUE(Sem().Get(mem)->Is()); } TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) { auto* st = Structure( "S", utils::Vector{Member("first_member", ty.f32()), Member("second_member", ty.f32())}); GlobalVar("my_struct", ty.Of(st), ast::AddressSpace::kPrivate); auto* expr = Add(MemberAccessor("my_struct", "first_member"), MemberAccessor("my_struct", "second_member")); WrapInFunction(expr); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(expr), nullptr); EXPECT_TRUE(TypeOf(expr)->Is()); } namespace ExprBinaryTest { template struct Aliased { using type = alias; }; template struct Aliased, ID> { using type = vec>; }; template struct Aliased, ID> { using type = mat>; }; struct Params { ast::BinaryOp op; builder::ast_type_func_ptr create_lhs_type; builder::ast_type_func_ptr create_rhs_type; builder::ast_type_func_ptr create_lhs_alias_type; builder::ast_type_func_ptr create_rhs_alias_type; builder::sem_type_func_ptr create_result_type; }; template constexpr Params ParamsFor(ast::BinaryOp op) { return Params{op, DataType::AST, DataType::AST, DataType::type>::AST, DataType::type>::AST, DataType::Sem}; } static constexpr ast::BinaryOp all_ops[] = { ast::BinaryOp::kAnd, ast::BinaryOp::kOr, ast::BinaryOp::kXor, ast::BinaryOp::kLogicalAnd, ast::BinaryOp::kLogicalOr, ast::BinaryOp::kEqual, ast::BinaryOp::kNotEqual, ast::BinaryOp::kLessThan, ast::BinaryOp::kGreaterThan, ast::BinaryOp::kLessThanEqual, ast::BinaryOp::kGreaterThanEqual, ast::BinaryOp::kShiftLeft, ast::BinaryOp::kShiftRight, ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract, ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide, ast::BinaryOp::kModulo, }; static constexpr builder::ast_type_func_ptr all_create_type_funcs[] = { DataType::AST, // DataType::AST, // DataType::AST, // DataType::AST, // DataType>::AST, // DataType>::AST, // DataType>::AST, // DataType>::AST, // DataType>::AST, // DataType>::AST, // DataType>::AST // }; // A list of all valid test cases for 'lhs op rhs', except that for vecN and // matNxN, we only test N=3. static constexpr Params all_valid_cases[] = { // Logical expressions // https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr // Binary logical expressions ParamsFor(Op::kLogicalAnd), ParamsFor(Op::kLogicalOr), ParamsFor(Op::kAnd), ParamsFor(Op::kOr), ParamsFor, vec3, vec3>(Op::kAnd), ParamsFor, vec3, vec3>(Op::kOr), // Arithmetic expressions // https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr // Binary arithmetic expressions over scalars ParamsFor(Op::kAdd), ParamsFor(Op::kSubtract), ParamsFor(Op::kMultiply), ParamsFor(Op::kDivide), ParamsFor(Op::kModulo), ParamsFor(Op::kAdd), ParamsFor(Op::kSubtract), ParamsFor(Op::kMultiply), ParamsFor(Op::kDivide), ParamsFor(Op::kModulo), ParamsFor(Op::kAdd), ParamsFor(Op::kSubtract), ParamsFor(Op::kMultiply), ParamsFor(Op::kDivide), ParamsFor(Op::kModulo), // Binary arithmetic expressions over vectors ParamsFor, vec3, vec3>(Op::kAdd), ParamsFor, vec3, vec3>(Op::kSubtract), ParamsFor, vec3, vec3>(Op::kMultiply), ParamsFor, vec3, vec3>(Op::kDivide), ParamsFor, vec3, vec3>(Op::kModulo), ParamsFor, vec3, vec3>(Op::kAdd), ParamsFor, vec3, vec3>(Op::kSubtract), ParamsFor, vec3, vec3>(Op::kMultiply), ParamsFor, vec3, vec3>(Op::kDivide), ParamsFor, vec3, vec3>(Op::kModulo), ParamsFor, vec3, vec3>(Op::kAdd), ParamsFor, vec3, vec3>(Op::kSubtract), ParamsFor, vec3, vec3>(Op::kMultiply), ParamsFor, vec3, vec3>(Op::kDivide), ParamsFor, vec3, vec3>(Op::kModulo), // Binary arithmetic expressions with mixed scalar and vector operands ParamsFor, i32, vec3>(Op::kAdd), ParamsFor, i32, vec3>(Op::kSubtract), ParamsFor, i32, vec3>(Op::kMultiply), ParamsFor, i32, vec3>(Op::kDivide), ParamsFor, i32, vec3>(Op::kModulo), ParamsFor, vec3>(Op::kAdd), ParamsFor, vec3>(Op::kSubtract), ParamsFor, vec3>(Op::kMultiply), ParamsFor, vec3>(Op::kDivide), ParamsFor, vec3>(Op::kModulo), ParamsFor, u32, vec3>(Op::kAdd), ParamsFor, u32, vec3>(Op::kSubtract), ParamsFor, u32, vec3>(Op::kMultiply), ParamsFor, u32, vec3>(Op::kDivide), ParamsFor, u32, vec3>(Op::kModulo), ParamsFor, vec3>(Op::kAdd), ParamsFor, vec3>(Op::kSubtract), ParamsFor, vec3>(Op::kMultiply), ParamsFor, vec3>(Op::kDivide), ParamsFor, vec3>(Op::kModulo), ParamsFor, f32, vec3>(Op::kAdd), ParamsFor, f32, vec3>(Op::kSubtract), ParamsFor, f32, vec3>(Op::kMultiply), ParamsFor, f32, vec3>(Op::kDivide), ParamsFor, f32, vec3>(Op::kModulo), ParamsFor, vec3>(Op::kAdd), ParamsFor, vec3>(Op::kSubtract), ParamsFor, vec3>(Op::kMultiply), ParamsFor, vec3>(Op::kDivide), ParamsFor, vec3>(Op::kModulo), // Matrix arithmetic ParamsFor, f32, mat2x3>(Op::kMultiply), ParamsFor, f32, mat3x2>(Op::kMultiply), ParamsFor, f32, mat3x3>(Op::kMultiply), ParamsFor, mat2x3>(Op::kMultiply), ParamsFor, mat3x2>(Op::kMultiply), ParamsFor, mat3x3>(Op::kMultiply), ParamsFor, mat2x3, vec2>(Op::kMultiply), ParamsFor, mat3x2, vec3>(Op::kMultiply), ParamsFor, mat3x3, vec3>(Op::kMultiply), ParamsFor, vec3, vec2>(Op::kMultiply), ParamsFor, vec2, vec3>(Op::kMultiply), ParamsFor, vec3, vec3>(Op::kMultiply), ParamsFor, mat3x2, mat3x3>(Op::kMultiply), ParamsFor, mat2x3, mat2x2>(Op::kMultiply), ParamsFor, mat3x3, mat3x2>(Op::kMultiply), ParamsFor, mat3x3, mat3x3>(Op::kMultiply), ParamsFor, mat2x3, mat2x3>(Op::kMultiply), ParamsFor, mat2x3, mat2x3>(Op::kAdd), ParamsFor, mat3x2, mat3x2>(Op::kAdd), ParamsFor, mat3x3, mat3x3>(Op::kAdd), ParamsFor, mat2x3, mat2x3>(Op::kSubtract), ParamsFor, mat3x2, mat3x2>(Op::kSubtract), ParamsFor, mat3x3, mat3x3>(Op::kSubtract), // Comparison expressions // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr // Comparisons over scalars ParamsFor(Op::kEqual), ParamsFor(Op::kNotEqual), ParamsFor(Op::kEqual), ParamsFor(Op::kNotEqual), ParamsFor(Op::kLessThan), ParamsFor(Op::kLessThanEqual), ParamsFor(Op::kGreaterThan), ParamsFor(Op::kGreaterThanEqual), ParamsFor(Op::kEqual), ParamsFor(Op::kNotEqual), ParamsFor(Op::kLessThan), ParamsFor(Op::kLessThanEqual), ParamsFor(Op::kGreaterThan), ParamsFor(Op::kGreaterThanEqual), ParamsFor(Op::kEqual), ParamsFor(Op::kNotEqual), ParamsFor(Op::kLessThan), ParamsFor(Op::kLessThanEqual), ParamsFor(Op::kGreaterThan), ParamsFor(Op::kGreaterThanEqual), // Comparisons over vectors ParamsFor, vec3, vec3>(Op::kEqual), ParamsFor, vec3, vec3>(Op::kNotEqual), ParamsFor, vec3, vec3>(Op::kEqual), ParamsFor, vec3, vec3>(Op::kNotEqual), ParamsFor, vec3, vec3>(Op::kLessThan), ParamsFor, vec3, vec3>(Op::kLessThanEqual), ParamsFor, vec3, vec3>(Op::kGreaterThan), ParamsFor, vec3, vec3>(Op::kGreaterThanEqual), ParamsFor, vec3, vec3>(Op::kEqual), ParamsFor, vec3, vec3>(Op::kNotEqual), ParamsFor, vec3, vec3>(Op::kLessThan), ParamsFor, vec3, vec3>(Op::kLessThanEqual), ParamsFor, vec3, vec3>(Op::kGreaterThan), ParamsFor, vec3, vec3>(Op::kGreaterThanEqual), ParamsFor, vec3, vec3>(Op::kEqual), ParamsFor, vec3, vec3>(Op::kNotEqual), ParamsFor, vec3, vec3>(Op::kLessThan), ParamsFor, vec3, vec3>(Op::kLessThanEqual), ParamsFor, vec3, vec3>(Op::kGreaterThan), ParamsFor, vec3, vec3>(Op::kGreaterThanEqual), // Binary bitwise operations ParamsFor(Op::kOr), ParamsFor(Op::kAnd), ParamsFor(Op::kXor), ParamsFor(Op::kOr), ParamsFor(Op::kAnd), ParamsFor(Op::kXor), ParamsFor, vec3, vec3>(Op::kOr), ParamsFor, vec3, vec3>(Op::kAnd), ParamsFor, vec3, vec3>(Op::kXor), ParamsFor, vec3, vec3>(Op::kOr), ParamsFor, vec3, vec3>(Op::kAnd), ParamsFor, vec3, vec3>(Op::kXor), // Bit shift expressions ParamsFor(Op::kShiftLeft), ParamsFor, vec3, vec3>(Op::kShiftLeft), ParamsFor(Op::kShiftLeft), ParamsFor, vec3, vec3>(Op::kShiftLeft), ParamsFor(Op::kShiftRight), ParamsFor, vec3, vec3>(Op::kShiftRight), ParamsFor(Op::kShiftRight), ParamsFor, vec3, vec3>(Op::kShiftRight), }; using Expr_Binary_Test_Valid = ResolverTestWithParam; TEST_P(Expr_Binary_Test_Valid, All) { auto& params = GetParam(); auto* lhs_type = params.create_lhs_type(*this); auto* rhs_type = params.create_rhs_type(*this); auto* result_type = params.create_result_type(*this); std::stringstream ss; ss << FriendlyName(lhs_type) << " " << params.op << " " << FriendlyName(rhs_type); SCOPED_TRACE(ss.str()); GlobalVar("lhs", lhs_type, ast::AddressSpace::kPrivate); GlobalVar("rhs", rhs_type, ast::AddressSpace::kPrivate); auto* expr = create(params.op, Expr("lhs"), Expr("rhs")); WrapInFunction(expr); ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(expr), nullptr); ASSERT_TRUE(TypeOf(expr) == result_type); } INSTANTIATE_TEST_SUITE_P(ResolverTest, Expr_Binary_Test_Valid, testing::ValuesIn(all_valid_cases)); enum class BinaryExprSide { Left, Right, Both }; using Expr_Binary_Test_WithAlias_Valid = ResolverTestWithParam>; TEST_P(Expr_Binary_Test_WithAlias_Valid, All) { const Params& params = std::get<0>(GetParam()); BinaryExprSide side = std::get<1>(GetParam()); auto* create_lhs_type = (side == BinaryExprSide::Left || side == BinaryExprSide::Both) ? params.create_lhs_alias_type : params.create_lhs_type; auto* create_rhs_type = (side == BinaryExprSide::Right || side == BinaryExprSide::Both) ? params.create_rhs_alias_type : params.create_rhs_type; auto* lhs_type = create_lhs_type(*this); auto* rhs_type = create_rhs_type(*this); std::stringstream ss; ss << FriendlyName(lhs_type) << " " << params.op << " " << FriendlyName(rhs_type); ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op << " " << FriendlyName(rhs_type); SCOPED_TRACE(ss.str()); GlobalVar("lhs", lhs_type, ast::AddressSpace::kPrivate); GlobalVar("rhs", rhs_type, ast::AddressSpace::kPrivate); auto* expr = create(params.op, Expr("lhs"), Expr("rhs")); WrapInFunction(expr); ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(expr), nullptr); // TODO(amaiorano): Bring this back once we have a way to get the canonical // type // auto* *result_type = params.create_result_type(*this); // ASSERT_TRUE(TypeOf(expr) == result_type); } INSTANTIATE_TEST_SUITE_P(ResolverTest, Expr_Binary_Test_WithAlias_Valid, testing::Combine(testing::ValuesIn(all_valid_cases), testing::Values(BinaryExprSide::Left, BinaryExprSide::Right, BinaryExprSide::Both))); // This test works by taking the cartesian product of all possible // (type * type * op), and processing only the triplets that are not found in // the `all_valid_cases` table. using Expr_Binary_Test_Invalid = ResolverTestWithParam< std::tuple>; TEST_P(Expr_Binary_Test_Invalid, All) { const builder::ast_type_func_ptr& lhs_create_type_func = std::get<0>(GetParam()); const builder::ast_type_func_ptr& rhs_create_type_func = std::get<1>(GetParam()); const ast::BinaryOp op = std::get<2>(GetParam()); // Skip if valid case // TODO(amaiorano): replace linear lookup with O(1) if too slow for (auto& c : all_valid_cases) { if (c.create_lhs_type == lhs_create_type_func && c.create_rhs_type == rhs_create_type_func && c.op == op) { return; } } auto* lhs_type = lhs_create_type_func(*this); auto* rhs_type = rhs_create_type_func(*this); std::stringstream ss; ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type); SCOPED_TRACE(ss.str()); GlobalVar("lhs", lhs_type, ast::AddressSpace::kPrivate); GlobalVar("rhs", rhs_type, ast::AddressSpace::kPrivate); auto* expr = create(Source{{12, 34}}, op, Expr("lhs"), Expr("rhs")); WrapInFunction(expr); ASSERT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching overload for operator ")); } INSTANTIATE_TEST_SUITE_P(ResolverTest, Expr_Binary_Test_Invalid, testing::Combine(testing::ValuesIn(all_create_type_funcs), testing::ValuesIn(all_create_type_funcs), testing::ValuesIn(all_ops))); using Expr_Binary_Test_Invalid_VectorMatrixMultiply = ResolverTestWithParam>; TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) { bool vec_by_mat = std::get<0>(GetParam()); uint32_t vec_size = std::get<1>(GetParam()); uint32_t mat_rows = std::get<2>(GetParam()); uint32_t mat_cols = std::get<3>(GetParam()); const ast::Type* lhs_type = nullptr; const ast::Type* rhs_type = nullptr; const sem::Type* result_type = nullptr; bool is_valid_expr; if (vec_by_mat) { lhs_type = ty.vec(vec_size); rhs_type = ty.mat(mat_cols, mat_rows); result_type = create(create(), mat_cols); is_valid_expr = vec_size == mat_rows; } else { lhs_type = ty.mat(mat_cols, mat_rows); rhs_type = ty.vec(vec_size); result_type = create(create(), mat_rows); is_valid_expr = vec_size == mat_cols; } GlobalVar("lhs", lhs_type, ast::AddressSpace::kPrivate); GlobalVar("rhs", rhs_type, ast::AddressSpace::kPrivate); auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs")); WrapInFunction(expr); if (is_valid_expr) { ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(TypeOf(expr) == result_type); } else { ASSERT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("no matching overload for operator *")); } } auto all_dimension_values = testing::Values(2u, 3u, 4u); INSTANTIATE_TEST_SUITE_P(ResolverTest, Expr_Binary_Test_Invalid_VectorMatrixMultiply, testing::Combine(testing::Values(true, false), all_dimension_values, all_dimension_values, all_dimension_values)); using Expr_Binary_Test_Invalid_MatrixMatrixMultiply = ResolverTestWithParam>; TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) { uint32_t lhs_mat_rows = std::get<0>(GetParam()); uint32_t lhs_mat_cols = std::get<1>(GetParam()); uint32_t rhs_mat_rows = std::get<2>(GetParam()); uint32_t rhs_mat_cols = std::get<3>(GetParam()); auto* lhs_type = ty.mat(lhs_mat_cols, lhs_mat_rows); auto* rhs_type = ty.mat(rhs_mat_cols, rhs_mat_rows); auto* f32 = create(); auto* col = create(f32, lhs_mat_rows); auto* result_type = create(col, rhs_mat_cols); GlobalVar("lhs", lhs_type, ast::AddressSpace::kPrivate); GlobalVar("rhs", rhs_type, ast::AddressSpace::kPrivate); auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs")); WrapInFunction(expr); bool is_valid_expr = lhs_mat_cols == rhs_mat_rows; if (is_valid_expr) { ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(TypeOf(expr) == result_type); } else { ASSERT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching overload for operator * ")); } } INSTANTIATE_TEST_SUITE_P(ResolverTest, Expr_Binary_Test_Invalid_MatrixMatrixMultiply, testing::Combine(all_dimension_values, all_dimension_values, all_dimension_values, all_dimension_values)); } // namespace ExprBinaryTest using UnaryOpExpressionTest = ResolverTestWithParam; TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) { auto op = GetParam(); if (op == ast::UnaryOp::kNot) { GlobalVar("ident", ty.vec4(), ast::AddressSpace::kPrivate); } else if (op == ast::UnaryOp::kNegation || op == ast::UnaryOp::kComplement) { GlobalVar("ident", ty.vec4(), ast::AddressSpace::kPrivate); } else { GlobalVar("ident", ty.vec4(), ast::AddressSpace::kPrivate); } auto* der = create(op, Expr("ident")); WrapInFunction(der); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(der), nullptr); ASSERT_TRUE(TypeOf(der)->Is()); if (op == ast::UnaryOp::kNot) { EXPECT_TRUE(TypeOf(der)->As()->type()->Is()); } else if (op == ast::UnaryOp::kNegation || op == ast::UnaryOp::kComplement) { EXPECT_TRUE(TypeOf(der)->As()->type()->Is()); } else { EXPECT_TRUE(TypeOf(der)->As()->type()->Is()); } EXPECT_EQ(TypeOf(der)->As()->Width(), 4u); } INSTANTIATE_TEST_SUITE_P(ResolverTest, UnaryOpExpressionTest, testing::Values(ast::UnaryOp::kComplement, ast::UnaryOp::kNegation, ast::UnaryOp::kNot)); TEST_F(ResolverTest, AddressSpace_SetsIfMissing) { auto* var = Var("var", ty.i32()); auto* stmt = Decl(var); Func("func", utils::Empty, ty.void_(), utils::Vector{stmt}); EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(Sem().Get(var)->AddressSpace(), ast::AddressSpace::kFunction); } TEST_F(ResolverTest, AddressSpace_SetForSampler) { auto* t = ty.sampler(ast::SamplerKind::kSampler); auto* var = GlobalVar("var", t, Binding(0_a), Group(0_a)); EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(Sem().Get(var)->AddressSpace(), ast::AddressSpace::kHandle); } TEST_F(ResolverTest, AddressSpace_SetForTexture) { auto* t = ty.sampled_texture(ast::TextureDimension::k1d, ty.f32()); auto* var = GlobalVar("var", t, Binding(0_a), Group(0_a)); EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(Sem().Get(var)->AddressSpace(), ast::AddressSpace::kHandle); } TEST_F(ResolverTest, AddressSpace_DoesNotSetOnConst) { auto* var = Let("var", ty.i32(), Construct(ty.i32())); auto* stmt = Decl(var); Func("func", utils::Empty, ty.void_(), utils::Vector{stmt}); EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(Sem().Get(var)->AddressSpace(), ast::AddressSpace::kNone); } TEST_F(ResolverTest, Access_SetForStorageBuffer) { // struct S { x : i32 }; // var g : S; auto* s = Structure("S", utils::Vector{Member(Source{{12, 34}}, "x", ty.i32())}); auto* var = GlobalVar(Source{{56, 78}}, "g", ty.Of(s), ast::AddressSpace::kStorage, Binding(0_a), Group(0_a)); EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(Sem().Get(var)->Access(), ast::Access::kRead); } TEST_F(ResolverTest, BindingPoint_SetForResources) { // @group(1) @binding(2) var s1 : sampler; // @group(3) @binding(4) var s2 : sampler; auto* s1 = GlobalVar(Sym(), ty.sampler(ast::SamplerKind::kSampler), Group(1_a), Binding(2_a)); auto* s2 = GlobalVar(Sym(), ty.sampler(ast::SamplerKind::kSampler), Group(3_a), Binding(4_a)); EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(Sem().Get(s1)->BindingPoint(), (sem::BindingPoint{1u, 2u})); EXPECT_EQ(Sem().Get(s2)->BindingPoint(), (sem::BindingPoint{3u, 4u})); } TEST_F(ResolverTest, Function_EntryPoints_StageAttribute) { // fn b() {} // fn c() { b(); } // fn a() { c(); } // fn ep_1() { a(); b(); } // fn ep_2() { c();} // // c -> {ep_1, ep_2} // a -> {ep_1} // b -> {ep_1, ep_2} // ep_1 -> {} // ep_2 -> {} GlobalVar("first", ty.f32(), ast::AddressSpace::kPrivate); GlobalVar("second", ty.f32(), ast::AddressSpace::kPrivate); GlobalVar("call_a", ty.f32(), ast::AddressSpace::kPrivate); GlobalVar("call_b", ty.f32(), ast::AddressSpace::kPrivate); GlobalVar("call_c", ty.f32(), ast::AddressSpace::kPrivate); auto* func_b = Func("b", utils::Empty, ty.f32(), utils::Vector{ Return(0_f), }); auto* func_c = Func("c", utils::Empty, ty.f32(), utils::Vector{ Assign("second", Call("b")), Return(0_f), }); auto* func_a = Func("a", utils::Empty, ty.f32(), utils::Vector{ Assign("first", Call("c")), Return(0_f), }); auto* ep_1 = Func("ep_1", utils::Empty, ty.void_(), utils::Vector{ Assign("call_a", Call("a")), Assign("call_b", Call("b")), }, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_i), }); auto* ep_2 = Func("ep_2", utils::Empty, ty.void_(), utils::Vector{ Assign("call_c", Call("c")), }, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_i), }); ASSERT_TRUE(r()->Resolve()) << r()->error(); auto* func_b_sem = Sem().Get(func_b); auto* func_a_sem = Sem().Get(func_a); auto* func_c_sem = Sem().Get(func_c); auto* ep_1_sem = Sem().Get(ep_1); auto* ep_2_sem = Sem().Get(ep_2); ASSERT_NE(func_b_sem, nullptr); ASSERT_NE(func_a_sem, nullptr); ASSERT_NE(func_c_sem, nullptr); ASSERT_NE(ep_1_sem, nullptr); ASSERT_NE(ep_2_sem, nullptr); EXPECT_EQ(func_b_sem->Parameters().Length(), 0u); EXPECT_EQ(func_a_sem->Parameters().Length(), 0u); EXPECT_EQ(func_c_sem->Parameters().Length(), 0u); const auto& b_eps = func_b_sem->AncestorEntryPoints(); ASSERT_EQ(2u, b_eps.size()); EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]->Declaration()->symbol); EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]->Declaration()->symbol); const auto& a_eps = func_a_sem->AncestorEntryPoints(); ASSERT_EQ(1u, a_eps.size()); EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]->Declaration()->symbol); const auto& c_eps = func_c_sem->AncestorEntryPoints(); ASSERT_EQ(2u, c_eps.size()); EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]->Declaration()->symbol); EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]->Declaration()->symbol); EXPECT_TRUE(ep_1_sem->AncestorEntryPoints().empty()); EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty()); } // Check for linear-time traversal of functions reachable from entry points. // See: crbug.com/tint/245 TEST_F(ResolverTest, Function_EntryPoints_LinearTime) { // fn lNa() { } // fn lNb() { } // ... // fn l2a() { l3a(); l3b(); } // fn l2b() { l3a(); l3b(); } // fn l1a() { l2a(); l2b(); } // fn l1b() { l2a(); l2b(); } // fn main() { l1a(); l1b(); } static constexpr int levels = 64; auto fn_a = [](int level) { return "l" + std::to_string(level + 1) + "a"; }; auto fn_b = [](int level) { return "l" + std::to_string(level + 1) + "b"; }; Func(fn_a(levels), utils::Empty, ty.void_(), utils::Empty); Func(fn_b(levels), utils::Empty, ty.void_(), utils::Empty); for (int i = levels - 1; i >= 0; i--) { Func(fn_a(i), utils::Empty, ty.void_(), utils::Vector{ CallStmt(Call(fn_a(i + 1))), CallStmt(Call(fn_b(i + 1))), }, utils::Empty); Func(fn_b(i), utils::Empty, ty.void_(), utils::Vector{ CallStmt(Call(fn_a(i + 1))), CallStmt(Call(fn_b(i + 1))), }, utils::Empty); } Func("main", utils::Empty, ty.void_(), utils::Vector{ CallStmt(Call(fn_a(0))), CallStmt(Call(fn_b(0))), }, utils::Vector{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_i)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); } // Test for crbug.com/tint/728 TEST_F(ResolverTest, ASTNodesAreReached) { Structure("A", utils::Vector{Member("x", ty.array(4))}); Structure("B", utils::Vector{Member("x", ty.array(4))}); ASSERT_TRUE(r()->Resolve()) << r()->error(); } TEST_F(ResolverTest, ASTNodeNotReached) { EXPECT_FATAL_FAILURE( { ProgramBuilder b; b.Expr("expr"); Resolver(&b).Resolve(); }, "internal compiler error: AST node 'tint::ast::IdentifierExpression' was not reached by " "the resolver"); } TEST_F(ResolverTest, ASTNodeReachedTwice) { EXPECT_FATAL_FAILURE( { ProgramBuilder b; auto* expr = b.Expr(1_i); b.GlobalVar("a", b.ty.i32(), ast::AddressSpace::kPrivate, expr); b.GlobalVar("b", b.ty.i32(), ast::AddressSpace::kPrivate, expr); Resolver(&b).Resolve(); }, "internal compiler error: AST node 'tint::ast::IntLiteralExpression' was encountered twice " "in the same AST of a Program"); } TEST_F(ResolverTest, UnaryOp_Not) { GlobalVar("ident", ty.vec4(), ast::AddressSpace::kPrivate); auto* der = create(ast::UnaryOp::kNot, Expr(Source{{12, 34}}, "ident")); WrapInFunction(der); EXPECT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("error: no matching overload for operator ! (vec4)")); } TEST_F(ResolverTest, UnaryOp_Complement) { GlobalVar("ident", ty.vec4(), ast::AddressSpace::kPrivate); auto* der = create(ast::UnaryOp::kComplement, Expr(Source{{12, 34}}, "ident")); WrapInFunction(der); EXPECT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("error: no matching overload for operator ~ (vec4)")); } TEST_F(ResolverTest, UnaryOp_Negation) { GlobalVar("ident", ty.u32(), ast::AddressSpace::kPrivate); auto* der = create(ast::UnaryOp::kNegation, Expr(Source{{12, 34}}, "ident")); WrapInFunction(der); EXPECT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("error: no matching overload for operator - (u32)")); } TEST_F(ResolverTest, TextureSampler_TextureSample) { GlobalVar("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Group(1_a), Binding(1_a)); GlobalVar("s", ty.sampler(ast::SamplerKind::kSampler), Group(1_a), Binding(2_a)); auto* call = CallStmt(Call("textureSample", "t", "s", vec2(1_f, 2_f))); const ast::Function* f = Func("test_function", utils::Empty, ty.void_(), utils::Vector{call}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); const sem::Function* sf = Sem().Get(f); auto pairs = sf->TextureSamplerPairs(); ASSERT_EQ(pairs.Length(), 1u); EXPECT_TRUE(pairs[0].first != nullptr); EXPECT_TRUE(pairs[0].second != nullptr); } TEST_F(ResolverTest, TextureSampler_TextureSampleInFunction) { GlobalVar("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Group(1_a), Binding(1_a)); GlobalVar("s", ty.sampler(ast::SamplerKind::kSampler), Group(1_a), Binding(2_a)); auto* inner_call = CallStmt(Call("textureSample", "t", "s", vec2(1_f, 2_f))); const ast::Function* inner_func = Func("inner_func", utils::Empty, ty.void_(), utils::Vector{inner_call}); auto* outer_call = CallStmt(Call("inner_func")); const ast::Function* outer_func = Func("outer_func", utils::Empty, ty.void_(), utils::Vector{outer_call}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto inner_pairs = Sem().Get(inner_func)->TextureSamplerPairs(); ASSERT_EQ(inner_pairs.Length(), 1u); EXPECT_TRUE(inner_pairs[0].first != nullptr); EXPECT_TRUE(inner_pairs[0].second != nullptr); auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs(); ASSERT_EQ(outer_pairs.Length(), 1u); EXPECT_TRUE(outer_pairs[0].first != nullptr); EXPECT_TRUE(outer_pairs[0].second != nullptr); } TEST_F(ResolverTest, TextureSampler_TextureSampleFunctionDiamondSameVariables) { GlobalVar("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Group(1_a), Binding(1_a)); GlobalVar("s", ty.sampler(ast::SamplerKind::kSampler), Group(1_a), Binding(2_a)); auto* inner_call_1 = CallStmt(Call("textureSample", "t", "s", vec2(1_f, 2_f))); const ast::Function* inner_func_1 = Func("inner_func_1", utils::Empty, ty.void_(), utils::Vector{inner_call_1}); auto* inner_call_2 = CallStmt(Call("textureSample", "t", "s", vec2(3_f, 4_f))); const ast::Function* inner_func_2 = Func("inner_func_2", utils::Empty, ty.void_(), utils::Vector{inner_call_2}); auto* outer_call_1 = CallStmt(Call("inner_func_1")); auto* outer_call_2 = CallStmt(Call("inner_func_2")); const ast::Function* outer_func = Func("outer_func", utils::Empty, ty.void_(), utils::Vector{outer_call_1, outer_call_2}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto inner_pairs_1 = Sem().Get(inner_func_1)->TextureSamplerPairs(); ASSERT_EQ(inner_pairs_1.Length(), 1u); EXPECT_TRUE(inner_pairs_1[0].first != nullptr); EXPECT_TRUE(inner_pairs_1[0].second != nullptr); auto inner_pairs_2 = Sem().Get(inner_func_2)->TextureSamplerPairs(); ASSERT_EQ(inner_pairs_1.Length(), 1u); EXPECT_TRUE(inner_pairs_2[0].first != nullptr); EXPECT_TRUE(inner_pairs_2[0].second != nullptr); auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs(); ASSERT_EQ(outer_pairs.Length(), 1u); EXPECT_TRUE(outer_pairs[0].first != nullptr); EXPECT_TRUE(outer_pairs[0].second != nullptr); } TEST_F(ResolverTest, TextureSampler_TextureSampleFunctionDiamondDifferentVariables) { GlobalVar("t1", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Group(1_a), Binding(1_a)); GlobalVar("t2", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Group(1_a), Binding(2_a)); GlobalVar("s", ty.sampler(ast::SamplerKind::kSampler), Group(1_a), Binding(3_a)); auto* inner_call_1 = CallStmt(Call("textureSample", "t1", "s", vec2(1_f, 2_f))); const ast::Function* inner_func_1 = Func("inner_func_1", utils::Empty, ty.void_(), utils::Vector{inner_call_1}); auto* inner_call_2 = CallStmt(Call("textureSample", "t2", "s", vec2(3_f, 4_f))); const ast::Function* inner_func_2 = Func("inner_func_2", utils::Empty, ty.void_(), utils::Vector{inner_call_2}); auto* outer_call_1 = CallStmt(Call("inner_func_1")); auto* outer_call_2 = CallStmt(Call("inner_func_2")); const ast::Function* outer_func = Func("outer_func", utils::Empty, ty.void_(), utils::Vector{outer_call_1, outer_call_2}, utils::Vector{Stage(ast::PipelineStage::kFragment)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto inner_pairs_1 = Sem().Get(inner_func_1)->TextureSamplerPairs(); ASSERT_EQ(inner_pairs_1.Length(), 1u); EXPECT_TRUE(inner_pairs_1[0].first != nullptr); EXPECT_TRUE(inner_pairs_1[0].second != nullptr); auto inner_pairs_2 = Sem().Get(inner_func_2)->TextureSamplerPairs(); ASSERT_EQ(inner_pairs_2.Length(), 1u); EXPECT_TRUE(inner_pairs_2[0].first != nullptr); EXPECT_TRUE(inner_pairs_2[0].second != nullptr); auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs(); ASSERT_EQ(outer_pairs.Length(), 2u); EXPECT_TRUE(outer_pairs[0].first == inner_pairs_1[0].first); EXPECT_TRUE(outer_pairs[0].second == inner_pairs_1[0].second); EXPECT_TRUE(outer_pairs[1].first == inner_pairs_2[0].first); EXPECT_TRUE(outer_pairs[1].second == inner_pairs_2[0].second); } TEST_F(ResolverTest, TextureSampler_TextureDimensions) { GlobalVar("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Group(1_a), Binding(2_a)); auto* call = Call("textureDimensions", "t"); const ast::Function* f = WrapInFunction(call); EXPECT_TRUE(r()->Resolve()) << r()->error(); const sem::Function* sf = Sem().Get(f); auto pairs = sf->TextureSamplerPairs(); ASSERT_EQ(pairs.Length(), 1u); EXPECT_TRUE(pairs[0].first != nullptr); EXPECT_TRUE(pairs[0].second == nullptr); } TEST_F(ResolverTest, TextureSampler_Bug1715) { // crbug.com/tint/1715 // @binding(0) @group(0) var s: sampler; // @binding(1) @group(0) var t: texture_2d; // @binding(2) @group(0) var c: vec2; // // @fragment // fn main() -> @location(0) vec4 { // return helper(&s, &t); // } // // fn helper(sl: ptr, tl: ptr>) -> vec4 { // return textureSampleLevel(*tl, *sl, c, 0.0); // } GlobalVar("s", ty.sampler(ast::SamplerKind::kSampler), Group(0_a), Binding(0_a)); GlobalVar("t", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Group(0_a), Binding(1_a)); GlobalVar("c", ty.vec2(), ast::AddressSpace::kUniform, Group(0_a), Binding(2_a)); Func("main", utils::Empty, ty.vec4(), utils::Vector{ Return(Call("helper", AddressOf("s"), AddressOf("t"))), }, utils::Vector{ Stage(ast::PipelineStage::kFragment), }, utils::Vector{ Location(0_u), }); Func("helper", utils::Vector{ Param("sl", ty.pointer(ty.sampler(ast::SamplerKind::kSampler), ast::AddressSpace::kFunction)), Param("tl", ty.pointer(ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), ast::AddressSpace::kFunction)), }, ty.vec4(), utils::Vector{ Return(Call("textureSampleLevel", Deref("tl"), Deref("sl"), "c", 0.0_a)), }); ASSERT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "error: cannot take the address of expression in handle address space"); } TEST_F(ResolverTest, ModuleDependencyOrderedDeclarations) { auto* f0 = Func("f0", utils::Empty, ty.void_(), utils::Empty); auto* v0 = GlobalVar("v0", ty.i32(), ast::AddressSpace::kPrivate); auto* a0 = Alias("a0", ty.i32()); auto* s0 = Structure("s0", utils::Vector{Member("m", ty.i32())}); auto* f1 = Func("f1", utils::Empty, ty.void_(), utils::Empty); auto* v1 = GlobalVar("v1", ty.i32(), ast::AddressSpace::kPrivate); auto* a1 = Alias("a1", ty.i32()); auto* s1 = Structure("s1", utils::Vector{Member("m", ty.i32())}); auto* f2 = Func("f2", utils::Empty, ty.void_(), utils::Empty); auto* v2 = GlobalVar("v2", ty.i32(), ast::AddressSpace::kPrivate); auto* a2 = Alias("a2", ty.i32()); auto* s2 = Structure("s2", utils::Vector{Member("m", ty.i32())}); EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(Sem().Module(), nullptr); EXPECT_THAT(Sem().Module()->DependencyOrderedDeclarations(), ElementsAre(f0, v0, a0, s0, f1, v1, a1, s1, f2, v2, a2, s2)); } constexpr size_t kMaxExpressionDepth = 512U; TEST_F(ResolverTest, MaxExpressionDepth_Pass) { auto* b = Var("b", ty.i32()); const ast::Expression* chain = nullptr; for (size_t i = 0; i < kMaxExpressionDepth; ++i) { chain = Add(chain ? chain : Expr("b"), Expr("b")); } auto* a = Let("a", chain); WrapInFunction(b, a); EXPECT_TRUE(r()->Resolve()) << r()->error(); } TEST_F(ResolverTest, MaxExpressionDepth_Fail) { auto* b = Var("b", ty.i32()); const ast::Expression* chain = nullptr; for (size_t i = 0; i < kMaxExpressionDepth + 1; ++i) { chain = Add(chain ? chain : Expr("b"), Expr("b")); } auto* a = Let("a", chain); WrapInFunction(b, a); EXPECT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("error: reached max expression depth of " + std::to_string(kMaxExpressionDepth))); } TEST_F(ResolverTest, Literal_F16WithoutExtension) { // fn test() {_ = 1.23h;} WrapInFunction(Ignore(Expr(f16(1.23f)))); EXPECT_FALSE(r()->Resolve()); EXPECT_THAT(r()->error(), HasSubstr("error: f16 literal used without 'f16' extension enabled")); } TEST_F(ResolverTest, Literal_F16WithExtension) { // enable f16; // fn test() {_ = 1.23h;} Enable(ast::Extension::kF16); WrapInFunction(Ignore(Expr(f16(1.23f)))); EXPECT_TRUE(r()->Resolve()); } // Windows debug builds have significantly smaller stack than other builds, and these tests will stack // overflow. #if !defined(NDEBUG) TEST_F(ResolverTest, ScopeDepth_NestedBlocks) { const ast::Statement* stmt = Return(); for (size_t i = 0; i < 150; i++) { stmt = Block(Source{{i, 1}}, stmt); } WrapInFunction(stmt); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "23:1 error: statement nesting depth / chaining length exceeds limit of 127"); } TEST_F(ResolverTest, ScopeDepth_NestedIf) { const ast::Statement* stmt = Return(); for (size_t i = 0; i < 150; i++) { stmt = If(Source{{i, 1}}, false, Block(Source{{i, 2}}, stmt)); } WrapInFunction(stmt); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "86:1 error: statement nesting depth / chaining length exceeds limit of 127"); } TEST_F(ResolverTest, ScopeDepth_IfElseChain) { const ast::Statement* stmt = nullptr; for (size_t i = 0; i < 150; i++) { stmt = If(Source{{i, 1}}, false, Block(Source{{i, 2}}), Else(stmt)); } WrapInFunction(stmt); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "24:2 error: statement nesting depth / chaining length exceeds limit of 127"); } #endif // !defined(NDEBUG) } // namespace } // namespace tint::resolver