Browse Source

Add support to rewriter for free-function declarations and definitions. (#2120)

Andy Soffer 3 years ago
parent
commit
1b4544ffed
3 changed files with 153 additions and 4 deletions
  1. 94 3
      migrate_cpp/rewriter.cpp
  2. 6 1
      migrate_cpp/rewriter.h
  3. 53 0
      migrate_cpp/rewriter_test.cpp

+ 94 - 3
migrate_cpp/rewriter.cpp

@@ -173,6 +173,9 @@ auto RewriteBuilder::VisitBuiltinTypeLoc(clang::BuiltinTypeLoc type_loc)
     case clang::BuiltinType::Double:
       content = "f64";
       break;
+    case clang::BuiltinType::Void:
+      content = "()";
+      break;
     default:
       // In this case we do not know what the output should be so we do not
       // write any.
@@ -203,6 +206,12 @@ auto RewriteBuilder::VisitDeclStmt(clang::DeclStmt* stmt) -> bool {
   return true;
 }
 
+auto RewriteBuilder::VisitImplicitCastExpr(clang::ImplicitCastExpr* expr)
+    -> bool {
+  SetReplacement(expr, OutputSegment(expr->getSubExpr()));
+  return true;
+}
+
 auto RewriteBuilder::VisitIntegerLiteral(clang::IntegerLiteral* expr) -> bool {
   // TODO: Replace suffixes.
   std::string text(TextForTokenAt(expr->getBeginLoc()));
@@ -217,6 +226,22 @@ auto RewriteBuilder::VisitIntegerLiteral(clang::IntegerLiteral* expr) -> bool {
   return true;
 }
 
+auto RewriteBuilder::VisitParmVarDecl(clang::ParmVarDecl* decl) -> bool {
+  llvm::StringRef name = decl->getName();
+  std::vector<OutputSegment> segments = {
+      OutputSegment(llvm::formatv("{0}: ", name.empty() ? "_" : name.str())),
+      OutputSegment(decl->getTypeSourceInfo()->getTypeLoc()),
+  };
+
+  if (clang::Expr* init = decl->getInit()) {
+    segments.push_back(OutputSegment(" = "));
+    segments.push_back(OutputSegment(init));
+  }
+
+  SetReplacement(decl, std::move(segments));
+  return true;
+}
+
 auto RewriteBuilder::VisitPointerTypeLoc(clang::PointerTypeLoc type_loc)
     -> bool {
   SetReplacement(type_loc,
@@ -224,6 +249,12 @@ auto RewriteBuilder::VisitPointerTypeLoc(clang::PointerTypeLoc type_loc)
   return true;
 }
 
+auto RewriteBuilder::VisitReturnStmt(clang::ReturnStmt* stmt) -> bool {
+  SetReplacement(
+      stmt, {OutputSegment("return "), OutputSegment(stmt->getRetValue())});
+  return true;
+}
+
 auto RewriteBuilder::VisitTranslationUnitDecl(clang::TranslationUnitDecl* decl)
     -> bool {
   std::vector<OutputSegment> segments;
@@ -241,7 +272,11 @@ auto RewriteBuilder::VisitTranslationUnitDecl(clang::TranslationUnitDecl* decl)
   for (; iter != decl->decls_end(); ++iter) {
     clang::Decl* d = *iter;
     segments.push_back(OutputSegment(d));
-    segments.push_back(OutputSegment(";\n"));
+
+    // Function definitions do not need semicolons.
+    bool needs_semicolon = !(llvm::isa<clang::FunctionDecl>(d) &&
+                             llvm::cast<clang::FunctionDecl>(d)->hasBody());
+    segments.push_back(OutputSegment(needs_semicolon ? ";\n" : "\n"));
   }
 
   SetReplacement(decl, std::move(segments));
@@ -262,10 +297,62 @@ auto RewriteBuilder::VisitUnaryOperator(clang::UnaryOperator* expr) -> bool {
   return true;
 }
 
-auto RewriteBuilder::VisitVarDecl(clang::VarDecl* decl) -> bool {
+auto RewriteBuilder::TraverseFunctionDecl(clang::FunctionDecl* decl) -> bool {
+  clang::TypeLoc return_type_loc = decl->getFunctionTypeLoc().getReturnLoc();
+  if (!TraverseTypeLoc(return_type_loc)) {
+    return false;
+  }
+
+  std::vector<OutputSegment> segments;
+  segments.push_back(
+      OutputSegment(llvm::formatv("fn {0}(", decl->getNameAsString())));
+
+  size_t i = 0;
+  for (; i + 1 < decl->getNumParams(); ++i) {
+    clang::ParmVarDecl* param = decl->getParamDecl(i);
+    if (!TraverseDecl(param)) {
+      return false;
+    }
+    segments.push_back(OutputSegment(param));
+    segments.push_back(OutputSegment(", "));
+  }
+
+  if (i + 1 == decl->getNumParams()) {
+    clang::ParmVarDecl* param = decl->getParamDecl(i);
+    if (!TraverseDecl(param)) {
+      return false;
+    }
+    segments.push_back(OutputSegment(param));
+  }
+
+  segments.push_back(OutputSegment(") -> "));
+  segments.push_back(OutputSegment(return_type_loc));
+
+  if (decl->hasBody()) {
+    segments.push_back(OutputSegment(" {\n"));
+    auto* stmts = llvm::dyn_cast<clang::CompoundStmt>(decl->getBody());
+    for (clang::Stmt* stmt : stmts->body()) {
+      if (!TraverseStmt(stmt)) {
+        return false;
+      }
+      segments.push_back(OutputSegment(stmt));
+      segments.push_back(OutputSegment(";\n"));
+    }
+    segments.push_back(OutputSegment("}"));
+  }
+
+  SetReplacement(decl, std::move(segments));
+  return true;
+}
+
+auto RewriteBuilder::TraverseVarDecl(clang::VarDecl* decl) -> bool {
+  clang::TypeLoc loc = decl->getTypeSourceInfo()->getTypeLoc();
+  if (!TraverseTypeLoc(loc)) {
+    return false;
+  }
+
   // TODO: Check storage class. Determine what happens for static local
   // variables.
-
   bool is_const = decl->getType().isConstQualified();
   std::vector<OutputSegment> segments = {
       OutputSegment(llvm::formatv("{0} {1}: ", is_const ? "let" : "var",
@@ -274,6 +361,10 @@ auto RewriteBuilder::VisitVarDecl(clang::VarDecl* decl) -> bool {
   };
 
   if (clang::Expr* init = decl->getInit()) {
+    if (!TraverseStmt(init)) {
+      return false;
+    }
+
     segments.push_back(OutputSegment(" = "));
     segments.push_back(OutputSegment(init));
   }

+ 6 - 1
migrate_cpp/rewriter.h

@@ -104,11 +104,16 @@ class RewriteBuilder : public clang::RecursiveASTVisitor<RewriteBuilder> {
   auto VisitCXXBoolLiteralExpr(clang::CXXBoolLiteralExpr* expr) -> bool;
   auto VisitDeclRefExpr(clang::DeclRefExpr* expr) -> bool;
   auto VisitDeclStmt(clang::DeclStmt* stmt) -> bool;
+  auto VisitImplicitCastExpr(clang::ImplicitCastExpr* expr) -> bool;
   auto VisitIntegerLiteral(clang::IntegerLiteral* expr) -> bool;
+  auto VisitParmVarDecl(clang::ParmVarDecl* decl) -> bool;
   auto VisitPointerTypeLoc(clang::PointerTypeLoc type_loc) -> bool;
+  auto VisitReturnStmt(clang::ReturnStmt* stmt) -> bool;
   auto VisitTranslationUnitDecl(clang::TranslationUnitDecl* decl) -> bool;
   auto VisitUnaryOperator(clang::UnaryOperator* expr) -> bool;
-  auto VisitVarDecl(clang::VarDecl* decl) -> bool;
+
+  auto TraverseFunctionDecl(clang::FunctionDecl* decl) -> bool;
+  auto TraverseVarDecl(clang::VarDecl* decl) -> bool;
 
   auto segments() const -> const SegmentMapType& { return segments_; }
   auto segments() -> SegmentMapType& { return segments_; }

+ 53 - 0
migrate_cpp/rewriter_test.cpp

@@ -131,5 +131,58 @@ TEST(Rewriter, DeclarationComma) {
             "let y: i32 = 5678;\n");
 }
 
+TEST(Rewriter, FunctionDeclaration) {
+  // Function declarations and definitions returning void.
+  EXPECT_EQ(RewriteText("void f();"), "fn f() -> ();\n");
+  EXPECT_EQ(RewriteText("void f() {}"),
+            "fn f() -> () {\n"
+            "}\n");
+
+  // Function declarations and definitions returning int.
+  EXPECT_EQ(RewriteText("int f();"), "fn f() -> i32;\n");
+  EXPECT_EQ(RewriteText("int f() { return 0; }"),
+            "fn f() -> i32 {\n"
+            "return 0;\n"
+            "}\n");
+
+  // Function declarations and definitions with a single parameter.
+  EXPECT_EQ(RewriteText("int f(bool);"), "fn f(_: bool) -> i32;\n");
+  EXPECT_EQ(RewriteText("int f(bool b);"), "fn f(b: bool) -> i32;\n");
+  EXPECT_EQ(RewriteText("int f(bool) { return 0; }"),
+            "fn f(_: bool) -> i32 {\n"
+            "return 0;\n"
+            "}\n");
+  EXPECT_EQ(RewriteText("int f(bool b) { return 0; }"),
+            "fn f(b: bool) -> i32 {\n"
+            "return 0;\n"
+            "}\n");
+
+  // Function declarations and definitions with a multiple parameters.
+  EXPECT_EQ(RewriteText("int f(bool, int);"),
+            "fn f(_: bool, _: i32) -> i32;\n");
+  EXPECT_EQ(RewriteText("int f(bool b, int n);"),
+            "fn f(b: bool, n: i32) -> i32;\n");
+  EXPECT_EQ(RewriteText("int f(bool, int n) { return 0; }"),
+            "fn f(_: bool, n: i32) -> i32 {\n"
+            "return 0;\n"
+            "}\n");
+  EXPECT_EQ(RewriteText("int f(bool b, int n) { return 0; }"),
+            "fn f(b: bool, n: i32) -> i32 {\n"
+            "return 0;\n"
+            "}\n");
+  EXPECT_EQ(RewriteText("int f(bool b, int n = 3) { return n; }"),
+            "fn f(b: bool, n: i32 = 3) -> i32 {\n"
+            "return n;\n"
+            "}\n");
+
+  // Function declarations with trailing-return syntax.
+  EXPECT_EQ(RewriteText("auto f(bool b, int n = 3) -> int;"),
+            "fn f(b: bool, n: i32 = 3) -> i32;\n");
+  EXPECT_EQ(RewriteText("auto f(bool b, int n = 3) -> int { return n; }"),
+            "fn f(b: bool, n: i32 = 3) -> i32 {\n"
+            "return n;\n"
+            "}\n");
+}
+
 }  // namespace
 }  // namespace Carbon::Testing