Support for generating DNA on explicit class specializations

This commit is contained in:
Jack Andersen 2020-04-08 19:19:07 -10:00
parent d0eef5eab7
commit 6f55ae4d26
1 changed files with 48 additions and 14 deletions

View File

@ -222,7 +222,7 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
.append(", s)"); .append(", s)");
} }
static void RecurseNestedTypeName(const clang::DeclContext* decl, std::string& templateStmt, std::string& qualType) { void RecurseNestedTypeName(const clang::DeclContext* decl, std::string& templateStmt, std::string& qualType) {
if (!decl) if (!decl)
return; return;
RecurseNestedTypeName(decl->getParent(), templateStmt, qualType); RecurseNestedTypeName(decl->getParent(), templateStmt, qualType);
@ -230,7 +230,20 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
if (!qualType.empty()) if (!qualType.empty())
qualType += "::"; qualType += "::";
qualType += rec->getName(); qualType += rec->getName();
if (const clang::ClassTemplateDecl* ct = rec->getDescribedClassTemplate()) { const clang::ClassTemplateSpecializationDecl* cts =
clang::dyn_cast_or_null<clang::ClassTemplateSpecializationDecl>(rec);
if (cts && cts->isExplicitSpecialization()) {
qualType += '<';
bool needsComma = false;
for (auto& arg : cts->getTemplateArgs().asArray()) {
if (needsComma)
qualType += ", ";
llvm::raw_string_ostream OS(qualType);
arg.print(context.getPrintingPolicy(), OS);
needsComma = true;
}
qualType += '>';
} else if (const clang::ClassTemplateDecl* ct = rec->getDescribedClassTemplate()) {
templateStmt += "template <"; templateStmt += "template <";
qualType += '<'; qualType += '<';
bool needsComma = false; bool needsComma = false;
@ -264,13 +277,13 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
} }
} }
static void GetNestedTypeName(const clang::DeclContext* decl, std::string& templateStmt, std::string& qualType) { void GetNestedTypeName(const clang::DeclContext* decl, std::string& templateStmt, std::string& qualType) {
templateStmt.clear(); templateStmt.clear();
qualType.clear(); qualType.clear();
RecurseNestedTypeName(decl, templateStmt, qualType); RecurseNestedTypeName(decl, templateStmt, qualType);
} }
static void RecurseNestedTypeSpecializations(const clang::DeclContext* decl, void RecurseNestedTypeSpecializations(const clang::DeclContext* decl,
std::vector<std::pair<std::string, int>>& specializations) { std::vector<std::pair<std::string, int>>& specializations) {
if (!decl) { if (!decl) {
specializations.emplace_back(); specializations.emplace_back();
@ -281,7 +294,28 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
RecurseNestedTypeSpecializations(decl->getParent(), parentSpecializations); RecurseNestedTypeSpecializations(decl->getParent(), parentSpecializations);
bool foundSpecializations = false; bool foundSpecializations = false;
if (const clang::CXXRecordDecl* rec = clang::dyn_cast_or_null<clang::CXXRecordDecl>(decl)) { if (const clang::CXXRecordDecl* rec = clang::dyn_cast_or_null<clang::CXXRecordDecl>(decl)) {
if (const clang::ClassTemplateDecl* ct = rec->getDescribedClassTemplate()) { const clang::ClassTemplateSpecializationDecl* cts =
clang::dyn_cast_or_null<clang::ClassTemplateSpecializationDecl>(rec);
if (cts && cts->isExplicitSpecialization()) {
for (const auto& parent : parentSpecializations) {
if (parent.first.empty()) {
specializations.emplace_back(std::string(rec->getName().str()).append(1, '<'), 0);
} else {
auto specialization = std::string(parent.first).append("::").append(rec->getName().str()).append(1, '<');
specializations.emplace_back(std::move(specialization), parent.second);
}
bool needsComma = false;
for (auto& arg : cts->getTemplateArgs().asArray()) {
if (needsComma)
specializations.back().first += ", ";
llvm::raw_string_ostream OS(specializations.back().first);
arg.print(context.getPrintingPolicy(), OS);
needsComma = true;
}
specializations.back().first += '>';
}
foundSpecializations = true;
} else if (const clang::ClassTemplateDecl* ct = rec->getDescribedClassTemplate()) {
int numParms = 0; int numParms = 0;
for (const clang::NamedDecl* parm : *ct->getTemplateParameters()) for (const clang::NamedDecl* parm : *ct->getTemplateParameters())
if (clang::dyn_cast_or_null<clang::TemplateTypeParmDecl>(parm) || if (clang::dyn_cast_or_null<clang::TemplateTypeParmDecl>(parm) ||
@ -340,7 +374,7 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
} }
} }
static std::vector<std::pair<std::string, int>> GetNestedTypeSpecializations(const clang::DeclContext* decl) { std::vector<std::pair<std::string, int>> GetNestedTypeSpecializations(const clang::DeclContext* decl) {
std::vector<std::pair<std::string, int>> ret; std::vector<std::pair<std::string, int>> ret;
RecurseNestedTypeSpecializations(decl, ret); RecurseNestedTypeSpecializations(decl, ret);
return ret; return ret;
@ -515,7 +549,8 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
std::string sizeExprStr; std::string sizeExprStr;
for (const clang::TemplateArgument& arg : *tsType) { for (const clang::TemplateArgument& arg : *tsType) {
if (arg.getKind() == clang::TemplateArgument::Expression) { if (arg.getKind() == clang::TemplateArgument::Expression) {
const auto* uExpr = static_cast<const clang::UnaryExprOrTypeTraitExpr*>(arg.getAsExpr()->IgnoreImpCasts()); const auto* uExpr =
static_cast<const clang::UnaryExprOrTypeTraitExpr*>(arg.getAsExpr()->IgnoreImpCasts());
if (uExpr->getStmtClass() == clang::Stmt::UnaryExprOrTypeTraitExprClass && if (uExpr->getStmtClass() == clang::Stmt::UnaryExprOrTypeTraitExprClass &&
uExpr->getKind() == clang::UETT_SizeOf) { uExpr->getKind() == clang::UETT_SizeOf) {
const clang::Expr* argExpr = uExpr->getArgumentExpr(); const clang::Expr* argExpr = uExpr->getArgumentExpr();
@ -670,8 +705,7 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
int64_t directionVal = direction.getSExtValue(); int64_t directionVal = direction.getSExtValue();
if (directionVal < 0 || directionVal > 2) { if (directionVal < 0 || directionVal > 2) {
if (directionExpr) { if (directionExpr) {
clang::DiagnosticBuilder diag = clang::DiagnosticBuilder diag = context.getDiagnostics().Report(directionExpr->getExprLoc(), AthenaError);
context.getDiagnostics().Report(directionExpr->getExprLoc(), AthenaError);
diag.AddString("Direction parameter must be 'Begin', 'Current', or 'End'"); diag.AddString("Direction parameter must be 'Begin', 'Current', or 'End'");
diag.AddSourceRange(clang::CharSourceRange(directionExpr->getSourceRange(), true)); diag.AddSourceRange(clang::CharSourceRange(directionExpr->getSourceRange(), true));
} else { } else {
@ -906,7 +940,8 @@ class ATDNAEmitVisitor : public clang::RecursiveASTVisitor<ATDNAEmitVisitor> {
std::string sizeExprStr; std::string sizeExprStr;
for (const clang::TemplateArgument& arg : *tsType) { for (const clang::TemplateArgument& arg : *tsType) {
if (arg.getKind() == clang::TemplateArgument::Expression) { if (arg.getKind() == clang::TemplateArgument::Expression) {
const auto* uExpr = static_cast<const clang::UnaryExprOrTypeTraitExpr*>(arg.getAsExpr()->IgnoreImpCasts()); const auto* uExpr =
static_cast<const clang::UnaryExprOrTypeTraitExpr*>(arg.getAsExpr()->IgnoreImpCasts());
if (uExpr->getStmtClass() == clang::Stmt::UnaryExprOrTypeTraitExprClass && if (uExpr->getStmtClass() == clang::Stmt::UnaryExprOrTypeTraitExprClass &&
uExpr->getKind() == clang::UETT_SizeOf) { uExpr->getKind() == clang::UETT_SizeOf) {
const clang::Expr* argExpr = uExpr->getArgumentExpr(); const clang::Expr* argExpr = uExpr->getArgumentExpr();
@ -1067,8 +1102,7 @@ public:
/* Determine if is is a YAML DNA */ /* Determine if is is a YAML DNA */
bool isYamlDNA = false; bool isYamlDNA = false;
for (const clang::CXXMethodDecl* method : decl->methods()) for (const clang::CXXMethodDecl* method : decl->methods())
if (method->getDeclName().isIdentifier() && if (method->getDeclName().isIdentifier() && (method->getName() == "read" || method->getName() == "write") &&
(method->getName() == "read" || method->getName() == "write") &&
method->getNumParams() == 1 && method->getNumParams() == 1 &&
method->getParamDecl(0)->getType().getAsString() == "athena::io::YAMLDocReader &") { method->getParamDecl(0)->getType().getAsString() == "athena::io::YAMLDocReader &") {
isYamlDNA = true; isYamlDNA = true;