Add support for GLSL cross.

This CL adds support for the GLSL cross method.

Change-Id: Ib2e83a2ef2e580c6ca257851a76f3f66fa377d6f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22842
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-06-08 23:48:15 +00:00 committed by dan sinclair
parent 92bb55777c
commit ee39225c0b
2 changed files with 254 additions and 53 deletions

View File

@ -628,68 +628,77 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
// * asinh, acosh, atanh
// * exp, exp2
// * log, log2
enum class GlslDataType { kFloatScalarOrVector, kIntScalarOrVector };
enum class GlslDataType {
kFloatScalarOrVector,
kIntScalarOrVector,
kFloatVector
};
struct GlslData {
const char* name;
uint8_t param_count;
uint32_t op_id;
GlslDataType type;
uint8_t vector_count;
};
constexpr const GlslData kGlslData[] = {
{"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector},
{"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector},
{"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector},
{"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector},
{"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector},
{"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector},
{"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector},
{"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector},
{"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector},
{"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector},
{"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector},
{"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector},
{"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector},
{"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector},
{"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector},
{"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector, 0},
{"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector, 0},
{"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector, 0},
{"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector, 0},
{"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector, 0},
{"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector, 0},
{"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector, 0},
{"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector, 0},
{"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector, 0},
{"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector, 0},
{"cross", 2, GLSLstd450Cross, GlslDataType::kFloatVector, 3},
{"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector, 0},
{"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector, 0},
{"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector, 0},
{"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector, 0},
{"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector, 0},
{"faceforward", 3, GLSLstd450FaceForward,
GlslDataType::kFloatScalarOrVector},
{"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector},
{"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector},
{"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector},
{"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector},
{"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector},
{"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector},
{"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector},
{"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector},
GlslDataType::kFloatScalarOrVector, 0},
{"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector, 0},
{"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector, 0},
{"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector, 0},
{"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector, 0},
{"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector, 0},
{"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector, 0},
{"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector, 0},
{"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector, 0},
{"inversesqrt", 1, GLSLstd450InverseSqrt,
GlslDataType::kFloatScalarOrVector},
{"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector},
{"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector},
{"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector},
{"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector},
{"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector},
{"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector},
{"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector},
{"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector},
{"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector},
{"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector},
{"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector},
{"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector},
{"sabs", 1, GLSLstd450SAbs, GlslDataType::kIntScalarOrVector},
{"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector},
{"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector},
{"smax", 2, GLSLstd450SMax, GlslDataType::kIntScalarOrVector},
{"smin", 2, GLSLstd450SMin, GlslDataType::kIntScalarOrVector},
{"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector},
{"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector},
{"ssign", 1, GLSLstd450SSign, GlslDataType::kIntScalarOrVector},
{"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector},
{"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector},
{"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector},
{"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector},
{"umax", 2, GLSLstd450UMax, GlslDataType::kIntScalarOrVector},
{"umin", 2, GLSLstd450UMin, GlslDataType::kIntScalarOrVector},
GlslDataType::kFloatScalarOrVector, 0},
{"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector, 0},
{"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector, 0},
{"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector, 0},
{"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector, 0},
{"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector, 0},
{"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector, 0},
{"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector,
0},
{"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector, 0},
{"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector, 0},
{"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector, 0},
{"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector, 0},
{"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector,
0},
{"sabs", 1, GLSLstd450SAbs, GlslDataType::kIntScalarOrVector, 0},
{"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector, 0},
{"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector, 0},
{"smax", 2, GLSLstd450SMax, GlslDataType::kIntScalarOrVector, 0},
{"smin", 2, GLSLstd450SMin, GlslDataType::kIntScalarOrVector, 0},
{"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector,
0},
{"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector, 0},
{"ssign", 1, GLSLstd450SSign, GlslDataType::kIntScalarOrVector, 0},
{"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector, 0},
{"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector, 0},
{"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector, 0},
{"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector, 0},
{"umax", 2, GLSLstd450UMax, GlslDataType::kIntScalarOrVector, 0},
{"umin", 2, GLSLstd450UMin, GlslDataType::kIntScalarOrVector, 0},
};
constexpr const uint32_t kGlslDataCount = sizeof(kGlslData) / sizeof(GlslData);
@ -742,6 +751,20 @@ ast::type::Type* TypeDeterminer::GetImportData(
return nullptr;
}
break;
case GlslDataType::kFloatVector:
if (!result_types.back()->is_float_vector()) {
set_error(source, "incorrect type for " + name + ". " +
"Requires float vector values");
return nullptr;
}
if (data->vector_count > 0 &&
result_types.back()->AsVector()->size() != data->vector_count) {
set_error(source,
"incorrect vector size for " + name + ". " + "Requires " +
std::to_string(data->vector_count) + " elements");
return nullptr;
}
break;
}
}

View File

@ -2348,8 +2348,186 @@ TEST_F(TypeDeterminerTest, ImportData_Distance_Error_TooManyParams) {
"incorrect number of parameters for distance. Expected 2 got 3");
}
using ImportData_ThreeParamTest = TypeDeterminerTestWithParam<GLSLData>;
TEST_F(TypeDeterminerTest, ImportData_Cross) {
ast::type::F32Type f32;
ast::type::VectorType vec(&f32, 3);
ast::ExpressionList vals_1;
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
ast::ExpressionList vals_2;
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_1)));
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_2)));
ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
uint32_t id = 0;
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
ASSERT_NE(type, nullptr);
EXPECT_TRUE(type->is_float_vector());
EXPECT_EQ(type->AsVector()->size(), 3u);
EXPECT_EQ(id, GLSLstd450Cross);
}
TEST_F(TypeDeterminerTest, ImportData_Cross_Error_Scalar) {
ast::type::F32Type f32;
ast::ExpressionList params;
params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
uint32_t id = 0;
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
"incorrect type for cross. Requires float vector values");
}
TEST_F(TypeDeterminerTest, ImportData_Cross_Error_IntType) {
ast::type::I32Type i32;
ast::type::VectorType vec(&i32, 3);
ast::ExpressionList vals_1;
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 1)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 1)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 3)));
ast::ExpressionList vals_2;
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 1)));
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 1)));
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 3)));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_1)));
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_2)));
ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
uint32_t id = 0;
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
"incorrect type for cross. Requires float vector values");
}
TEST_F(TypeDeterminerTest, ImportData_Cross_Error_MissingParams) {
ast::type::F32Type f32;
ast::type::VectorType vec(&f32, 3);
ast::ExpressionList params;
ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
uint32_t id = 0;
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
"incorrect number of parameters for cross. Expected 2 got 0");
}
TEST_F(TypeDeterminerTest, ImportData_Cross_Error_TooFewParams) {
ast::type::F32Type f32;
ast::type::VectorType vec(&f32, 3);
ast::ExpressionList vals_1;
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_1)));
ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
uint32_t id = 0;
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
"incorrect number of parameters for cross. Expected 2 got 1");
}
TEST_F(TypeDeterminerTest, ImportData_Cross_Error_TooManyParams) {
ast::type::F32Type f32;
ast::type::VectorType vec(&f32, 3);
ast::ExpressionList vals_1;
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
ast::ExpressionList vals_2;
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
ast::ExpressionList vals_3;
vals_3.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_3.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
vals_3.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_1)));
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_2)));
params.push_back(std::make_unique<ast::TypeConstructorExpression>(
&vec, std::move(vals_3)));
ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
uint32_t id = 0;
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
"incorrect number of parameters for cross. Expected 2 got 3");
}
using ImportData_ThreeParamTest = TypeDeterminerTestWithParam<GLSLData>;
TEST_P(ImportData_ThreeParamTest, Scalar) {
auto param = GetParam();