Browse Source

Optimize specific function coalescing in lowering. (#5684)

- Update all the llvm::Function pointers after function replacement.
Some were previously left in an inconsistent state.
- Only do function replacement once, after converging on the canonical
specific to use.
Alina Sbirlea 10 months ago
parent
commit
b0be6619ef

+ 55 - 42
toolchain/lower/file_context.cpp

@@ -188,12 +188,17 @@ auto FileContext::ContainsPair(
 
 auto FileContext::CoalesceEquivalentSpecifics() -> void {
   for (auto& specifics : lowered_specifics_.values()) {
+    // Collect specifics to delete for each generic. Replace and remove each
+    // after processing all specifics for a generic. Note, we could also
+    // replace and remove all specifics after processing all generics.
+    llvm::SmallVector<SemIR::SpecificId> specifics_to_delete;
     // i cannot be unsigned due to the comparison with a negative number when
     // the specifics vector is empty.
     for (int i = 0; i < static_cast<int>(specifics.size()) - 1; ++i) {
       // This specific was already replaced, skip it.
       if (equivalent_specifics_.Get(specifics[i]).has_value() &&
           equivalent_specifics_.Get(specifics[i]) != specifics[i]) {
+        specifics_to_delete.push_back(specifics[i]);
         specifics[i] = specifics[specifics.size() - 1];
         specifics.pop_back();
         --i;
@@ -205,6 +210,7 @@ auto FileContext::CoalesceEquivalentSpecifics() -> void {
         // When the specific was already replaced, skip it.
         if (equivalent_specifics_.Get(specifics[j]).has_value() &&
             equivalent_specifics_.Get(specifics[j]) != specifics[j]) {
+          specifics_to_delete.push_back(specifics[j]);
           specifics[j] = specifics[specifics.size() - 1];
           specifics.pop_back();
           --j;
@@ -231,26 +237,18 @@ auto FileContext::CoalesceEquivalentSpecifics() -> void {
           // When processing equivalences, we may change the canonical specific
           // multiple times, so we don't delete replaced specifics until the
           // end.
-          llvm::SmallVector<SemIR::SpecificId> specifics_to_delete;
           visited_equivalent_specifics.ForEach(
               [&](std::pair<SemIR::SpecificId, SemIR::SpecificId>
                       equivalent_entry) {
                 CARBON_VLOG("Found equivalent specifics: {0}, {1}",
                             equivalent_entry.first, equivalent_entry.second);
-                ProcessSpecificEquivalence(equivalent_entry,
-                                           specifics_to_delete);
+                ProcessSpecificEquivalence(equivalent_entry);
               });
 
-          // Delete function bodies for already replaced functions.
-          for (auto specific_id : specifics_to_delete) {
-            specific_functions_.Get(specific_id)->eraseFromParent();
-            specific_functions_.Get(specific_id) =
-                specific_functions_.Get(equivalent_specifics_.Get(specific_id));
-          }
-
           // Removed the replaced specific from the list of emitted specifics.
           // Only the top level, since the others are somewhere else in the
           // vector, they will be found and removed during processing.
+          specifics_to_delete.push_back(specifics[j]);
           specifics[j] = specifics[specifics.size() - 1];
           specifics.pop_back();
           --j;
@@ -260,25 +258,27 @@ auto FileContext::CoalesceEquivalentSpecifics() -> void {
         }
       }
     }
+
+    // Once all equivalences are found for a generic, update and delete up
+    // equivalent specifics.
+    for (auto specific_id : specifics_to_delete) {
+      UpdateAndDeleteLLVMFunction(specific_id);
+    }
   }
 }
 
 auto FileContext::ProcessSpecificEquivalence(
-    std::pair<SemIR::SpecificId, SemIR::SpecificId> pair,
-    llvm::SmallVector<SemIR::SpecificId>& specifics_to_delete) -> void {
+    std::pair<SemIR::SpecificId, SemIR::SpecificId> pair) -> void {
   auto [specific_id1, specific_id2] = pair;
   CARBON_CHECK(specific_id1.has_value() && specific_id2.has_value(),
                "Expected values in equivalence check");
 
   auto get_canon = [&](SemIR::SpecificId specific_id) {
-    return equivalent_specifics_.Get(specific_id).has_value()
-               ? std::make_pair(
-                     equivalent_specifics_.Get(specific_id),
-                     (equivalent_specifics_.Get(specific_id) != specific_id))
-               : std::make_pair(specific_id, false);
+    auto equiv_id = equivalent_specifics_.Get(specific_id);
+    return equiv_id.has_value() ? equiv_id : specific_id;
   };
-  auto [canon_id1, replaced_before1] = get_canon(specific_id1);
-  auto [canon_id2, replaced_before2] = get_canon(specific_id2);
+  auto canon_id1 = get_canon(specific_id1);
+  auto canon_id2 = get_canon(specific_id2);
 
   if (canon_id1 == canon_id2) {
     // Already equivalent, there was a previous replacement.
@@ -288,7 +288,6 @@ auto FileContext::ProcessSpecificEquivalence(
   if (canon_id1.index >= canon_id2.index) {
     // Prefer the earlier index for canonical values.
     std::swap(canon_id1, canon_id2);
-    std::swap(replaced_before1, replaced_before2);
   }
 
   // Update equivalent_specifics_ for all. This is used as an indicator that
@@ -296,11 +295,41 @@ auto FileContext::ProcessSpecificEquivalence(
   // chains in `IsKnownEquivalence`.
   equivalent_specifics_.Set(specific_id1, canon_id1);
   equivalent_specifics_.Set(specific_id2, canon_id1);
-  specific_functions_.Get(canon_id2)->replaceAllUsesWith(
-      specific_functions_.Get(canon_id1));
-  if (!replaced_before2) {
-    specifics_to_delete.push_back(canon_id2);
+  equivalent_specifics_.Set(canon_id1, canon_id1);
+  equivalent_specifics_.Set(canon_id2, canon_id1);
+}
+
+auto FileContext::UpdateEquivalentSpecific(SemIR::SpecificId specific_id)
+    -> void {
+  if (!equivalent_specifics_.Get(specific_id).has_value()) {
+    return;
+  }
+
+  llvm::SmallVector<SemIR::SpecificId> stack;
+  SemIR::SpecificId specific_to_update = specific_id;
+  SemIR::SpecificId equivalent = equivalent_specifics_.Get(specific_to_update);
+  SemIR::SpecificId equivalent_next = equivalent_specifics_.Get(equivalent);
+  while (equivalent != equivalent_next) {
+    stack.push_back(specific_to_update);
+    specific_to_update = equivalent;
+    equivalent = equivalent_next;
+    equivalent_next = equivalent_specifics_.Get(equivalent_next);
   }
+
+  for (auto specific : stack) {
+    equivalent_specifics_.Set(specific, equivalent);
+  }
+}
+
+auto FileContext::UpdateAndDeleteLLVMFunction(SemIR::SpecificId specific_id)
+    -> void {
+  UpdateEquivalentSpecific(specific_id);
+  auto* old_function = specific_functions_.Get(specific_id);
+  auto* new_function =
+      specific_functions_.Get(equivalent_specifics_.Get(specific_id));
+  old_function->replaceAllUsesWith(new_function);
+  old_function->eraseFromParent();
+  specific_functions_.Set(specific_id, new_function);
 }
 
 auto FileContext::IsKnownEquivalence(SemIR::SpecificId specific_id1,
@@ -310,24 +339,8 @@ auto FileContext::IsKnownEquivalence(SemIR::SpecificId specific_id1,
     return false;
   }
 
-  auto update_equivalent_specific = [&](SemIR::SpecificId specific_id) {
-    llvm::SmallVector<SemIR::SpecificId> stack;
-    SemIR::SpecificId specific_to_update = specific_id;
-    while (equivalent_specifics_.Get(
-               equivalent_specifics_.Get(specific_to_update)) !=
-           equivalent_specifics_.Get(specific_to_update)) {
-      stack.push_back(specific_to_update);
-      specific_to_update = equivalent_specifics_.Get(specific_to_update);
-    }
-    for (auto specific : llvm::reverse(stack)) {
-      equivalent_specifics_.Set(
-          specific,
-          equivalent_specifics_.Get(equivalent_specifics_.Get(specific)));
-    }
-  };
-
-  update_equivalent_specific(specific_id1);
-  update_equivalent_specific(specific_id2);
+  UpdateEquivalentSpecific(specific_id1);
+  UpdateEquivalentSpecific(specific_id2);
 
   return equivalent_specifics_.Get(specific_id1) ==
          equivalent_specifics_.Get(specific_id2);

+ 13 - 5
toolchain/lower/file_context.h

@@ -235,12 +235,9 @@ class FileContext {
           visited_equivalent_specifics) -> bool;
 
   // Given an equivalent pair of specifics, updates the canonical specific to
-  // use for each of the two Specifics found to be equivalent, replaces all
-  // uses of one specific with the canonical one, and adds the non-canonical
-  // specific to specifics_to_delete.
+  // use for each of the two Specifics found to be equivalent.
   auto ProcessSpecificEquivalence(
-      std::pair<SemIR::SpecificId, SemIR::SpecificId> pair,
-      llvm::SmallVector<SemIR::SpecificId>& specifics_to_delete) -> void;
+      std::pair<SemIR::SpecificId, SemIR::SpecificId> pair) -> void;
 
   // Checks if two specific_ids are equivalent and also reduces the equivalence
   // chains/paths. This update ensures the canonical specific is always "one
@@ -248,6 +245,17 @@ class FileContext {
   auto IsKnownEquivalence(SemIR::SpecificId specific_id1,
                           SemIR::SpecificId specific_id2) -> bool;
 
+  // Update the tracked equivalent specific for the `SpecificId`. This may
+  // occur a replacement was performed and a chain of such replacements needs
+  // to be followed to discover the canonical specific for the given argument.
+  auto UpdateEquivalentSpecific(SemIR::SpecificId specific_id) -> void;
+
+  // Update the LLVM function to use for a `SpecificId` that has been found to
+  // have another equivalent LLVM function. Replace all uses of the original
+  // LLVM function with the equivalent one found, and delete the previous LLVM
+  // function body.
+  auto UpdateAndDeleteLLVMFunction(SemIR::SpecificId specific_id) -> void;
+
   // Inserts a pair into a set of pairs in canonical form. Also implicitly
   // checks entry already existed if it cannot be inserted.
   auto InsertPair(

+ 1 - 1
toolchain/lower/testdata/function/generic/call_recursive_reorder_more.carbon

@@ -227,7 +227,7 @@ fn M() {
 // CHECK:STDOUT:
 // CHECK:STDOUT: ; uselistorder directives
 // CHECK:STDOUT: uselistorder ptr @llvm.lifetime.start.p0, { 8, 7, 6, 5, 4, 3, 2, 1, 0 }
-// CHECK:STDOUT: uselistorder ptr @_CF.Main.7776e910959584d9, { 5, 7, 6, 4, 3, 2, 1, 0 }
+// CHECK:STDOUT: uselistorder ptr @_CF.Main.7776e910959584d9, { 5, 7, 6, 4, 0, 1, 2, 3 }
 // CHECK:STDOUT:
 // CHECK:STDOUT: attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) }
 // CHECK:STDOUT: