Handle MSL and HLSL builtins more consistently

* Fixes missing namespace for metal builtins
    * Consolidates handling of most builtins
    * Implements ldexp for msl and hlsl
    * Many more tests

Change-Id: I43a4876785d488921421ab64c2999aa036d831a8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/39940
Commit-Queue: Alan Baker <alanbaker@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Alan Baker 2021-02-01 21:30:54 +00:00 committed by Commit Bot service account
parent 7ecf92a53c
commit db67a287b8
6 changed files with 528 additions and 186 deletions

View File

@ -528,57 +528,6 @@ bool GeneratorImpl::EmitBreak(std::ostream& out, ast::BreakStatement*) {
return true; return true;
} }
std::string GeneratorImpl::generate_intrinsic_name(ast::Intrinsic intrinsic) {
if (intrinsic == ast::Intrinsic::kAny) {
return "any";
}
if (intrinsic == ast::Intrinsic::kAll) {
return "all";
}
if (intrinsic == ast::Intrinsic::kCountOneBits) {
return "countbits";
}
if (intrinsic == ast::Intrinsic::kDot) {
return "dot";
}
if (intrinsic == ast::Intrinsic::kDpdy) {
return "ddy";
}
if (intrinsic == ast::Intrinsic::kDpdyFine) {
return "ddy_fine";
}
if (intrinsic == ast::Intrinsic::kDpdyCoarse) {
return "ddy_coarse";
}
if (intrinsic == ast::Intrinsic::kDpdx) {
return "ddx";
}
if (intrinsic == ast::Intrinsic::kDpdxFine) {
return "ddx_fine";
}
if (intrinsic == ast::Intrinsic::kDpdxCoarse) {
return "ddx_coarse";
}
if (intrinsic == ast::Intrinsic::kFwidth ||
intrinsic == ast::Intrinsic::kFwidthFine ||
intrinsic == ast::Intrinsic::kFwidthCoarse) {
return "fwidth";
}
if (intrinsic == ast::Intrinsic::kIsFinite) {
return "isfinite";
}
if (intrinsic == ast::Intrinsic::kIsInf) {
return "isinf";
}
if (intrinsic == ast::Intrinsic::kIsNan) {
return "isnan";
}
if (intrinsic == ast::Intrinsic::kReverseBits) {
return "reversebits";
}
return "";
}
bool GeneratorImpl::EmitCall(std::ostream& pre, bool GeneratorImpl::EmitCall(std::ostream& pre,
std::ostream& out, std::ostream& out,
ast::CallExpression* expr) { ast::CallExpression* expr) {
@ -597,16 +546,13 @@ bool GeneratorImpl::EmitCall(std::ostream& pre,
error_ = "is_normal not supported in HLSL backend yet"; error_ = "is_normal not supported in HLSL backend yet";
return false; return false;
} else { } else {
auto name = generate_intrinsic_name(ident->intrinsic());
if (name.empty()) {
if (ast::intrinsic::IsTextureIntrinsic(ident->intrinsic())) { if (ast::intrinsic::IsTextureIntrinsic(ident->intrinsic())) {
return EmitTextureCall(pre, out, expr); return EmitTextureCall(pre, out, expr);
} }
name = generate_builtin_name(expr); auto name = generate_builtin_name(ident);
if (name.empty()) { if (name.empty()) {
return false; return false;
} }
}
make_indent(out); make_indent(out);
out << name << "("; out << name << "(";
@ -935,11 +881,13 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& pre,
return true; return true;
} // namespace hlsl } // namespace hlsl
std::string GeneratorImpl::generate_builtin_name(ast::CallExpression* expr) { std::string GeneratorImpl::generate_builtin_name(
ast::IdentifierExpression* ident) {
std::string out; std::string out;
auto* ident = expr->func()->As<ast::IdentifierExpression>();
switch (ident->intrinsic()) { switch (ident->intrinsic()) {
case ast::Intrinsic::kAcos: case ast::Intrinsic::kAcos:
case ast::Intrinsic::kAny:
case ast::Intrinsic::kAll:
case ast::Intrinsic::kAsin: case ast::Intrinsic::kAsin:
case ast::Intrinsic::kAtan: case ast::Intrinsic::kAtan:
case ast::Intrinsic::kAtan2: case ast::Intrinsic::kAtan2:
@ -949,10 +897,12 @@ std::string GeneratorImpl::generate_builtin_name(ast::CallExpression* expr) {
case ast::Intrinsic::kCross: case ast::Intrinsic::kCross:
case ast::Intrinsic::kDeterminant: case ast::Intrinsic::kDeterminant:
case ast::Intrinsic::kDistance: case ast::Intrinsic::kDistance:
case ast::Intrinsic::kDot:
case ast::Intrinsic::kExp: case ast::Intrinsic::kExp:
case ast::Intrinsic::kExp2: case ast::Intrinsic::kExp2:
case ast::Intrinsic::kFloor: case ast::Intrinsic::kFloor:
case ast::Intrinsic::kFma: case ast::Intrinsic::kFma:
case ast::Intrinsic::kLdexp:
case ast::Intrinsic::kLength: case ast::Intrinsic::kLength:
case ast::Intrinsic::kLog: case ast::Intrinsic::kLog:
case ast::Intrinsic::kLog2: case ast::Intrinsic::kLog2:
@ -975,15 +925,53 @@ std::string GeneratorImpl::generate_builtin_name(ast::CallExpression* expr) {
case ast::Intrinsic::kClamp: case ast::Intrinsic::kClamp:
out = builder_.Symbols().NameFor(ident->symbol()); out = builder_.Symbols().NameFor(ident->symbol());
break; break;
case ast::Intrinsic::kCountOneBits:
out = "countbits";
break;
case ast::Intrinsic::kDpdx:
out = "ddx";
break;
case ast::Intrinsic::kDpdxCoarse:
out = "ddx_coarse";
break;
case ast::Intrinsic::kDpdxFine:
out = "ddx_fine";
break;
case ast::Intrinsic::kDpdy:
out = "ddy";
break;
case ast::Intrinsic::kDpdyCoarse:
out = "ddy_coarse";
break;
case ast::Intrinsic::kDpdyFine:
out = "ddy_fine";
break;
case ast::Intrinsic::kFaceForward: case ast::Intrinsic::kFaceForward:
out = "faceforward"; out = "faceforward";
break; break;
case ast::Intrinsic::kFract: case ast::Intrinsic::kFract:
out = "frac"; out = "frac";
break; break;
case ast::Intrinsic::kFwidth:
case ast::Intrinsic::kFwidthCoarse:
case ast::Intrinsic::kFwidthFine:
out = "fwidth";
break;
case ast::Intrinsic::kInverseSqrt: case ast::Intrinsic::kInverseSqrt:
out = "rsqrt"; out = "rsqrt";
break; break;
case ast::Intrinsic::kIsFinite:
out = "isfinite";
break;
case ast::Intrinsic::kIsInf:
out = "isinf";
break;
case ast::Intrinsic::kIsNan:
out = "isnan";
break;
case ast::Intrinsic::kReverseBits:
out = "reversebits";
break;
case ast::Intrinsic::kSmoothStep: case ast::Intrinsic::kSmoothStep:
out = "smoothstep"; out = "smoothstep";
break; break;

View File

@ -345,14 +345,10 @@ class GeneratorImpl {
/// @returns the index string, or blank if unable to generate /// @returns the index string, or blank if unable to generate
std::string generate_storage_buffer_index_expression(std::ostream& pre, std::string generate_storage_buffer_index_expression(std::ostream& pre,
ast::Expression* expr); ast::Expression* expr);
/// Generates an intrinsic name from the given name
/// @param intrinsic the intrinsic to convert to a name
/// @returns the intrinsic name or blank on error
std::string generate_intrinsic_name(ast::Intrinsic intrinsic);
/// Handles generating a builtin method name /// Handles generating a builtin method name
/// @param expr the expression /// @param expr the expression
/// @returns the name or "" if not valid /// @returns the name or "" if not valid
std::string generate_builtin_name(ast::CallExpression* expr); std::string generate_builtin_name(ast::IdentifierExpression* expr);
/// Converts a builtin to an attribute name /// Converts a builtin to an attribute name
/// @param builtin the builtin to convert /// @param builtin the builtin to convert
/// @returns the string name of the builtin or blank on error /// @returns the string name of the builtin or blank on error

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <sstream>
#include "src/ast/call_expression.h" #include "src/ast/call_expression.h"
#include "src/ast/identifier_expression.h" #include "src/ast/identifier_expression.h"
#include "src/program.h" #include "src/program.h"
@ -27,43 +29,227 @@ namespace {
using HlslGeneratorImplTest_Intrinsic = TestHelper; using HlslGeneratorImplTest_Intrinsic = TestHelper;
enum class ParamType {
kF32,
kU32,
kBool,
};
struct IntrinsicData { struct IntrinsicData {
ast::Intrinsic intrinsic; ast::Intrinsic intrinsic;
ParamType type;
const char* hlsl_name; const char* hlsl_name;
}; };
inline std::ostream& operator<<(std::ostream& out, IntrinsicData data) { inline std::ostream& operator<<(std::ostream& out, IntrinsicData data) {
out << data.hlsl_name; out << data.hlsl_name;
switch (data.type) {
case ParamType::kF32:
out << "f32";
break;
case ParamType::kU32:
out << "u32";
break;
case ParamType::kBool:
out << "bool";
break;
}
out << ">";
return out; return out;
} }
ast::CallExpression* GenerateCall(ast::Intrinsic intrinsic,
ParamType type,
ProgramBuilder* builder) {
std::string name;
std::ostringstream str(name);
str << intrinsic;
switch (intrinsic) {
case ast::Intrinsic::kAcos:
case ast::Intrinsic::kAsin:
case ast::Intrinsic::kAtan:
case ast::Intrinsic::kCeil:
case ast::Intrinsic::kCos:
case ast::Intrinsic::kCosh:
case ast::Intrinsic::kDpdx:
case ast::Intrinsic::kDpdxCoarse:
case ast::Intrinsic::kDpdxFine:
case ast::Intrinsic::kDpdy:
case ast::Intrinsic::kDpdyCoarse:
case ast::Intrinsic::kDpdyFine:
case ast::Intrinsic::kExp:
case ast::Intrinsic::kExp2:
case ast::Intrinsic::kFloor:
case ast::Intrinsic::kFract:
case ast::Intrinsic::kFwidth:
case ast::Intrinsic::kFwidthCoarse:
case ast::Intrinsic::kFwidthFine:
case ast::Intrinsic::kInverseSqrt:
case ast::Intrinsic::kIsFinite:
case ast::Intrinsic::kIsInf:
case ast::Intrinsic::kIsNan:
case ast::Intrinsic::kIsNormal:
case ast::Intrinsic::kLdexp:
case ast::Intrinsic::kLength:
case ast::Intrinsic::kLog:
case ast::Intrinsic::kLog2:
case ast::Intrinsic::kNormalize:
case ast::Intrinsic::kReflect:
case ast::Intrinsic::kRound:
case ast::Intrinsic::kSin:
case ast::Intrinsic::kSinh:
case ast::Intrinsic::kSqrt:
case ast::Intrinsic::kTan:
case ast::Intrinsic::kTanh:
case ast::Intrinsic::kTrunc:
case ast::Intrinsic::kSign:
return builder->Call(str.str(), "f1");
break;
case ast::Intrinsic::kAtan2:
case ast::Intrinsic::kCross:
case ast::Intrinsic::kDot:
case ast::Intrinsic::kDistance:
case ast::Intrinsic::kPow:
case ast::Intrinsic::kStep:
return builder->Call(str.str(), "f1", "f2");
case ast::Intrinsic::kFma:
case ast::Intrinsic::kMix:
case ast::Intrinsic::kFaceForward:
case ast::Intrinsic::kSmoothStep:
return builder->Call(str.str(), "f1", "f2", "f3");
case ast::Intrinsic::kAll:
case ast::Intrinsic::kAny:
return builder->Call(str.str(), "b1");
case ast::Intrinsic::kAbs:
if (type == ParamType::kF32) {
return builder->Call(str.str(), "f1");
} else {
return builder->Call(str.str(), "u1");
}
case ast::Intrinsic::kCountOneBits:
case ast::Intrinsic::kReverseBits:
return builder->Call(str.str(), "u1");
case ast::Intrinsic::kMax:
case ast::Intrinsic::kMin:
if (type == ParamType::kF32) {
return builder->Call(str.str(), "f1", "f2");
} else {
return builder->Call(str.str(), "u1", "u2");
}
case ast::Intrinsic::kClamp:
if (type == ParamType::kF32) {
return builder->Call(str.str(), "f1", "f2", "f3");
} else {
return builder->Call(str.str(), "u1", "u2", "u3");
}
case ast::Intrinsic::kSelect:
return builder->Call(str.str(), "f1", "f2", "b1");
case ast::Intrinsic::kDeterminant:
return builder->Call(str.str(), "m1");
default:
break;
}
return nullptr;
}
using HlslIntrinsicTest = TestParamHelper<IntrinsicData>; using HlslIntrinsicTest = TestParamHelper<IntrinsicData>;
TEST_P(HlslIntrinsicTest, Emit) { TEST_P(HlslIntrinsicTest, Emit) {
auto param = GetParam(); auto param = GetParam();
auto* call = GenerateCall(param.intrinsic, param.type, this);
ASSERT_NE(nullptr, call) << "Unhandled intrinsic";
auto* f1 = Var("f1", ast::StorageClass::kFunction, ty.vec2<float>());
auto* f2 = Var("f2", ast::StorageClass::kFunction, ty.vec2<float>());
auto* f3 = Var("f3", ast::StorageClass::kFunction, ty.vec2<float>());
auto* u1 = Var("u1", ast::StorageClass::kFunction, ty.vec2<unsigned int>());
auto* u2 = Var("u2", ast::StorageClass::kFunction, ty.vec2<unsigned int>());
auto* u3 = Var("u3", ast::StorageClass::kFunction, ty.vec2<unsigned int>());
auto* b1 = Var("b1", ast::StorageClass::kFunction, ty.vec2<bool>());
auto* m1 = Var("m1", ast::StorageClass::kFunction, ty.mat2x2<float>());
td.RegisterVariableForTesting(f1);
td.RegisterVariableForTesting(f2);
td.RegisterVariableForTesting(f3);
td.RegisterVariableForTesting(u1);
td.RegisterVariableForTesting(u2);
td.RegisterVariableForTesting(u3);
td.RegisterVariableForTesting(b1);
td.RegisterVariableForTesting(m1);
ASSERT_TRUE(td.DetermineResultType(call)) << td.error();
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
EXPECT_EQ(gen.generate_intrinsic_name(param.intrinsic), param.hlsl_name); EXPECT_EQ(
gen.generate_builtin_name(call->func()->As<ast::IdentifierExpression>()),
param.hlsl_name);
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
HlslGeneratorImplTest_Intrinsic, HlslGeneratorImplTest_Intrinsic,
HlslIntrinsicTest, HlslIntrinsicTest,
testing::Values(IntrinsicData{ast::Intrinsic::kAny, "any"}, testing::Values(
IntrinsicData{ast::Intrinsic::kAll, "all"}, IntrinsicData{ast::Intrinsic::kAbs, ParamType::kF32, "abs"},
IntrinsicData{ast::Intrinsic::kCountOneBits, "countbits"}, IntrinsicData{ast::Intrinsic::kAbs, ParamType::kU32, "abs"},
IntrinsicData{ast::Intrinsic::kDot, "dot"}, IntrinsicData{ast::Intrinsic::kAcos, ParamType::kF32, "acos"},
IntrinsicData{ast::Intrinsic::kDpdx, "ddx"}, IntrinsicData{ast::Intrinsic::kAll, ParamType::kBool, "all"},
IntrinsicData{ast::Intrinsic::kDpdxCoarse, "ddx_coarse"}, IntrinsicData{ast::Intrinsic::kAny, ParamType::kBool, "any"},
IntrinsicData{ast::Intrinsic::kDpdxFine, "ddx_fine"}, IntrinsicData{ast::Intrinsic::kAsin, ParamType::kF32, "asin"},
IntrinsicData{ast::Intrinsic::kDpdy, "ddy"}, IntrinsicData{ast::Intrinsic::kAtan, ParamType::kF32, "atan"},
IntrinsicData{ast::Intrinsic::kDpdyCoarse, "ddy_coarse"}, IntrinsicData{ast::Intrinsic::kAtan2, ParamType::kF32, "atan2"},
IntrinsicData{ast::Intrinsic::kDpdyFine, "ddy_fine"}, IntrinsicData{ast::Intrinsic::kCeil, ParamType::kF32, "ceil"},
IntrinsicData{ast::Intrinsic::kFwidth, "fwidth"}, IntrinsicData{ast::Intrinsic::kClamp, ParamType::kF32, "clamp"},
IntrinsicData{ast::Intrinsic::kFwidthCoarse, "fwidth"}, IntrinsicData{ast::Intrinsic::kClamp, ParamType::kU32, "clamp"},
IntrinsicData{ast::Intrinsic::kFwidthFine, "fwidth"}, IntrinsicData{ast::Intrinsic::kCos, ParamType::kF32, "cos"},
IntrinsicData{ast::Intrinsic::kIsFinite, "isfinite"}, IntrinsicData{ast::Intrinsic::kCosh, ParamType::kF32, "cosh"},
IntrinsicData{ast::Intrinsic::kIsInf, "isinf"}, IntrinsicData{ast::Intrinsic::kCountOneBits, ParamType::kU32,
IntrinsicData{ast::Intrinsic::kIsNan, "isnan"}, "countbits"},
IntrinsicData{ast::Intrinsic::kReverseBits, IntrinsicData{ast::Intrinsic::kCross, ParamType::kF32, "cross"},
"reversebits"})); IntrinsicData{ast::Intrinsic::kDeterminant, ParamType::kF32,
"determinant"},
IntrinsicData{ast::Intrinsic::kDistance, ParamType::kF32, "distance"},
IntrinsicData{ast::Intrinsic::kDot, ParamType::kF32, "dot"},
IntrinsicData{ast::Intrinsic::kDpdx, ParamType::kF32, "ddx"},
IntrinsicData{ast::Intrinsic::kDpdxCoarse, ParamType::kF32,
"ddx_coarse"},
IntrinsicData{ast::Intrinsic::kDpdxFine, ParamType::kF32, "ddx_fine"},
IntrinsicData{ast::Intrinsic::kDpdy, ParamType::kF32, "ddy"},
IntrinsicData{ast::Intrinsic::kDpdyCoarse, ParamType::kF32,
"ddy_coarse"},
IntrinsicData{ast::Intrinsic::kDpdyFine, ParamType::kF32, "ddy_fine"},
IntrinsicData{ast::Intrinsic::kExp, ParamType::kF32, "exp"},
IntrinsicData{ast::Intrinsic::kExp2, ParamType::kF32, "exp2"},
IntrinsicData{ast::Intrinsic::kFaceForward, ParamType::kF32,
"faceforward"},
IntrinsicData{ast::Intrinsic::kFloor, ParamType::kF32, "floor"},
IntrinsicData{ast::Intrinsic::kFma, ParamType::kF32, "fma"},
IntrinsicData{ast::Intrinsic::kFract, ParamType::kF32, "frac"},
IntrinsicData{ast::Intrinsic::kFwidth, ParamType::kF32, "fwidth"},
IntrinsicData{ast::Intrinsic::kFwidthCoarse, ParamType::kF32, "fwidth"},
IntrinsicData{ast::Intrinsic::kFwidthFine, ParamType::kF32, "fwidth"},
IntrinsicData{ast::Intrinsic::kInverseSqrt, ParamType::kF32, "rsqrt"},
IntrinsicData{ast::Intrinsic::kIsFinite, ParamType::kF32, "isfinite"},
IntrinsicData{ast::Intrinsic::kIsInf, ParamType::kF32, "isinf"},
IntrinsicData{ast::Intrinsic::kIsNan, ParamType::kF32, "isnan"},
IntrinsicData{ast::Intrinsic::kLdexp, ParamType::kF32, "ldexp"},
IntrinsicData{ast::Intrinsic::kLength, ParamType::kF32, "length"},
IntrinsicData{ast::Intrinsic::kLog, ParamType::kF32, "log"},
IntrinsicData{ast::Intrinsic::kLog2, ParamType::kF32, "log2"},
IntrinsicData{ast::Intrinsic::kMax, ParamType::kF32, "max"},
IntrinsicData{ast::Intrinsic::kMax, ParamType::kU32, "max"},
IntrinsicData{ast::Intrinsic::kMin, ParamType::kF32, "min"},
IntrinsicData{ast::Intrinsic::kMin, ParamType::kU32, "min"},
IntrinsicData{ast::Intrinsic::kNormalize, ParamType::kF32, "normalize"},
IntrinsicData{ast::Intrinsic::kPow, ParamType::kF32, "pow"},
IntrinsicData{ast::Intrinsic::kReflect, ParamType::kF32, "reflect"},
IntrinsicData{ast::Intrinsic::kReverseBits, ParamType::kU32,
"reversebits"},
IntrinsicData{ast::Intrinsic::kRound, ParamType::kU32, "round"},
IntrinsicData{ast::Intrinsic::kSign, ParamType::kF32, "sign"},
IntrinsicData{ast::Intrinsic::kSin, ParamType::kF32, "sin"},
IntrinsicData{ast::Intrinsic::kSinh, ParamType::kF32, "sinh"},
IntrinsicData{ast::Intrinsic::kSmoothStep, ParamType::kF32,
"smoothstep"},
IntrinsicData{ast::Intrinsic::kSqrt, ParamType::kF32, "sqrt"},
IntrinsicData{ast::Intrinsic::kStep, ParamType::kF32, "step"},
IntrinsicData{ast::Intrinsic::kTan, ParamType::kF32, "tan"},
IntrinsicData{ast::Intrinsic::kTanh, ParamType::kF32, "tanh"},
IntrinsicData{ast::Intrinsic::kTrunc, ParamType::kF32, "trunc"}));
TEST_F(HlslGeneratorImplTest_Intrinsic, DISABLED_Intrinsic_IsNormal) { TEST_F(HlslGeneratorImplTest_Intrinsic, DISABLED_Intrinsic_IsNormal) {
FAIL(); FAIL();
@ -73,12 +259,6 @@ TEST_F(HlslGeneratorImplTest_Intrinsic, DISABLED_Intrinsic_Select) {
FAIL(); FAIL();
} }
TEST_F(HlslGeneratorImplTest_Intrinsic, Intrinsic_Bad_Name) {
GeneratorImpl& gen = Build();
EXPECT_EQ(gen.generate_intrinsic_name(ast::Intrinsic::kNone), "");
}
TEST_F(HlslGeneratorImplTest_Intrinsic, Intrinsic_Call) { TEST_F(HlslGeneratorImplTest_Intrinsic, Intrinsic_Call) {
auto* call = Call("dot", "param1", "param2"); auto* call = Call("dot", "param1", "param2");

View File

@ -431,55 +431,6 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) {
return name; return name;
} }
std::string GeneratorImpl::generate_intrinsic_name(ast::Intrinsic intrinsic) {
if (intrinsic == ast::Intrinsic::kAny) {
return "any";
}
if (intrinsic == ast::Intrinsic::kAll) {
return "all";
}
if (intrinsic == ast::Intrinsic::kCountOneBits) {
return "popcount";
}
if (intrinsic == ast::Intrinsic::kDot) {
return "dot";
}
if (intrinsic == ast::Intrinsic::kDpdy ||
intrinsic == ast::Intrinsic::kDpdyFine ||
intrinsic == ast::Intrinsic::kDpdyCoarse) {
return "dfdy";
}
if (intrinsic == ast::Intrinsic::kDpdx ||
intrinsic == ast::Intrinsic::kDpdxFine ||
intrinsic == ast::Intrinsic::kDpdxCoarse) {
return "dfdx";
}
if (intrinsic == ast::Intrinsic::kFwidth ||
intrinsic == ast::Intrinsic::kFwidthFine ||
intrinsic == ast::Intrinsic::kFwidthCoarse) {
return "fwidth";
}
if (intrinsic == ast::Intrinsic::kIsFinite) {
return "isfinite";
}
if (intrinsic == ast::Intrinsic::kIsInf) {
return "isinf";
}
if (intrinsic == ast::Intrinsic::kIsNan) {
return "isnan";
}
if (intrinsic == ast::Intrinsic::kIsNormal) {
return "isnormal";
}
if (intrinsic == ast::Intrinsic::kReverseBits) {
return "reverse_bits";
}
if (intrinsic == ast::Intrinsic::kSelect) {
return "select";
}
return "";
}
bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
auto* ident = expr->func()->As<ast::IdentifierExpression>(); auto* ident = expr->func()->As<ast::IdentifierExpression>();
@ -489,16 +440,13 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
} }
if (ident->IsIntrinsic()) { if (ident->IsIntrinsic()) {
auto name = generate_intrinsic_name(ident->intrinsic());
if (name.empty()) {
if (ast::intrinsic::IsTextureIntrinsic(ident->intrinsic())) { if (ast::intrinsic::IsTextureIntrinsic(ident->intrinsic())) {
return EmitTextureCall(expr); return EmitTextureCall(expr);
} }
name = generate_builtin_name(ident); auto name = generate_builtin_name(ident);
if (name.empty()) { if (name.empty()) {
return false; return false;
} }
}
make_indent(); make_indent();
out_ << name << "("; out_ << name << "(";
@ -819,6 +767,8 @@ std::string GeneratorImpl::generate_builtin_name(
std::string out = "metal::"; std::string out = "metal::";
switch (ident->intrinsic()) { switch (ident->intrinsic()) {
case ast::Intrinsic::kAcos: case ast::Intrinsic::kAcos:
case ast::Intrinsic::kAll:
case ast::Intrinsic::kAny:
case ast::Intrinsic::kAsin: case ast::Intrinsic::kAsin:
case ast::Intrinsic::kAtan: case ast::Intrinsic::kAtan:
case ast::Intrinsic::kAtan2: case ast::Intrinsic::kAtan2:
@ -828,12 +778,14 @@ std::string GeneratorImpl::generate_builtin_name(
case ast::Intrinsic::kCross: case ast::Intrinsic::kCross:
case ast::Intrinsic::kDeterminant: case ast::Intrinsic::kDeterminant:
case ast::Intrinsic::kDistance: case ast::Intrinsic::kDistance:
case ast::Intrinsic::kDot:
case ast::Intrinsic::kExp: case ast::Intrinsic::kExp:
case ast::Intrinsic::kExp2: case ast::Intrinsic::kExp2:
case ast::Intrinsic::kFloor: case ast::Intrinsic::kFloor:
case ast::Intrinsic::kFma: case ast::Intrinsic::kFma:
case ast::Intrinsic::kFract: case ast::Intrinsic::kFract:
case ast::Intrinsic::kLength: case ast::Intrinsic::kLength:
case ast::Intrinsic::kLdexp:
case ast::Intrinsic::kLog: case ast::Intrinsic::kLog:
case ast::Intrinsic::kLog2: case ast::Intrinsic::kLog2:
case ast::Intrinsic::kMix: case ast::Intrinsic::kMix:
@ -841,6 +793,7 @@ std::string GeneratorImpl::generate_builtin_name(
case ast::Intrinsic::kPow: case ast::Intrinsic::kPow:
case ast::Intrinsic::kReflect: case ast::Intrinsic::kReflect:
case ast::Intrinsic::kRound: case ast::Intrinsic::kRound:
case ast::Intrinsic::kSelect:
case ast::Intrinsic::kSin: case ast::Intrinsic::kSin:
case ast::Intrinsic::kSinh: case ast::Intrinsic::kSinh:
case ast::Intrinsic::kSqrt: case ast::Intrinsic::kSqrt:
@ -853,29 +806,62 @@ std::string GeneratorImpl::generate_builtin_name(
out += program_->Symbols().NameFor(ident->symbol()); out += program_->Symbols().NameFor(ident->symbol());
break; break;
case ast::Intrinsic::kAbs: case ast::Intrinsic::kAbs:
if (type->Is<type::F32>()) { if (type->is_float_scalar_or_vector()) {
out += "fabs"; out += "fabs";
} else if (type->Is<type::U32>() || type->Is<type::I32>()) { } else {
out += "abs"; out += "abs";
} }
break; break;
case ast::Intrinsic::kCountOneBits:
out += "popcount";
break;
case ast::Intrinsic::kDpdx:
case ast::Intrinsic::kDpdxCoarse:
case ast::Intrinsic::kDpdxFine:
out += "dfdx";
break;
case ast::Intrinsic::kDpdy:
case ast::Intrinsic::kDpdyCoarse:
case ast::Intrinsic::kDpdyFine:
out += "dfdy";
break;
case ast::Intrinsic::kFwidth:
case ast::Intrinsic::kFwidthCoarse:
case ast::Intrinsic::kFwidthFine:
out += "fwidth";
break;
case ast::Intrinsic::kIsFinite:
out += "isfinite";
break;
case ast::Intrinsic::kIsInf:
out += "isinf";
break;
case ast::Intrinsic::kIsNan:
out += "isnan";
break;
case ast::Intrinsic::kIsNormal:
out += "isnormal";
break;
case ast::Intrinsic::kMax: case ast::Intrinsic::kMax:
if (type->Is<type::F32>()) { if (type->is_float_scalar_or_vector()) {
out += "fmax"; out += "fmax";
} else if (type->Is<type::U32>() || type->Is<type::I32>()) { } else {
out += "max"; out += "max";
} }
break; break;
case ast::Intrinsic::kMin: case ast::Intrinsic::kMin:
if (type->Is<type::F32>()) { if (type->is_float_scalar_or_vector()) {
out += "fmin"; out += "fmin";
} else if (type->Is<type::U32>() || type->Is<type::I32>()) { } else {
out += "min"; out += "min";
} }
break; break;
case ast::Intrinsic::kFaceForward: case ast::Intrinsic::kFaceForward:
out += "faceforward"; out += "faceforward";
break; break;
case ast::Intrinsic::kReverseBits:
out += "reverse_bits";
break;
case ast::Intrinsic::kSmoothStep: case ast::Intrinsic::kSmoothStep:
out += "smoothstep"; out += "smoothstep";
break; break;

View File

@ -248,10 +248,6 @@ class GeneratorImpl : public TextGenerator {
/// @param prefix the prefix of the name to generate /// @param prefix the prefix of the name to generate
/// @returns the name /// @returns the name
std::string generate_name(const std::string& prefix); std::string generate_name(const std::string& prefix);
/// Generates an intrinsic name from the given name
/// @param intrinsic the intrinsic to convert to an method name
/// @returns the intrinsic name or blank on error
std::string generate_intrinsic_name(ast::Intrinsic intrinsic);
/// Handles generating a builtin name /// Handles generating a builtin name
/// @param ident the identifier to build the name from /// @param ident the identifier to build the name from
/// @returns the name or "" if not valid /// @returns the name or "" if not valid

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/ast/call_expression.h" #include "src/ast/call_expression.h"
#include "src/ast/identifier_expression.h" #include "src/ast/identifier_expression.h"
@ -29,50 +31,244 @@ namespace {
using MslGeneratorImplTest = TestHelper; using MslGeneratorImplTest = TestHelper;
enum class ParamType {
kF32,
kU32,
kBool,
};
struct IntrinsicData { struct IntrinsicData {
ast::Intrinsic intrinsic; ast::Intrinsic intrinsic;
ParamType type;
const char* msl_name; const char* msl_name;
}; };
inline std::ostream& operator<<(std::ostream& out, IntrinsicData data) { inline std::ostream& operator<<(std::ostream& out, IntrinsicData data) {
out << data.msl_name; out << data.msl_name << "<";
switch (data.type) {
case ParamType::kF32:
out << "f32";
break;
case ParamType::kU32:
out << "u32";
break;
case ParamType::kBool:
out << "bool";
break;
}
out << ">";
return out; return out;
} }
ast::CallExpression* GenerateCall(ast::Intrinsic intrinsic,
ParamType type,
ProgramBuilder* builder) {
std::string name;
std::ostringstream str(name);
str << intrinsic;
switch (intrinsic) {
case ast::Intrinsic::kAcos:
case ast::Intrinsic::kAsin:
case ast::Intrinsic::kAtan:
case ast::Intrinsic::kCeil:
case ast::Intrinsic::kCos:
case ast::Intrinsic::kCosh:
case ast::Intrinsic::kDpdx:
case ast::Intrinsic::kDpdxCoarse:
case ast::Intrinsic::kDpdxFine:
case ast::Intrinsic::kDpdy:
case ast::Intrinsic::kDpdyCoarse:
case ast::Intrinsic::kDpdyFine:
case ast::Intrinsic::kExp:
case ast::Intrinsic::kExp2:
case ast::Intrinsic::kFloor:
case ast::Intrinsic::kFract:
case ast::Intrinsic::kFwidth:
case ast::Intrinsic::kFwidthCoarse:
case ast::Intrinsic::kFwidthFine:
case ast::Intrinsic::kInverseSqrt:
case ast::Intrinsic::kIsFinite:
case ast::Intrinsic::kIsInf:
case ast::Intrinsic::kIsNan:
case ast::Intrinsic::kIsNormal:
case ast::Intrinsic::kLdexp:
case ast::Intrinsic::kLength:
case ast::Intrinsic::kLog:
case ast::Intrinsic::kLog2:
case ast::Intrinsic::kNormalize:
case ast::Intrinsic::kReflect:
case ast::Intrinsic::kRound:
case ast::Intrinsic::kSin:
case ast::Intrinsic::kSinh:
case ast::Intrinsic::kSqrt:
case ast::Intrinsic::kTan:
case ast::Intrinsic::kTanh:
case ast::Intrinsic::kTrunc:
case ast::Intrinsic::kSign:
return builder->Call(str.str(), "f1");
break;
case ast::Intrinsic::kAtan2:
case ast::Intrinsic::kCross:
case ast::Intrinsic::kDot:
case ast::Intrinsic::kDistance:
case ast::Intrinsic::kPow:
case ast::Intrinsic::kStep:
return builder->Call(str.str(), "f1", "f2");
case ast::Intrinsic::kFma:
case ast::Intrinsic::kMix:
case ast::Intrinsic::kFaceForward:
case ast::Intrinsic::kSmoothStep:
return builder->Call(str.str(), "f1", "f2", "f3");
case ast::Intrinsic::kAll:
case ast::Intrinsic::kAny:
return builder->Call(str.str(), "b1");
case ast::Intrinsic::kAbs:
if (type == ParamType::kF32) {
return builder->Call(str.str(), "f1");
} else {
return builder->Call(str.str(), "u1");
}
case ast::Intrinsic::kCountOneBits:
case ast::Intrinsic::kReverseBits:
return builder->Call(str.str(), "u1");
case ast::Intrinsic::kMax:
case ast::Intrinsic::kMin:
if (type == ParamType::kF32) {
return builder->Call(str.str(), "f1", "f2");
} else {
return builder->Call(str.str(), "u1", "u2");
}
case ast::Intrinsic::kClamp:
if (type == ParamType::kF32) {
return builder->Call(str.str(), "f1", "f2", "f3");
} else {
return builder->Call(str.str(), "u1", "u2", "u3");
}
case ast::Intrinsic::kSelect:
return builder->Call(str.str(), "f1", "f2", "b1");
case ast::Intrinsic::kDeterminant:
return builder->Call(str.str(), "m1");
default:
break;
}
return nullptr;
}
using MslIntrinsicTest = TestParamHelper<IntrinsicData>; using MslIntrinsicTest = TestParamHelper<IntrinsicData>;
TEST_P(MslIntrinsicTest, Emit) { TEST_P(MslIntrinsicTest, Emit) {
auto param = GetParam(); auto param = GetParam();
auto* call = GenerateCall(param.intrinsic, param.type, this);
ASSERT_NE(nullptr, call) << "Unhandled intrinsic";
auto* f1 = Var("f1", ast::StorageClass::kFunction, ty.vec2<float>());
auto* f2 = Var("f2", ast::StorageClass::kFunction, ty.vec2<float>());
auto* f3 = Var("f3", ast::StorageClass::kFunction, ty.vec2<float>());
auto* u1 = Var("u1", ast::StorageClass::kFunction, ty.vec2<unsigned int>());
auto* u2 = Var("u2", ast::StorageClass::kFunction, ty.vec2<unsigned int>());
auto* u3 = Var("u3", ast::StorageClass::kFunction, ty.vec2<unsigned int>());
auto* b1 = Var("b1", ast::StorageClass::kFunction, ty.vec2<bool>());
auto* m1 = Var("m1", ast::StorageClass::kFunction, ty.mat2x2<float>());
td.RegisterVariableForTesting(f1);
td.RegisterVariableForTesting(f2);
td.RegisterVariableForTesting(f3);
td.RegisterVariableForTesting(u1);
td.RegisterVariableForTesting(u2);
td.RegisterVariableForTesting(u3);
td.RegisterVariableForTesting(b1);
td.RegisterVariableForTesting(m1);
ASSERT_TRUE(td.DetermineResultType(call)) << td.error();
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
EXPECT_EQ(gen.generate_intrinsic_name(param.intrinsic), param.msl_name); EXPECT_EQ(
gen.generate_builtin_name(call->func()->As<ast::IdentifierExpression>()),
param.msl_name);
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
MslGeneratorImplTest, MslGeneratorImplTest,
MslIntrinsicTest, MslIntrinsicTest,
testing::Values(IntrinsicData{ast::Intrinsic::kAny, "any"}, testing::Values(
IntrinsicData{ast::Intrinsic::kAll, "all"}, IntrinsicData{ast::Intrinsic::kAbs, ParamType::kF32, "metal::fabs"},
IntrinsicData{ast::Intrinsic::kCountOneBits, "popcount"}, IntrinsicData{ast::Intrinsic::kAbs, ParamType::kU32, "metal::abs"},
IntrinsicData{ast::Intrinsic::kDot, "dot"}, IntrinsicData{ast::Intrinsic::kAcos, ParamType::kF32, "metal::acos"},
IntrinsicData{ast::Intrinsic::kDpdx, "dfdx"}, IntrinsicData{ast::Intrinsic::kAll, ParamType::kBool, "metal::all"},
IntrinsicData{ast::Intrinsic::kDpdxCoarse, "dfdx"}, IntrinsicData{ast::Intrinsic::kAny, ParamType::kBool, "metal::any"},
IntrinsicData{ast::Intrinsic::kDpdxFine, "dfdx"}, IntrinsicData{ast::Intrinsic::kAsin, ParamType::kF32, "metal::asin"},
IntrinsicData{ast::Intrinsic::kDpdy, "dfdy"}, IntrinsicData{ast::Intrinsic::kAtan, ParamType::kF32, "metal::atan"},
IntrinsicData{ast::Intrinsic::kDpdyCoarse, "dfdy"}, IntrinsicData{ast::Intrinsic::kAtan2, ParamType::kF32, "metal::atan2"},
IntrinsicData{ast::Intrinsic::kDpdyFine, "dfdy"}, IntrinsicData{ast::Intrinsic::kCeil, ParamType::kF32, "metal::ceil"},
IntrinsicData{ast::Intrinsic::kFwidth, "fwidth"}, IntrinsicData{ast::Intrinsic::kClamp, ParamType::kF32, "metal::clamp"},
IntrinsicData{ast::Intrinsic::kFwidthCoarse, "fwidth"}, IntrinsicData{ast::Intrinsic::kClamp, ParamType::kU32, "metal::clamp"},
IntrinsicData{ast::Intrinsic::kFwidthFine, "fwidth"}, IntrinsicData{ast::Intrinsic::kCos, ParamType::kF32, "metal::cos"},
IntrinsicData{ast::Intrinsic::kIsFinite, "isfinite"}, IntrinsicData{ast::Intrinsic::kCosh, ParamType::kF32, "metal::cosh"},
IntrinsicData{ast::Intrinsic::kIsInf, "isinf"}, IntrinsicData{ast::Intrinsic::kCountOneBits, ParamType::kU32,
IntrinsicData{ast::Intrinsic::kIsNan, "isnan"}, "metal::popcount"},
IntrinsicData{ast::Intrinsic::kIsNormal, "isnormal"}, IntrinsicData{ast::Intrinsic::kCross, ParamType::kF32, "metal::cross"},
IntrinsicData{ast::Intrinsic::kReverseBits, "reverse_bits"}, IntrinsicData{ast::Intrinsic::kDeterminant, ParamType::kF32,
IntrinsicData{ast::Intrinsic::kSelect, "select"})); "metal::determinant"},
IntrinsicData{ast::Intrinsic::kDistance, ParamType::kF32,
TEST_F(MslGeneratorImplTest, Intrinsic_Bad_Name) { "metal::distance"},
GeneratorImpl& gen = Build(); IntrinsicData{ast::Intrinsic::kDot, ParamType::kF32, "metal::dot"},
IntrinsicData{ast::Intrinsic::kDpdx, ParamType::kF32, "metal::dfdx"},
EXPECT_EQ(gen.generate_intrinsic_name(ast::Intrinsic::kNone), ""); IntrinsicData{ast::Intrinsic::kDpdxCoarse, ParamType::kF32,
} "metal::dfdx"},
IntrinsicData{ast::Intrinsic::kDpdxFine, ParamType::kF32,
"metal::dfdx"},
IntrinsicData{ast::Intrinsic::kDpdy, ParamType::kF32, "metal::dfdy"},
IntrinsicData{ast::Intrinsic::kDpdyCoarse, ParamType::kF32,
"metal::dfdy"},
IntrinsicData{ast::Intrinsic::kDpdyFine, ParamType::kF32,
"metal::dfdy"},
IntrinsicData{ast::Intrinsic::kExp, ParamType::kF32, "metal::exp"},
IntrinsicData{ast::Intrinsic::kExp2, ParamType::kF32, "metal::exp2"},
IntrinsicData{ast::Intrinsic::kFaceForward, ParamType::kF32,
"metal::faceforward"},
IntrinsicData{ast::Intrinsic::kFloor, ParamType::kF32, "metal::floor"},
IntrinsicData{ast::Intrinsic::kFma, ParamType::kF32, "metal::fma"},
IntrinsicData{ast::Intrinsic::kFract, ParamType::kF32, "metal::fract"},
IntrinsicData{ast::Intrinsic::kFwidth, ParamType::kF32,
"metal::fwidth"},
IntrinsicData{ast::Intrinsic::kFwidthCoarse, ParamType::kF32,
"metal::fwidth"},
IntrinsicData{ast::Intrinsic::kFwidthFine, ParamType::kF32,
"metal::fwidth"},
IntrinsicData{ast::Intrinsic::kInverseSqrt, ParamType::kF32,
"metal::rsqrt"},
IntrinsicData{ast::Intrinsic::kIsFinite, ParamType::kF32,
"metal::isfinite"},
IntrinsicData{ast::Intrinsic::kIsInf, ParamType::kF32, "metal::isinf"},
IntrinsicData{ast::Intrinsic::kIsNan, ParamType::kF32, "metal::isnan"},
IntrinsicData{ast::Intrinsic::kIsNormal, ParamType::kF32,
"metal::isnormal"},
IntrinsicData{ast::Intrinsic::kLdexp, ParamType::kF32, "metal::ldexp"},
IntrinsicData{ast::Intrinsic::kLength, ParamType::kF32,
"metal::length"},
IntrinsicData{ast::Intrinsic::kLog, ParamType::kF32, "metal::log"},
IntrinsicData{ast::Intrinsic::kLog2, ParamType::kF32, "metal::log2"},
IntrinsicData{ast::Intrinsic::kMax, ParamType::kF32, "metal::fmax"},
IntrinsicData{ast::Intrinsic::kMax, ParamType::kU32, "metal::max"},
IntrinsicData{ast::Intrinsic::kMin, ParamType::kF32, "metal::fmin"},
IntrinsicData{ast::Intrinsic::kMin, ParamType::kU32, "metal::min"},
IntrinsicData{ast::Intrinsic::kNormalize, ParamType::kF32,
"metal::normalize"},
IntrinsicData{ast::Intrinsic::kPow, ParamType::kF32, "metal::pow"},
IntrinsicData{ast::Intrinsic::kReflect, ParamType::kF32,
"metal::reflect"},
IntrinsicData{ast::Intrinsic::kReverseBits, ParamType::kU32,
"metal::reverse_bits"},
IntrinsicData{ast::Intrinsic::kRound, ParamType::kU32, "metal::round"},
IntrinsicData{ast::Intrinsic::kSelect, ParamType::kF32,
"metal::select"},
IntrinsicData{ast::Intrinsic::kSign, ParamType::kF32, "metal::sign"},
IntrinsicData{ast::Intrinsic::kSin, ParamType::kF32, "metal::sin"},
IntrinsicData{ast::Intrinsic::kSinh, ParamType::kF32, "metal::sinh"},
IntrinsicData{ast::Intrinsic::kSmoothStep, ParamType::kF32,
"metal::smoothstep"},
IntrinsicData{ast::Intrinsic::kSqrt, ParamType::kF32, "metal::sqrt"},
IntrinsicData{ast::Intrinsic::kStep, ParamType::kF32, "metal::step"},
IntrinsicData{ast::Intrinsic::kTan, ParamType::kF32, "metal::tan"},
IntrinsicData{ast::Intrinsic::kTanh, ParamType::kF32, "metal::tanh"},
IntrinsicData{ast::Intrinsic::kTrunc, ParamType::kF32,
"metal::trunc"}));
TEST_F(MslGeneratorImplTest, Intrinsic_Call) { TEST_F(MslGeneratorImplTest, Intrinsic_Call) {
auto* call = Call("dot", "param1", "param2"); auto* call = Call("dot", "param1", "param2");
@ -89,7 +285,7 @@ TEST_F(MslGeneratorImplTest, Intrinsic_Call) {
gen.increment_indent(); gen.increment_indent();
ASSERT_TRUE(gen.EmitExpression(call)) << gen.error(); ASSERT_TRUE(gen.EmitExpression(call)) << gen.error();
EXPECT_EQ(gen.result(), " dot(param1, param2)"); EXPECT_EQ(gen.result(), " metal::dot(param1, param2)");
} }
} // namespace } // namespace