diff --git a/BUILD.gn b/BUILD.gn index 079bbd17f9..23d89b1c07 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -371,10 +371,10 @@ source_set("libtint_core_src") { "src/inspector/scalar.h", "src/namer.cc", "src/namer.h", - "src/program_builder.cc", - "src/program_builder.h", "src/program.cc", "src/program.h", + "src/program_builder.cc", + "src/program_builder.h", "src/reader/reader.cc", "src/reader/reader.h", "src/scope_stack.h", @@ -404,10 +404,6 @@ source_set("libtint_core_src") { "src/transform/transform.h", "src/transform/vertex_pulling.cc", "src/transform/vertex_pulling.h", - "src/type_determiner.cc", - "src/type_determiner.h", - "src/validator/validator.cc", - "src/validator/validator.h", "src/type/access_control_type.cc", "src/type/access_control_type.h", "src/type/alias_type.cc", @@ -448,6 +444,10 @@ source_set("libtint_core_src") { "src/type/vector_type.h", "src/type/void_type.cc", "src/type/void_type.h", + "src/type_determiner.cc", + "src/type_determiner.h", + "src/validator/validator.cc", + "src/validator/validator.h", "src/validator/validator_impl.cc", "src/validator/validator_impl.h", "src/validator/validator_test_helper.cc", @@ -824,8 +824,8 @@ source_set("tint_unittests_core_src") { "src/diagnostic/printer_test.cc", "src/inspector/inspector_test.cc", "src/namer_test.cc", - "src/program_test.cc", "src/program_builder_test.cc", + "src/program_test.cc", "src/scope_stack_test.cc", "src/symbol_table_test.cc", "src/symbol_test.cc", @@ -835,7 +835,6 @@ source_set("tint_unittests_core_src") { "src/transform/first_index_offset_test.cc", "src/transform/test_helper.h", "src/transform/vertex_pulling_test.cc", - "src/type_determiner_test.cc", "src/type/access_control_type_test.cc", "src/type/alias_type_test.cc", "src/type/array_type_test.cc", @@ -854,6 +853,8 @@ source_set("tint_unittests_core_src") { "src/type/type_manager_test.cc", "src/type/u32_type_test.cc", "src/type/vector_type_test.cc", + "src/type_determiner_test.cc", + "src/validator/validator_builtins_test.cc", "src/validator/validator_control_block_test.cc", "src/validator/validator_function_test.cc", "src/validator/validator_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 465e16fe4f..2ab4a49695 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -478,6 +478,7 @@ if(${TINT_BUILD_TESTS}) type/type_manager_test.cc type/u32_type_test.cc type/vector_type_test.cc + validator/validator_builtins_test.cc validator/validator_control_block_test.cc validator/validator_function_test.cc validator/validator_test.cc diff --git a/src/type/type.cc b/src/type/type.cc index 623d07616c..3dfdeb03a5 100644 --- a/src/type/type.cc +++ b/src/type/type.cc @@ -75,47 +75,47 @@ uint64_t Type::BaseAlignment(MemoryLayout) const { return 0; } -bool Type::is_scalar() { +bool Type::is_scalar() const { return is_float_scalar() || is_integer_scalar() || Is(); } -bool Type::is_float_scalar() { +bool Type::is_float_scalar() const { return Is(); } -bool Type::is_float_matrix() { +bool Type::is_float_matrix() const { return Is() && As()->type()->is_float_scalar(); } -bool Type::is_float_vector() { +bool Type::is_float_vector() const { return Is() && As()->type()->is_float_scalar(); } -bool Type::is_float_scalar_or_vector() { +bool Type::is_float_scalar_or_vector() const { return is_float_scalar() || is_float_vector(); } -bool Type::is_integer_scalar() { +bool Type::is_integer_scalar() const { return Is() || Is(); } -bool Type::is_unsigned_integer_vector() { +bool Type::is_unsigned_integer_vector() const { return Is() && As()->type()->Is(); } -bool Type::is_signed_integer_vector() { +bool Type::is_signed_integer_vector() const { return Is() && As()->type()->Is(); } -bool Type::is_unsigned_scalar_or_vector() { +bool Type::is_unsigned_scalar_or_vector() const { return Is() || (Is() && As()->type()->Is()); } -bool Type::is_signed_scalar_or_vector() { +bool Type::is_signed_scalar_or_vector() const { return Is() || (Is() && As()->type()->Is()); } -bool Type::is_integer_scalar_or_vector() { +bool Type::is_integer_scalar_or_vector() const { return is_unsigned_scalar_or_vector() || is_signed_scalar_or_vector(); } diff --git a/src/type/type.h b/src/type/type.h index ae8d245c0b..126486555a 100644 --- a/src/type/type.h +++ b/src/type/type.h @@ -74,27 +74,27 @@ class Type : public Castable { Type* UnwrapAll(); /// @returns true if this type is a scalar - bool is_scalar(); + bool is_scalar() const; /// @returns true if this type is a float scalar - bool is_float_scalar(); + bool is_float_scalar() const; /// @returns true if this type is a float matrix - bool is_float_matrix(); + bool is_float_matrix() const; /// @returns true if this type is a float vector - bool is_float_vector(); + bool is_float_vector() const; /// @returns true if this type is a float scalar or vector - bool is_float_scalar_or_vector(); - /// @returns ture if this type is an integer scalar - bool is_integer_scalar(); + bool is_float_scalar_or_vector() const; + /// @returns true if this type is an integer scalar + bool is_integer_scalar() const; /// @returns true if this type is a signed integer vector - bool is_signed_integer_vector(); + bool is_signed_integer_vector() const; /// @returns true if this type is an unsigned vector - bool is_unsigned_integer_vector(); + bool is_unsigned_integer_vector() const; /// @returns true if this type is an unsigned scalar or vector - bool is_unsigned_scalar_or_vector(); + bool is_unsigned_scalar_or_vector() const; /// @returns true if this type is a signed scalar or vector - bool is_signed_scalar_or_vector(); + bool is_signed_scalar_or_vector() const; /// @returns true if this type is an integer scalar or vector - bool is_integer_scalar_or_vector(); + bool is_integer_scalar_or_vector() const; protected: Type(); diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 4256651c5f..32df8700b9 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -448,70 +448,65 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) { namespace { enum class IntrinsicDataType { - kFloatOrIntScalarOrVector, - kFloatScalarOrVector, - kIntScalarOrVector, - kFloatVector, - kMatrix, + kDependent, + kSignedInteger, + kUnsignedInteger, + kFloat, }; + struct IntrinsicData { ast::Intrinsic intrinsic; - uint8_t param_count; - IntrinsicDataType data_type; - uint8_t vector_size; + IntrinsicDataType result_type; + uint8_t result_vector_width; + uint8_t param_for_result_type; }; // Note, this isn't all the intrinsics. Some are handled specially before // we get to the generic code. See the DetermineIntrinsic code below. constexpr const IntrinsicData kIntrinsicData[] = { - {ast::Intrinsic::kAbs, 1, IntrinsicDataType::kFloatOrIntScalarOrVector, 0}, - {ast::Intrinsic::kAcos, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kAsin, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kAtan, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kAtan2, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kCeil, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kClamp, 3, IntrinsicDataType::kFloatOrIntScalarOrVector, - 0}, - {ast::Intrinsic::kCos, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kCosh, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kCountOneBits, 1, IntrinsicDataType::kIntScalarOrVector, - 0}, - {ast::Intrinsic::kCross, 2, IntrinsicDataType::kFloatVector, 3}, - {ast::Intrinsic::kDeterminant, 1, IntrinsicDataType::kMatrix, 0}, - {ast::Intrinsic::kDistance, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kExp, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kExp2, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kFaceForward, 3, IntrinsicDataType::kFloatScalarOrVector, - 0}, - {ast::Intrinsic::kFloor, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kFma, 3, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kFract, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kFrexp, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kInverseSqrt, 1, IntrinsicDataType::kFloatScalarOrVector, - 0}, - {ast::Intrinsic::kLdexp, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kLength, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kLog, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kLog2, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kMax, 2, IntrinsicDataType::kFloatOrIntScalarOrVector, 0}, - {ast::Intrinsic::kMin, 2, IntrinsicDataType::kFloatOrIntScalarOrVector, 0}, - {ast::Intrinsic::kMix, 3, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kModf, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kNormalize, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kPow, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kReflect, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kReverseBits, 1, IntrinsicDataType::kIntScalarOrVector, 0}, - {ast::Intrinsic::kRound, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kSign, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kSin, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kSinh, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kSmoothStep, 3, IntrinsicDataType::kFloatScalarOrVector, - 0}, - {ast::Intrinsic::kSqrt, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kStep, 2, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kTan, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kTanh, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, - {ast::Intrinsic::kTrunc, 1, IntrinsicDataType::kFloatScalarOrVector, 0}, + {ast::Intrinsic::kAbs, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kAcos, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kAsin, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kAtan, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kAtan2, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kCeil, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kClamp, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kCos, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kCosh, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kCountOneBits, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kCross, IntrinsicDataType::kFloat, 3, 0}, + {ast::Intrinsic::kDeterminant, IntrinsicDataType::kFloat, 1, 0}, + {ast::Intrinsic::kDistance, IntrinsicDataType::kFloat, 1, 0}, + {ast::Intrinsic::kExp, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kExp2, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kFaceForward, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kFloor, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kFma, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kFract, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kFrexp, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kInverseSqrt, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kLdexp, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kLength, IntrinsicDataType::kFloat, 1, 0}, + {ast::Intrinsic::kLog, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kLog2, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kMax, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kMin, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kMix, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kModf, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kNormalize, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kPow, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kReflect, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kReverseBits, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kRound, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kSign, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kSin, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kSinh, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kSmoothStep, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kSqrt, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kStep, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kTan, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kTanh, IntrinsicDataType::kDependent, 0, 0}, + {ast::Intrinsic::kTrunc, IntrinsicDataType::kDependent, 0, 0}, }; constexpr const uint32_t kIntrinsicDataCount = @@ -780,105 +775,42 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident, return false; } - if (expr->params().size() != data->param_count) { - set_error(expr->source(), "incorrect number of parameters for " + - builder_->Symbols().NameFor(ident->symbol()) + - ". Expected " + - std::to_string(data->param_count) + " got " + - std::to_string(expr->params().size())); - return false; - } - - std::vector result_types; - for (uint32_t i = 0; i < data->param_count; ++i) { - result_types.push_back(TypeOf(expr->params()[i])->UnwrapPtrIfNeeded()); - - switch (data->data_type) { - case IntrinsicDataType::kFloatOrIntScalarOrVector: - if (!result_types.back()->is_float_scalar_or_vector() && - !result_types.back()->is_integer_scalar_or_vector()) { - set_error(expr->source(), - "incorrect type for " + - builder_->Symbols().NameFor(ident->symbol()) + ". " + - "Requires float or int, scalar or vector values"); - return false; - } - break; - case IntrinsicDataType::kFloatScalarOrVector: - if (!result_types.back()->is_float_scalar_or_vector()) { - set_error(expr->source(), - "incorrect type for " + - builder_->Symbols().NameFor(ident->symbol()) + ". " + - "Requires float scalar or float vector values"); - return false; - } - - break; - case IntrinsicDataType::kIntScalarOrVector: - if (!result_types.back()->is_integer_scalar_or_vector()) { - set_error(expr->source(), - "incorrect type for " + - builder_->Symbols().NameFor(ident->symbol()) + ". " + - "Requires integer scalar or integer vector values"); - return false; - } - break; - case IntrinsicDataType::kFloatVector: - if (!result_types.back()->is_float_vector()) { - set_error(expr->source(), - "incorrect type for " + - builder_->Symbols().NameFor(ident->symbol()) + ". " + - "Requires float vector values"); - return false; - } - if (data->vector_size > 0 && - result_types.back()->As()->size() != - data->vector_size) { - set_error(expr->source(), - "incorrect vector size for " + - builder_->Symbols().NameFor(ident->symbol()) + ". " + - "Requires " + std::to_string(data->vector_size) + - " elements"); - return false; - } - break; - case IntrinsicDataType::kMatrix: - if (!result_types.back()->Is()) { - set_error(expr->source(), - "incorrect type for " + - builder_->Symbols().NameFor(ident->symbol()) + - ". Requires matrix value"); - return false; - } - break; - } - } - - // Verify all the parameter types match - for (size_t i = 1; i < data->param_count; ++i) { - if (result_types[0] != result_types[i]) { + if (data->result_type == IntrinsicDataType::kDependent) { + const auto param_idx = data->param_for_result_type; + if (expr->params().size() <= param_idx) { set_error(expr->source(), - "mismatched parameter types for " + + "missing parameter " + std::to_string(param_idx) + + " required for type determination in builtin " + builder_->Symbols().NameFor(ident->symbol())); return false; } + SetType(expr->func(), + TypeOf(expr->params()[param_idx])->UnwrapPtrIfNeeded()); + } else { + // The result type is not dependent on the parameter types. + type::Type* type = nullptr; + switch (data->result_type) { + case IntrinsicDataType::kSignedInteger: + type = builder_->create(); + break; + case IntrinsicDataType::kUnsignedInteger: + type = builder_->create(); + break; + case IntrinsicDataType::kFloat: + type = builder_->create(); + break; + default: + error_ = "unhandled intrinsic data type for " + + builder_->Symbols().NameFor(ident->symbol()); + return false; + } + + if (data->result_vector_width > 1) { + type = builder_->create(type, data->result_vector_width); + } + SetType(expr->func(), type); } - // Handle functions which aways return the type, even if a vector is - // provided. - if (ident->intrinsic() == ast::Intrinsic::kLength || - ident->intrinsic() == ast::Intrinsic::kDistance) { - SetType(expr->func(), result_types[0]->is_float_scalar() - ? result_types[0] - : result_types[0]->As()->type()); - return true; - } - // The determinant returns the component type of the columns - if (ident->intrinsic() == ast::Intrinsic::kDeterminant) { - SetType(expr->func(), result_types[0]->As()->type()); - return true; - } - SetType(expr->func(), result_types[0]); return true; } diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 29daee7e72..3fe802aacd 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -1735,35 +1735,15 @@ TEST_P(ImportData_SingleParamTest, Vector) { EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } -TEST_P(ImportData_SingleParamTest, Error_Integer) { - auto param = GetParam(); - - auto* call = Call(param.name, 1); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("incorrect type for ") + param.name + - ". Requires float scalar or float vector values"); -} - TEST_P(ImportData_SingleParamTest, Error_NoParams) { auto param = GetParam(); auto* call = Call(param.name); EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 0"); -} - -TEST_P(ImportData_SingleParamTest, Error_MultipleParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 3"); + EXPECT_EQ(td()->error(), + "missing parameter 0 required for type determination in builtin " + + std::string(param.name)); } INSTANTIATE_TEST_SUITE_P( @@ -1874,35 +1854,15 @@ TEST_P(ImportData_SingleParam_FloatOrInt_Test, Uint_Vector) { EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } -TEST_P(ImportData_SingleParam_FloatOrInt_Test, Error_Bool) { - auto param = GetParam(); - - auto* call = Call(param.name, false); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("incorrect type for ") + param.name + - ". Requires float or int, scalar or vector values"); -} - TEST_P(ImportData_SingleParam_FloatOrInt_Test, Error_NoParams) { auto param = GetParam(); auto* call = Call(param.name); EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 0"); -} - -TEST_P(ImportData_SingleParam_FloatOrInt_Test, Error_MultipleParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 3"); + EXPECT_EQ(td()->error(), + "missing parameter 0 required for type determination in builtin " + + std::string(param.name)); } INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, @@ -1933,36 +1893,6 @@ TEST_F(TypeDeterminerTest, ImportData_Length_FloatVector) { EXPECT_TRUE(TypeOf(ident)->is_float_scalar()); } -TEST_F(TypeDeterminerTest, ImportData_Length_Error_Integer) { - ast::ExpressionList params; - params.push_back(Expr(1)); - - auto* call = Call("length", params); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect type for length. Requires float scalar or float vector " - "values"); -} - -TEST_F(TypeDeterminerTest, ImportData_Length_Error_NoParams) { - ast::ExpressionList params; - - auto* call = Call("length"); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for length. Expected 1 got 0"); -} - -TEST_F(TypeDeterminerTest, ImportData_Length_Error_MultipleParams) { - auto* call = Call("length", 1.f, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for length. Expected 1 got 3"); -} - using ImportData_TwoParamTest = TypeDeterminerTestWithParam; TEST_P(ImportData_TwoParamTest, Scalar) { auto param = GetParam(); @@ -1987,69 +1917,16 @@ TEST_P(ImportData_TwoParamTest, Vector) { EXPECT_TRUE(TypeOf(ident)->is_float_vector()); EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } - -TEST_P(ImportData_TwoParamTest, Error_Integer) { - auto param = GetParam(); - - auto* call = Call(param.name, 1, 2); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("incorrect type for ") + param.name + - ". Requires float scalar or float vector values"); -} - TEST_P(ImportData_TwoParamTest, Error_NoParams) { auto param = GetParam(); auto* call = Call(param.name); - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 2 got 0"); -} - -TEST_P(ImportData_TwoParamTest, Error_OneParam) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 2 got 1"); -} - -TEST_P(ImportData_TwoParamTest, Error_MismatchedParamCount) { - auto param = GetParam(); - - auto* call = - Call(param.name, vec2(1.0f, 1.0f), vec3(1.0f, 1.0f, 3.0f)); - EXPECT_FALSE(td()->DetermineResultType(call)); EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); + "missing parameter 0 required for type determination in builtin " + + std::string(param.name)); } - -TEST_P(ImportData_TwoParamTest, Error_MismatchedParamType) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.0f, vec3(1.0f, 1.0f, 3.0f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); -} - -TEST_P(ImportData_TwoParamTest, Error_TooManyParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 2 got 3"); -} - INSTANTIATE_TEST_SUITE_P( TypeDeterminerTest, ImportData_TwoParamTest, @@ -2079,54 +1956,6 @@ TEST_F(TypeDeterminerTest, ImportData_Distance_Vector) { EXPECT_TRUE(TypeOf(ident)->Is()); } -TEST_F(TypeDeterminerTest, ImportData_Distance_Error_Integer) { - auto* call = Call("distance", 1, 2); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect type for distance. Requires float scalar or float " - "vector values"); -} - -TEST_F(TypeDeterminerTest, ImportData_Distance_Error_NoParams) { - auto* call = Call("distance"); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for distance. Expected 2 got 0"); -} - -TEST_F(TypeDeterminerTest, ImportData_Distance_Error_OneParam) { - auto* call = Call("distance", 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for distance. Expected 2 got 1"); -} - -TEST_F(TypeDeterminerTest, ImportData_Distance_Error_MismatchedParamCount) { - auto* call = - Call("distance", vec2(1.0f, 1.0f), vec3(1.0f, 1.0f, 3.0f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), "mismatched parameter types for distance"); -} - -TEST_F(TypeDeterminerTest, ImportData_Distance_Error_MismatchedParamType) { - auto* call = Call("distance", Expr(1.0f), vec3(1.0f, 1.0f, 3.0f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), "mismatched parameter types for distance"); -} - -TEST_F(TypeDeterminerTest, ImportData_Distance_Error_TooManyParams) { - auto* call = Call("distance", Expr(1.f), Expr(1.f), Expr(1.f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for distance. Expected 2 got 3"); -} - TEST_F(TypeDeterminerTest, ImportData_Cross) { auto* ident = Expr("cross"); @@ -2139,45 +1968,15 @@ TEST_F(TypeDeterminerTest, ImportData_Cross) { EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } -TEST_F(TypeDeterminerTest, ImportData_Cross_Error_Scalar) { - auto* call = Call("cross", 1.0f, 1.0f); +TEST_F(TypeDeterminerTest, ImportData_Cross_AutoType) { + auto* ident = Expr("cross"); - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect type for cross. Requires float vector values"); -} + auto* call = Call(ident); -TEST_F(TypeDeterminerTest, ImportData_Cross_Error_IntType) { - auto* call = Call("cross", vec3(1, 1, 3), vec3(1, 1, 3)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect type for cross. Requires float vector values"); -} - -TEST_F(TypeDeterminerTest, ImportData_Cross_Error_MissingParams) { - auto* call = Call("cross"); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for cross. Expected 2 got 0"); -} - -TEST_F(TypeDeterminerTest, ImportData_Cross_Error_TooFewParams) { - auto* call = Call("cross", vec3(1.0f, 1.0f, 3.0f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for cross. Expected 2 got 1"); -} - -TEST_F(TypeDeterminerTest, ImportData_Cross_Error_TooManyParams) { - auto* call = Call("cross", vec3(1.0f, 1.0f, 3.0f), - vec3(1.0f, 1.0f, 3.0f), vec3(1.0f, 1.0f, 3.0f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - "incorrect number of parameters for cross. Expected 2 got 3"); + EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error(); + ASSERT_NE(TypeOf(ident), nullptr); + EXPECT_TRUE(TypeOf(ident)->is_float_vector()); + EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } using ImportData_ThreeParamTest = TypeDeterminerTestWithParam; @@ -2204,77 +2003,15 @@ TEST_P(ImportData_ThreeParamTest, Vector) { EXPECT_TRUE(TypeOf(ident)->is_float_vector()); EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } - -TEST_P(ImportData_ThreeParamTest, Error_Integer) { - auto param = GetParam(); - - auto* call = Call(param.name, 1, 2, 3); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("incorrect type for ") + param.name + - ". Requires float scalar or float vector values"); -} - TEST_P(ImportData_ThreeParamTest, Error_NoParams) { auto param = GetParam(); auto* call = Call(param.name); - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 0"); -} - -TEST_P(ImportData_ThreeParamTest, Error_OneParam) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 1"); -} - -TEST_P(ImportData_ThreeParamTest, Error_TwoParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 2"); -} - -TEST_P(ImportData_ThreeParamTest, Error_MismatchedParamCount) { - auto param = GetParam(); - - auto* call = Call(param.name, vec2(1.0f, 1.0f), - vec3(1.0f, 1.0f, 3.0f), vec3(1.0f, 1.0f, 3.0f)); - EXPECT_FALSE(td()->DetermineResultType(call)); EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); -} - -TEST_P(ImportData_ThreeParamTest, Error_MismatchedParamType) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.0f, 1.0f, vec3(1.0f, 1.0f, 3.0f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); -} - -TEST_P(ImportData_ThreeParamTest, Error_TooManyParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f, 1.f, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 4"); + "missing parameter 0 required for type determination in builtin " + + std::string(param.name)); } INSTANTIATE_TEST_SUITE_P( @@ -2360,76 +2097,15 @@ TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Uint_Vector) { EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } -TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_Bool) { - auto param = GetParam(); - - auto* call = Call(param.name, true, false, true); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("incorrect type for ") + param.name + - ". Requires float or int, scalar or vector values"); -} - TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_NoParams) { auto param = GetParam(); auto* call = Call(param.name); - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 0"); -} - -TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_OneParam) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 1"); -} - -TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_TwoParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 2"); -} - -TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_MismatchedParamCount) { - auto param = GetParam(); - - auto* call = Call(param.name, vec2(1.0f, 1.0f), - vec3(1.0f, 1.0f, 3.0f), vec3(1.0f, 1.0f, 3.0f)); - EXPECT_FALSE(td()->DetermineResultType(call)); EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); -} - -TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_MismatchedParamType) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.0f, 1.0f, vec3(1.0f, 1.0f, 3.0f)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); -} - -TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_TooManyParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f, 1.f, 1.f, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 3 got 4"); + "missing parameter 0 required for type determination in builtin " + + std::string(param.name)); } INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, @@ -2462,35 +2138,15 @@ TEST_P(ImportData_Int_SingleParamTest, Vector) { EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } -TEST_P(ImportData_Int_SingleParamTest, Error_Float) { - auto param = GetParam(); - - auto* call = Call(param.name, 1.f); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("incorrect type for ") + param.name + - ". Requires integer scalar or integer vector values"); -} - TEST_P(ImportData_Int_SingleParamTest, Error_NoParams) { auto param = GetParam(); auto* call = Call(param.name); EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 0"); -} - -TEST_P(ImportData_Int_SingleParamTest, Error_MultipleParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1, 1, 1); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 3"); + EXPECT_EQ(td()->error(), + "missing parameter 0 required for type determination in builtin " + + std::string(param.name)); } INSTANTIATE_TEST_SUITE_P( @@ -2571,65 +2227,15 @@ TEST_P(ImportData_FloatOrInt_TwoParamTest, Vector_Float) { EXPECT_EQ(TypeOf(ident)->As()->size(), 3u); } -TEST_P(ImportData_FloatOrInt_TwoParamTest, Error_Bool) { - auto param = GetParam(); - - auto* call = Call(param.name, true, false); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("incorrect type for ") + param.name + - ". Requires float or int, scalar or vector values"); -} - TEST_P(ImportData_FloatOrInt_TwoParamTest, Error_NoParams) { auto param = GetParam(); auto* call = Call(param.name); - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 2 got 0"); -} - -TEST_P(ImportData_FloatOrInt_TwoParamTest, Error_OneParam) { - auto param = GetParam(); - - auto* call = Call(param.name, 1); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 2 got 1"); -} - -TEST_P(ImportData_FloatOrInt_TwoParamTest, Error_MismatchedParamCount) { - auto param = GetParam(); - - auto* call = Call(param.name, vec2(1, 1), vec3(1, 1, 3)); - EXPECT_FALSE(td()->DetermineResultType(call)); EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); -} - -TEST_P(ImportData_FloatOrInt_TwoParamTest, Error_MismatchedParamType) { - auto param = GetParam(); - - auto* call = Call(param.name, Expr(1), vec3(1, 1, 3)); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), - std::string("mismatched parameter types for ") + param.name); -} - -TEST_P(ImportData_FloatOrInt_TwoParamTest, Error_TooManyParams) { - auto param = GetParam(); - - auto* call = Call(param.name, 1, 1, 1); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 2 got 3"); + "missing parameter 0 required for type determination in builtin " + + std::string(param.name)); } INSTANTIATE_TEST_SUITE_P( @@ -2655,45 +2261,16 @@ TEST_F(TypeDeterminerTest, ImportData_GLSL_Determinant) { using ImportData_Matrix_OneParam_Test = TypeDeterminerTestWithParam; -TEST_P(ImportData_Matrix_OneParam_Test, Error_Float) { - auto param = GetParam(); - - auto* var = Var("var", ast::StorageClass::kFunction, ty.f32()); - AST().AddGlobalVariable(var); - - ASSERT_TRUE(td()->Determine()) << td()->error(); - - auto* call = Call(param.name, "var"); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name + - ". Requires matrix value"); -} TEST_P(ImportData_Matrix_OneParam_Test, NoParams) { auto param = GetParam(); auto* call = Call(param.name); - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 0"); + EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error(); + EXPECT_TRUE(TypeOf(call)->Is()); } -TEST_P(ImportData_Matrix_OneParam_Test, TooManyParams) { - auto param = GetParam(); - - auto* var = Var("var", ast::StorageClass::kFunction, ty.mat3x3()); - AST().AddGlobalVariable(var); - - ASSERT_TRUE(td()->Determine()) << td()->error(); - - auto* call = Call(param.name, "var", "var"); - - EXPECT_FALSE(td()->DetermineResultType(call)); - EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + - param.name + ". Expected 1 got 2"); -} INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, ImportData_Matrix_OneParam_Test, testing::Values(IntrinsicData{ diff --git a/src/validator/validator_builtins_test.cc b/src/validator/validator_builtins_test.cc new file mode 100644 index 0000000000..392a7acebd --- /dev/null +++ b/src/validator/validator_builtins_test.cc @@ -0,0 +1,1128 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "src/ast/array_accessor_expression.h" +#include "src/ast/struct.h" +#include "src/ast/struct_block_decoration.h" +#include "src/ast/struct_member.h" +#include "src/ast/struct_member_decoration.h" +#include "src/ast/type_constructor_expression.h" +#include "src/ast/variable_decl_statement.h" +#include "src/type/alias_type.h" +#include "src/type/array_type.h" +#include "src/type/f32_type.h" +#include "src/type/i32_type.h" +#include "src/type/struct_type.h" +#include "src/validator/validator_impl.h" +#include "src/validator/validator_test_helper.h" + +namespace tint { + +class ValidatorBuiltinsTest : public ValidatorTestHelper, + public testing::Test {}; + +TEST_F(ValidatorBuiltinsTest, Length_Float_Scalar) { + auto* builtin = Call("length", 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Length_Float_Vec2) { + auto* builtin = Call("length", vec2(1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Length_Float_Vec3) { + auto* builtin = Call("length", vec3(1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Length_Float_Vec4) { + auto* builtin = Call("length", vec4(1.0f, 1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Length_Integer_Scalar) { + auto* builtin = Call("length", 1); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect type for length. Requires float scalar or vector value"); +} + +TEST_F(ValidatorBuiltinsTest, Length_Integer_Vec2) { + auto* builtin = Call("length", vec2(1, 1)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect type for length. Requires float scalar or vector value"); +} + +TEST_F(ValidatorBuiltinsTest, Length_TooManyParams) { + auto* builtin = Call("length", 1.0f, 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for length expected 1 got 2"); +} + +TEST_F(ValidatorBuiltinsTest, Length_TooFewParams) { + auto* builtin = Call("length"); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for length expected 1 got 0"); +} + +TEST_F(ValidatorBuiltinsTest, Distance_Float_Scalar) { + auto* builtin = Call("distance", 1.0f, 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Distance_Float_Vec2) { + auto* builtin = + Call("distance", vec2(1.0f, 1.0f), vec2(1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Distance_Float_Vec3) { + auto* builtin = Call("distance", vec3(1.0f, 1.0f, 1.0f), + vec3(1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Distance_Float_Vec4) { + auto* builtin = Call("distance", vec4(1.0f, 1.0f, 1.0f, 1.0f), + vec4(1.0f, 1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Distance_Integer_Scalar) { + auto* builtin = Call("distance", 1, 1); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "incorrect type for distance. Requires float scalar or vector value"); +} + +TEST_F(ValidatorBuiltinsTest, Distance_Integer_Vec2) { + auto* builtin = Call("distance", vec2(1, 1), vec2(1, 1)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "incorrect type for distance. Requires float scalar or vector value"); +} + +TEST_F(ValidatorBuiltinsTest, Distance_TooManyParams) { + auto* builtin = Call("distance", 1.0f, 1.0f, 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for distance expected 2 got 3"); +} + +TEST_F(ValidatorBuiltinsTest, Distance_TooFewParams) { + auto* builtin = Call("distance", 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for distance expected 2 got 1"); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_Mat2x2) { + auto* builtin = Call("determinant", mat2x2(vec2(1.0f, 1.0f), + vec2(1.0f, 1.0f))); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_Mat3x3) { + auto* builtin = + Call("determinant", mat3x3(vec3(1.0f, 1.0f, 1.0f), + vec3(1.0f, 1.0f, 1.0f), + vec3(1.0f, 1.0f, 1.0f))); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_Mat4x4) { + auto* builtin = + Call("determinant", mat4x4(vec4(1.0f, 1.0f, 1.0f, 1.0f), + vec4(1.0f, 1.0f, 1.0f, 1.0f), + vec4(1.0f, 1.0f, 1.0f, 1.0f), + vec4(1.0f, 1.0f, 1.0f, 1.0f))); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_Mat3x2) { + auto* builtin = Call("determinant", mat3x2(vec2(1.0f, 1.0f), + vec2(1.0f, 1.0f), + vec2(1.0f, 1.0f))); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect type for determinant. Requires a square matrix"); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_Float_Vec2) { + auto* builtin = Call("determinant", vec2(1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect type for determinant. Requires matrix value"); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_Float_Scalar) { + auto* builtin = Call("determinant", 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect type for determinant. Requires matrix value"); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_Integer_Scalar) { + auto* builtin = Call("determinant", 1); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect type for determinant. Requires matrix value"); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_TooManyParams) { + auto* builtin = + Call("determinant", + mat2x2(vec2(1.0f, 1.0f), vec2(1.0f, 1.0f)), + mat2x2(vec2(1.0f, 1.0f), vec2(1.0f, 1.0f))); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for determinant expected 1 got 2"); +} + +TEST_F(ValidatorBuiltinsTest, Determinant_TooFewParams) { + auto* builtin = Call("determinant"); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for determinant expected 1 got 0"); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_Scalar) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.i32()); + auto* b = + Const("b", ast::StorageClass::kWorkgroup, + ty.pointer(ast::StorageClass::kWorkgroup), Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("frexp", 1.0f, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_Vec2) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec2()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.i32(), 2), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("frexp", vec2(1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_Vec3) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec3()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.i32(), 3), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("frexp", vec3(1.0f, 1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_Vec4) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec4()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.i32(), 4), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("frexp", vec4(1.0f, 1.0f, 1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_Integer_FirstParam) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.i32()); + auto* b = + Const("b", ast::StorageClass::kWorkgroup, + ty.pointer(ast::StorageClass::kWorkgroup), Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("frexp", 1, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_FALSE(TypeOf(builtin)->Is()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect type for frexp. Requires float scalar or vector value"); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_Float_SecondParam) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.f32()); + auto* b = + Const("b", ast::StorageClass::kWorkgroup, + ty.pointer(ast::StorageClass::kWorkgroup), Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("frexp", 1.0f, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect type for frexp. Requires int scalar or vector value"); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_NotAPointer) { + auto* builtin = Call("frexp", 1.0f, 1); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect type for frexp. Requires pointer value"); +} + +TEST_F(ValidatorBuiltinsTest, Frexp_Scalar_Vector) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec2()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.i32(), 2), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("frexp", 1.0f, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect types for frexp. Parameters must be matched scalars or " + "vectors"); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Scalar) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.f32()); + auto* b = + Const("b", ast::StorageClass::kWorkgroup, + ty.pointer(ast::StorageClass::kWorkgroup), Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", 1.0f, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Vec2) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec2()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.f32(), 2), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", vec2(1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Vec3) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec3()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.f32(), 3), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", vec3(1.0f, 1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Vec4) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec4()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.f32(), 4), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", vec4(1.0f, 1.0f, 1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Integer_FirstParam) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.f32()); + auto* b = + Const("b", ast::StorageClass::kWorkgroup, + ty.pointer(ast::StorageClass::kWorkgroup), Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", 1, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_FALSE(TypeOf(builtin)->Is()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect type for modf. Requires float scalar or vector value"); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Integer_SecondParam) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.i32()); + auto* b = + Const("b", ast::StorageClass::kWorkgroup, + ty.pointer(ast::StorageClass::kWorkgroup), Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", 1.0f, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 1's unwrapped type to match result type for modf"); +} + +TEST_F(ValidatorBuiltinsTest, Modf_NotAPointer) { + auto* builtin = Call("modf", 1.0f, 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect type for modf. Requires pointer value"); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Scalar_Vector) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec2()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.f32(), 2), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", 1.0f, Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 1's unwrapped type to match result type for modf"); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Vector_Scalar) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.f32()); + auto* b = + Const("b", ast::StorageClass::kWorkgroup, + ty.pointer(ast::StorageClass::kWorkgroup), Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", vec2(1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 1's unwrapped type to match result type for modf"); +} + +TEST_F(ValidatorBuiltinsTest, Modf_Vector_Vector_MismatchedSize) { + auto* a = Var("a", ast::StorageClass::kWorkgroup, ty.vec2()); + auto* b = Const("b", ast::StorageClass::kWorkgroup, + create(create(ty.f32(), 2), + ast::StorageClass::kWorkgroup), + Expr("a"), {}); + RegisterVariable(a); + RegisterVariable(b); + auto* builtin = Call("modf", vec3(1.0f, 1.0f, 1.0f), Expr("b")); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + EXPECT_TRUE(TypeOf(builtin->params()[1])->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 1's unwrapped type to match result type for modf"); +} + +TEST_F(ValidatorBuiltinsTest, Cross_Float_Vec3) { + auto* builtin = Call("cross", vec3(1.0f, 1.0f, 1.0f), + vec3(1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_F(ValidatorBuiltinsTest, Cross_Integer_Vec3) { + auto* builtin = Call("cross", vec3(1, 1, 1), vec3(1, 1, 1)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 0's unwrapped type to match result type for cross"); +} + +TEST_F(ValidatorBuiltinsTest, Cross_Float_Vec4) { + auto* builtin = Call("cross", vec4(1.0f, 1.0f, 1.0f, 1.0f), + vec4(1.0f, 1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 0's unwrapped type to match result type for cross"); +} + +TEST_F(ValidatorBuiltinsTest, Cross_Float_Vec2) { + auto* builtin = + Call("cross", vec2(1.0f, 1.0f), vec2(1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 0's unwrapped type to match result type for cross"); +} + +TEST_F(ValidatorBuiltinsTest, Cross_Float_Scalar) { + auto* builtin = Call("cross", 1.0f, 1.0f); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ( + v.error(), + "expected parameter 0's unwrapped type to match result type for cross"); +} + +TEST_F(ValidatorBuiltinsTest, Cross_TooManyParams) { + auto* builtin = + Call("cross", vec3(1.0f, 1.0f, 1.0f), + vec3(1.0f, 1.0f, 1.0f), vec3(1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for cross expected 2 got 3"); +} + +TEST_F(ValidatorBuiltinsTest, Cross_TooFewParams) { + auto* builtin = Call("cross", vec3(1.0f, 1.0f, 1.0f)); + + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "incorrect number of parameters for cross expected 2 got 1"); +} + +template +class ValidatorBuiltinsTestWithParams : public ValidatorTestHelper, + public testing::TestWithParam {}; + +using FloatAllMatching = + ValidatorBuiltinsTestWithParams>; + +TEST_P(FloatAllMatching, Scalar) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(Expr(1.0f)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(FloatAllMatching, Vec2) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec2(1.0f, 1.0f)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(FloatAllMatching, Vec3) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec3(1.0f, 1.0f, 1.0f)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(FloatAllMatching, Vec4) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec4(1.0f, 1.0f, 1.0f, 1.0f)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(FloatAllMatching, Param_TooManyParams) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params + 1; ++i) { + params.push_back(Expr(1.0f)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect number of parameters for " + name + + " expected " + std::to_string(num_params) + " got " + + std::to_string(num_params + 1)); +} + +TEST_P(FloatAllMatching, Param_TooFewParams) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params - 1; ++i) { + params.push_back(Expr(1.0f)); + } + auto* builtin = Call(name, params); + // Most intrinsics require a parameter to determine the type so expect type + // determination to fail at zero parameters. + if (num_params > 1) { + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect number of parameters for " + name + + " expected " + std::to_string(num_params) + + " got " + std::to_string(num_params - 1)); + } else { + EXPECT_FALSE(td()->DetermineResultType(builtin)); + } +} + +TEST_P(FloatAllMatching, Param_Mismatch_Scalar) { + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params - 1; ++i) { + params.push_back(Expr(1.0f)); + } + // Can't mismatch single parameter types. + if (num_params > 1) { + std::string name = std::get<0>(GetParam()); + params.push_back(Expr(1)); + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "expected parameter " + std::to_string(num_params - 1) + + "'s unwrapped type to match result type for " + name); + } +} + +TEST_P(FloatAllMatching, Param_Mismatch_Vector) { + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params - 1; ++i) { + params.push_back(Expr(1.0f)); + } + // Can't mismatch single parameter types. + if (num_params > 1) { + std::string name = std::get<0>(GetParam()); + params.push_back(vec2(1.0f, 1.0f)); + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "expected parameter " + std::to_string(num_params - 1) + + "'s unwrapped type to match result type for " + name); + } +} + +TEST_P(FloatAllMatching, Param_Integer) { + std::string name = std::get<0>(GetParam()); + // abs, clamp, max and min also support integers. + if (name != "abs" && name != "clamp" && name != "max" && name != "min") { + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(Expr(1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_FALSE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect type for " + name + + ". Requires float scalar or vector value"); + } +} + +INSTANTIATE_TEST_SUITE_P(ValidatorBuiltinsTest, + FloatAllMatching, + ::testing::Values(std::make_tuple("abs", 1), + std::make_tuple("acos", 1), + std::make_tuple("asin", 1), + std::make_tuple("atan", 1), + std::make_tuple("atan2", 2), + std::make_tuple("ceil", 1), + std::make_tuple("clamp", 3), + std::make_tuple("cos", 1), + std::make_tuple("cosh", 1), + std::make_tuple("exp", 1), + std::make_tuple("exp2", 1), + std::make_tuple("faceForward", 3), + std::make_tuple("floor", 1), + std::make_tuple("fma", 3), + std::make_tuple("fract", 1), + std::make_tuple("inverseSqrt", 1), + std::make_tuple("ldexp", 2), + std::make_tuple("log", 1), + std::make_tuple("log2", 1), + std::make_tuple("max", 2), + std::make_tuple("min", 2), + std::make_tuple("mix", 3), + std::make_tuple("pow", 2), + std::make_tuple("reflect", 2), + std::make_tuple("round", 1), + std::make_tuple("sign", 1), + std::make_tuple("sin", 1), + std::make_tuple("sinh", 1), + std::make_tuple("smoothStep", 3), + std::make_tuple("sqrt", 1), + std::make_tuple("step", 2), + std::make_tuple("tan", 1), + std::make_tuple("tanh", 1), + std::make_tuple("trunc", 1))); + +using IntegerAllMatching = + ValidatorBuiltinsTestWithParams>; + +TEST_P(IntegerAllMatching, ScalarUnsigned) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(Construct(1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, Vec2Unsigned) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec2(1, 1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, Vec3Unsigned) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec3(1, 1, 1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, Vec4Unsigned) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec4(1, 1, 1, 1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, ScalarSigned) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(Construct(1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, Vec2Signed) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec2(1, 1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, Vec3Signed) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec3(1, 1, 1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, Vec4Signed) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(vec4(1, 1, 1, 1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector()); + ValidatorImpl& v = Build(); + EXPECT_TRUE(v.ValidateCallExpr(builtin)) << v.error(); +} + +TEST_P(IntegerAllMatching, Param_TooManyParams) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params + 1; ++i) { + params.push_back(Construct(1)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect number of parameters for " + name + + " expected " + std::to_string(num_params) + " got " + + std::to_string(num_params + 1)); +} + +TEST_P(IntegerAllMatching, Param_TooFewParams) { + std::string name = std::get<0>(GetParam()); + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params - 1; ++i) { + params.push_back(Construct(1)); + } + auto* builtin = Call(name, params); + // Most intrinsics require a parameter to determine the type so expect type + // determination to fail at zero parameters. + if (num_params > 1) { + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect number of parameters for " + name + + " expected " + std::to_string(num_params) + + " got " + std::to_string(num_params - 1)); + } else { + EXPECT_FALSE(td()->DetermineResultType(builtin)); + } +} + +TEST_P(IntegerAllMatching, Param_Mismatch_Scalar) { + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params - 1; ++i) { + params.push_back(Construct(1)); + } + // Can't mismatch single parameter types. + if (num_params > 1) { + std::string name = std::get<0>(GetParam()); + params.push_back(Expr(1)); + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "expected parameter " + std::to_string(num_params - 1) + + "'s unwrapped type to match result type for " + name); + } +} + +TEST_P(IntegerAllMatching, Param_Mismatch_Vector) { + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params - 1; ++i) { + params.push_back(Construct(1)); + } + // Can't mismatch single parameter types. + if (num_params > 1) { + std::string name = std::get<0>(GetParam()); + params.push_back(vec2(1, 1)); + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "expected parameter " + std::to_string(num_params - 1) + + "'s unwrapped type to match result type for " + name); + } +} + +TEST_P(IntegerAllMatching, Param_Mismatch_Sign) { + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params - 1; ++i) { + params.push_back(Construct(1)); + } + // Can't mismatch single parameter types. + if (num_params > 1) { + std::string name = std::get<0>(GetParam()); + params.push_back(Construct(1)); + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), + "expected parameter " + std::to_string(num_params - 1) + + "'s unwrapped type to match result type for " + name); + } +} + +TEST_P(IntegerAllMatching, Param_Float) { + std::string name = std::get<0>(GetParam()); + // abs, clamp, max and min also support integers. + if (name != "abs" && name != "clamp" && name != "max" && name != "min") { + uint32_t num_params = std::get<1>(GetParam()); + + ast::ExpressionList params; + for (uint32_t i = 0; i < num_params; ++i) { + params.push_back(Expr(1.0f)); + } + auto* builtin = Call(name, params); + EXPECT_TRUE(td()->DetermineResultType(builtin)) << td()->error(); + EXPECT_TRUE(TypeOf(builtin)->Is()); + ValidatorImpl& v = Build(); + EXPECT_FALSE(v.ValidateCallExpr(builtin)); + EXPECT_EQ(v.error(), "incorrect type for " + name + + ". Requires int scalar or vector value"); + } +} + +INSTANTIATE_TEST_SUITE_P(ValidatorBuiltinsTest, + IntegerAllMatching, + ::testing::Values(std::make_tuple("abs", 1), + std::make_tuple("clamp", 3), + std::make_tuple("countOneBits", 1), + std::make_tuple("max", 2), + std::make_tuple("min", 2), + std::make_tuple("reverseBits", 1))); +} // namespace tint diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc index 818ac50dec..b3ce8f2f3e 100644 --- a/src/validator/validator_impl.cc +++ b/src/validator/validator_impl.cc @@ -33,6 +33,7 @@ #include "src/semantic/expression.h" #include "src/type/alias_type.h" #include "src/type/array_type.h" +#include "src/type/f32_type.h" #include "src/type/i32_type.h" #include "src/type/matrix_type.h" #include "src/type/pointer_type.h" @@ -43,6 +44,175 @@ namespace tint { +namespace { + +enum class IntrinsicDataType { + kMixed, + kFloatOrIntScalarOrVector, + kFloatScalarOrVector, + kIntScalarOrVector, + kFloatVector, + kFloatScalar, + kMatrix, +}; + +struct IntrinsicData { + ast::Intrinsic intrinsic; + uint32_t param_count; + IntrinsicDataType data_type; + uint32_t vector_size; + bool all_types_match; +}; + +// Note, this isn't all the intrinsics. Some are handled specially before +// we get to the generic code. See the ValidateCallExpr code below. +constexpr const IntrinsicData kIntrinsicData[] = { + {ast::Intrinsic::kAbs, 1, IntrinsicDataType::kFloatOrIntScalarOrVector, 0, + true}, + {ast::Intrinsic::kAcos, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kAsin, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kAtan, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kAtan2, 2, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kCeil, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kClamp, 3, IntrinsicDataType::kFloatOrIntScalarOrVector, 0, + true}, + {ast::Intrinsic::kCos, 1, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kCosh, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kCountOneBits, 1, IntrinsicDataType::kIntScalarOrVector, 0, + true}, + {ast::Intrinsic::kCross, 2, IntrinsicDataType::kFloatVector, 3, true}, + {ast::Intrinsic::kDeterminant, 1, IntrinsicDataType::kMatrix, 0, false}, + {ast::Intrinsic::kDistance, 2, IntrinsicDataType::kFloatScalarOrVector, 0, + false}, + {ast::Intrinsic::kExp, 1, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kExp2, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kFaceForward, 3, IntrinsicDataType::kFloatScalarOrVector, + 0, true}, + {ast::Intrinsic::kFloor, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kFma, 3, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kFract, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kFrexp, 2, IntrinsicDataType::kMixed, 0, false}, + {ast::Intrinsic::kInverseSqrt, 1, IntrinsicDataType::kFloatScalarOrVector, + 0, true}, + {ast::Intrinsic::kLdexp, 2, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kLength, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + false}, + {ast::Intrinsic::kLog, 1, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kLog2, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kMax, 2, IntrinsicDataType::kFloatOrIntScalarOrVector, 0, + true}, + {ast::Intrinsic::kMin, 2, IntrinsicDataType::kFloatOrIntScalarOrVector, 0, + true}, + {ast::Intrinsic::kMix, 3, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kModf, 2, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kNormalize, 1, IntrinsicDataType::kFloatVector, 0, true}, + {ast::Intrinsic::kPow, 2, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kReflect, 2, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kReverseBits, 1, IntrinsicDataType::kIntScalarOrVector, 0, + true}, + {ast::Intrinsic::kRound, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kSign, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kSin, 1, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kSinh, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kSmoothStep, 3, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kSqrt, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kStep, 2, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kTan, 1, IntrinsicDataType::kFloatScalarOrVector, 0, true}, + {ast::Intrinsic::kTanh, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, + {ast::Intrinsic::kTrunc, 1, IntrinsicDataType::kFloatScalarOrVector, 0, + true}, +}; + +constexpr const uint32_t kIntrinsicDataCount = + sizeof(kIntrinsicData) / sizeof(IntrinsicData); + +bool IsValidType(type::Type* type, + const Source& source, + const std::string& name, + const IntrinsicDataType& data_type, + uint32_t vector_size, + ValidatorImpl* impl) { + type = type->UnwrapPtrIfNeeded(); + switch (data_type) { + case IntrinsicDataType::kFloatOrIntScalarOrVector: + if (!type->is_float_scalar_or_vector() && + !type->is_integer_scalar_or_vector()) { + impl->add_error(source, + "incorrect type for " + name + + ". Requires int or float, scalar or vector value"); + return false; + } + break; + case IntrinsicDataType::kFloatScalarOrVector: + if (!type->is_float_scalar_or_vector()) { + impl->add_error(source, "incorrect type for " + name + + ". Requires float scalar or vector value"); + return false; + } + break; + case IntrinsicDataType::kIntScalarOrVector: + if (!type->is_integer_scalar_or_vector()) { + impl->add_error(source, "incorrect type for " + name + + ". Requires int scalar or vector value"); + return false; + } + break; + case IntrinsicDataType::kFloatVector: + if (!type->is_float_vector()) { + impl->add_error(source, "incorrect type for " + name + + ". Requires float vector value"); + return false; + } + if (vector_size > 0 && vector_size != type->As()->size()) { + impl->add_error(source, "incorrect vector size for " + name + + ". Requires " + + std::to_string(vector_size) + " elements"); + return false; + } + break; + case IntrinsicDataType::kFloatScalar: + if (!type->Is()) { + impl->add_error(source, "incorrect type for " + name + + ". Requires float scalar value"); + return false; + } + break; + case IntrinsicDataType::kMatrix: + if (!type->Is()) { + impl->add_error( + source, "incorrect type for " + name + ". Requires matrix value"); + return false; + } + break; + default: + break; + } + + return true; +} + +} // namespace + ValidatorImpl::ValidatorImpl(const Program* program) : program_(program) {} ValidatorImpl::~ValidatorImpl() = default; @@ -406,10 +576,145 @@ bool ValidatorImpl::ValidateCallExpr(const ast::CallExpression* expr) { } if (auto* ident = expr->func()->As()) { + auto symbol = ident->symbol(); if (ident->IsIntrinsic()) { - // TODO(sarahM0): validate intrinsics - tied with type-determiner + const IntrinsicData* data = nullptr; + for (uint32_t i = 0; i < kIntrinsicDataCount; ++i) { + if (ident->intrinsic() == kIntrinsicData[i].intrinsic) { + data = &kIntrinsicData[i]; + break; + } + } + + if (data != nullptr) { + const auto builtin = program_->Symbols().NameFor(symbol); + if (expr->params().size() != data->param_count) { + add_error(expr->source(), + "incorrect number of parameters for " + builtin + + " expected " + std::to_string(data->param_count) + + " got " + std::to_string(expr->params().size())); + return false; + } + + if (data->all_types_match) { + // Check that the type is an acceptable one. + if (!IsValidType(program_->TypeOf(expr->func()), expr->source(), + builtin, data->data_type, data->vector_size, this)) { + return false; + } + + // Check that all params match the result type. + for (uint32_t i = 0; i < data->param_count; ++i) { + if (program_->TypeOf(expr->func())->UnwrapPtrIfNeeded() != + program_->TypeOf(expr->params()[i])->UnwrapPtrIfNeeded()) { + add_error(expr->params()[i]->source(), + "expected parameter " + std::to_string(i) + + "'s unwrapped type to match result type for " + + builtin); + return false; + } + } + } else { + if (data->data_type != IntrinsicDataType::kMixed) { + auto* p0 = expr->params()[0]; + if (!IsValidType(program_->TypeOf(p0), p0->source(), builtin, + data->data_type, data->vector_size, this)) { + return false; + } + + // Check that parameters are valid types. + for (uint32_t i = 1; i < expr->params().size(); ++i) { + if (program_->TypeOf(p0)->UnwrapPtrIfNeeded() != + program_->TypeOf(expr->params()[i])->UnwrapPtrIfNeeded()) { + add_error( + expr->source(), + "parameter " + std::to_string(i) + + "'s unwrapped type must match parameter 0's type"); + return false; + } + } + } else { + // Special cases. + if (data->intrinsic == ast::Intrinsic::kFrexp) { + auto* p0 = expr->params()[0]; + auto* p1 = expr->params()[1]; + auto* t0 = program_->TypeOf(p0)->UnwrapPtrIfNeeded(); + auto* t1 = program_->TypeOf(p1)->UnwrapPtrIfNeeded(); + if (!IsValidType(t0, p0->source(), builtin, + IntrinsicDataType::kFloatScalarOrVector, 0, + this)) { + return false; + } + if (!IsValidType(t1, p1->source(), builtin, + IntrinsicDataType::kIntScalarOrVector, 0, + this)) { + return false; + } + + if (t0->is_scalar()) { + if (!t1->is_scalar()) { + add_error( + expr->source(), + "incorrect types for " + builtin + + ". Parameters must be matched scalars or vectors"); + return false; + } + } else { + if (t1->is_integer_scalar()) { + add_error( + expr->source(), + "incorrect types for " + builtin + + ". Parameters must be matched scalars or vectors"); + return false; + } + const auto* v0 = t0->As(); + const auto* v1 = t1->As(); + if (v0->size() != v1->size()) { + add_error(expr->source(), + "incorrect types for " + builtin + + ". Parameter vector sizes must match"); + return false; + } + } + } + } + + // Result types don't match parameter types. + if (data->intrinsic == ast::Intrinsic::kLength || + data->intrinsic == ast::Intrinsic::kDistance || + data->intrinsic == ast::Intrinsic::kDeterminant) { + if (!IsValidType(program_->TypeOf(expr->func()), expr->source(), + builtin, IntrinsicDataType::kFloatScalar, 0, + this)) { + return false; + } + } + + // Must be a square matrix. + if (data->intrinsic == ast::Intrinsic::kDeterminant) { + const auto* matrix = + program_->TypeOf(expr->params()[0])->As(); + if (matrix->rows() != matrix->columns()) { + add_error(expr->params()[0]->source(), + "incorrect type for " + builtin + + ". Requires a square matrix"); + return false; + } + } + } + + // Last parameter must be a pointer. + if (data->intrinsic == ast::Intrinsic::kFrexp || + data->intrinsic == ast::Intrinsic::kModf) { + auto* last_param = expr->params()[data->param_count - 1]; + if (!program_->TypeOf(last_param)->Is()) { + add_error(last_param->source(), "incorrect type for " + builtin + + ". Requires pointer value"); + return false; + } + } + } } else { - auto symbol = ident->symbol(); if (!function_stack_.has(symbol)) { add_error(expr->source(), "v-0005", "function must be declared before use: '" +