浏览代码

Cache multi-IR info, particularly include_in_dumps (#5408)

Right now we construct `tree_and_subtrees_getters` a couple different
ways, it's just not obvious because one's abstracted in `check`. But
also, when formatting IR, we'll repeatedly do the `IncludeInDumps`
string check, which felt odd to me since it only needs to be calculated
once per IR.

This also shifts `CheckIRId` selection a little earlier, and in doing so
makes `CheckParseTrees` accept a sparse `units` argument. I actually
think this is a positive: it makes `CheckIRId` a little more stable
across possible command lines, when file loading fails (which is the
only time that a file will have a `CompilationUnit` but not a
`Check::Unit`).

Trying to build on the shared issue between these, I'm adding a
`MultiUnitCache` to store the calculated arrays. For the subtree
getters, this is very minor and avoids at most one incremental array
construction (moving logic out of `CompileSubcommand::Run` might be the
bigger benefit). For `include_in_dumps`, when dumping SemIR, this is
changing a calculation run once per entity (in each IR) to be calculated
once per IR (globally), i.e. O(M*N) -> O(N).

Note this seems to be marginal for performance of file_test:

- Before: Stats over 10 runs: max = 5.3s, min = 4.7s, avg = 4.9s, dev =
0.2s
- After: Stats over 10 runs: max = 4.9s, min = 4.7s, avg = 4.8s, dev =
0.1s

I was mainly thinking about this in the context of dumping SemIR ranges.
There, the impact may actually decrease because a range won't do any
cross-IR printing. But, I'm expecting to add another layer for whether
we're printing IR for a file, and that made the `should_format_entity`
callback stick out for me.
Jon Ross-Perkins 1 年之前
父节点
当前提交
0683742f19

+ 10 - 11
toolchain/check/check.cpp

@@ -322,19 +322,18 @@ static auto BuildApiMapAndDiagnosePackaging(
   return api_map;
 }
 
-auto CheckParseTrees(llvm::MutableArrayRef<Unit> units, bool prelude_import,
-                     llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
-                     llvm::raw_ostream* vlog_stream, bool fuzzing) -> void {
+auto CheckParseTrees(
+    llvm::MutableArrayRef<Unit> units,
+    llvm::ArrayRef<Parse::GetTreeAndSubtreesFn> tree_and_subtrees_getters,
+    bool prelude_import, llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
+    llvm::raw_ostream* vlog_stream, bool fuzzing) -> void {
   // UnitAndImports is big due to its SmallVectors, so we default to 0 on the
   // stack.
-  llvm::SmallVector<UnitAndImports, 0> unit_infos;
-  llvm::SmallVector<Parse::GetTreeAndSubtreesFn> tree_and_subtrees_getters;
-  unit_infos.reserve(units.size());
-  tree_and_subtrees_getters.reserve(units.size());
-  for (auto [i, unit] : llvm::enumerate(units)) {
-    unit_infos.emplace_back(SemIR::CheckIRId(i), unit);
-    tree_and_subtrees_getters.push_back(unit.tree_and_subtrees_getter);
-  }
+  llvm::SmallVector<UnitAndImports, 0> unit_infos(
+      llvm::map_range(units, [&](Unit& unit) {
+        return UnitAndImports(
+            &unit, tree_and_subtrees_getters[unit.sem_ir->check_ir_id().index]);
+      }));
 
   Map<ImportKey, UnitAndImports*> api_map =
       BuildApiMapAndDiagnosePackaging(unit_infos);

+ 5 - 6
toolchain/check/check.h

@@ -22,9 +22,6 @@ struct Unit {
   // The `timings` may be null if nothing is to be recorded.
   Timings* timings;
 
-  // Returns a lazily constructed TreeAndSubtrees.
-  Parse::GetTreeAndSubtreesFn tree_and_subtrees_getter;
-
   // The unit's SemIR, provided as empty and filled in by CheckParseTrees.
   SemIR::File* sem_ir;
 
@@ -34,9 +31,11 @@ struct Unit {
 
 // Checks a group of parse trees. This will use imports to decide the order of
 // checking.
-auto CheckParseTrees(llvm::MutableArrayRef<Unit> units, bool prelude_import,
-                     llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
-                     llvm::raw_ostream* vlog_stream, bool fuzzing) -> void;
+auto CheckParseTrees(
+    llvm::MutableArrayRef<Unit> units,
+    llvm::ArrayRef<Parse::GetTreeAndSubtreesFn> tree_and_subtrees_getters,
+    bool prelude_import, llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
+    llvm::raw_ostream* vlog_stream, bool fuzzing) -> void;
 
 }  // namespace Carbon::Check
 

+ 7 - 5
toolchain/check/check_unit.cpp

@@ -54,15 +54,17 @@ CheckUnit::CheckUnit(
     llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
     llvm::raw_ostream* vlog_stream)
     : unit_and_imports_(unit_and_imports),
+      tree_and_subtrees_getter_(
+          tree_and_subtrees_getters
+              [unit_and_imports->unit->sem_ir->check_ir_id().index]),
       total_ir_count_(tree_and_subtrees_getters.size()),
       fs_(std::move(fs)),
       vlog_stream_(vlog_stream),
       emitter_(&unit_and_imports_->err_tracker, tree_and_subtrees_getters,
                unit_and_imports_->unit->sem_ir),
-      context_(&emitter_, unit_and_imports_->unit->tree_and_subtrees_getter,
-               unit_and_imports_->unit->sem_ir,
-               GetImportedIRCount(unit_and_imports),
-               tree_and_subtrees_getters.size(), vlog_stream) {}
+      context_(
+          &emitter_, tree_and_subtrees_getter_, unit_and_imports_->unit->sem_ir,
+          GetImportedIRCount(unit_and_imports), total_ir_count_, vlog_stream) {}
 
 auto CheckUnit::Run() -> void {
   Timings::ScopedTiming timing(unit_and_imports_->unit->timings, "check");
@@ -366,7 +368,7 @@ auto CheckUnit::ProcessNodeIds() -> bool {
 
   // On crash, report which token we were handling.
   PrettyStackTraceFunction node_dumper([&](llvm::raw_ostream& output) {
-    const auto& tree = unit_and_imports_->unit->tree_and_subtrees_getter();
+    const auto& tree = tree_and_subtrees_getter_();
     auto converted = tree.NodeToDiagnosticLoc(node_id, /*token_only=*/false);
     converted.loc.FormatLocation(output);
     output << "checking " << context_.parse_tree().node_kind(node_id) << "\n";

+ 6 - 6
toolchain/check/check_unit.h

@@ -68,18 +68,17 @@ struct UnitAndImports {
     Parse::GetTreeAndSubtreesFn tree_and_subtrees_getter_;
   };
 
-  explicit UnitAndImports(SemIR::CheckIRId check_ir_id, Unit& unit)
-      : check_ir_id(check_ir_id),
-        unit(&unit),
-        err_tracker(*unit.consumer),
-        emitter(&err_tracker, unit.tree_and_subtrees_getter) {}
+  explicit UnitAndImports(Unit* unit,
+                          Parse::GetTreeAndSubtreesFn tree_and_subtrees_getter)
+      : unit(unit),
+        err_tracker(*unit->consumer),
+        emitter(&err_tracker, tree_and_subtrees_getter) {}
 
   auto parse_tree() -> const Parse::Tree& { return unit->sem_ir->parse_tree(); }
   auto source() -> const SourceBuffer& {
     return parse_tree().tokens().source();
   }
 
-  SemIR::CheckIRId check_ir_id;
   Unit* unit;
 
   // Emitter information.
@@ -180,6 +179,7 @@ class CheckUnit {
   auto ProcessNodeIds() -> bool;
 
   UnitAndImports* unit_and_imports_;
+  Parse::GetTreeAndSubtreesFn tree_and_subtrees_getter_;
   // The number of IRs being checked in total.
   int total_ir_count_;
   llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs_;

+ 147 - 93
toolchain/driver/compile_subcommand.cpp

@@ -345,13 +345,20 @@ auto CompileSubcommand::ValidateOptions(
 }
 
 namespace {
+
+class MultiUnitCache;
+
 // Ties together information for a file being compiled.
 class CompilationUnit {
  public:
-  explicit CompilationUnit(DriverEnv& driver_env, const CompileOptions& options,
+  explicit CompilationUnit(int unit_index, DriverEnv* driver_env,
+                           const CompileOptions* options,
                            Diagnostics::Consumer* consumer,
                            llvm::StringRef input_filename);
 
+  // Sets the multi-unit cache and initializes dependent member state.
+  auto SetMultiUnitCache(MultiUnitCache* cache) -> void;
+
   // Loads source and lexes it. Returns true on success.
   auto RunLex() -> void;
 
@@ -359,14 +366,13 @@ class CompilationUnit {
   auto RunParse() -> void;
 
   // Returns information needed to check this unit.
-  auto GetCheckUnit(SemIR::CheckIRId check_ir_id) -> Check::Unit;
+  auto GetCheckUnit() -> Check::Unit;
 
   // Runs post-check logic. Returns true if checking succeeded for the IR.
   auto PostCheck() -> void;
 
   // Lower SemIR to LLVM IR.
-  auto RunLower(std::optional<llvm::ArrayRef<Parse::GetTreeAndSubtreesFn>>
-                    tree_and_subtrees_getters_for_debug_info) -> void;
+  auto RunLower() -> void;
 
   auto RunCodeGen() -> void;
 
@@ -400,15 +406,17 @@ class CompilationUnit {
                llvm::StringLiteral timing_label,
                llvm::function_ref<auto()->void> fn) -> void;
 
-  // Returns true if the current input file can be dumped.
-  auto IncludeInDumps() const -> bool;
+  // Returns true if the current file should be included in debug dumps.
+  auto IncludeInDumps() -> bool;
 
-  // Returns true if the specified input file can be dumped.
-  auto IncludeInDumps(llvm::StringRef filename) const -> bool;
+  // The index of the unit amongst all units. Equivalent to a `CheckIRId`.
+  int unit_index_;
 
   DriverEnv* driver_env_;
+  const CompileOptions* options_;
+
   SharedValueStores value_stores_;
-  const CompileOptions& options_;
+
   // The input filename from the command line. For most diagnostics, we
   // typically use `source_->filename()`, which includes a `-` -> `<stdin>`
   // translation. However, logging and some diagnostics use the command line
@@ -424,6 +432,8 @@ class CompilationUnit {
 
   bool success_ = true;
 
+  // Initialized by `SetMultiUnitCache`.
+  MultiUnitCache* cache_ = nullptr;
   // Tracks memory usage of the compile.
   std::optional<MemUsage> mem_usage_;
   // Tracks timings of the compile.
@@ -442,29 +452,105 @@ class CompilationUnit {
   std::unique_ptr<llvm::Module> module_;
 };
 
-CompilationUnit::CompilationUnit(DriverEnv& driver_env,
-                                 const CompileOptions& options,
+// Caches lists that are shared cross-unit. Accessors do lazy caching because
+// they may not be used.
+class MultiUnitCache {
+ public:
+  // This relies on construction after `units` are all initialized, which is
+  // reflected by the `ArrayRef` here.
+  explicit MultiUnitCache(
+      const CompileOptions* options,
+      const llvm::ArrayRef<std::unique_ptr<CompilationUnit>> units)
+      : options_(options), units_(units) {}
+
+  auto include_in_dumps() -> llvm::ArrayRef<bool> {
+    CARBON_CHECK(!units_.empty());
+    if (include_in_dumps_.empty()) {
+      BuildIncludeInDumps();
+    }
+    return include_in_dumps_;
+  }
+
+  auto tree_and_subtrees_getters()
+      -> llvm::ArrayRef<Parse::GetTreeAndSubtreesFn> {
+    CARBON_CHECK(!units_.empty());
+    if (tree_and_subtrees_getters_.empty()) {
+      BuildTreeAndSubtreesGetters();
+    }
+    return tree_and_subtrees_getters_;
+  }
+
+ private:
+  auto BuildIncludeInDumps() -> void {
+    CARBON_CHECK(include_in_dumps_.empty());
+    llvm::append_range(include_in_dumps_,
+                       llvm::map_range(units_, [&](const auto& unit) {
+                         return options_->exclude_dump_file_prefix.empty() ||
+                                !unit->input_filename().starts_with(
+                                    options_->exclude_dump_file_prefix);
+                       }));
+  }
+
+  auto BuildTreeAndSubtreesGetters() -> void {
+    CARBON_CHECK(tree_and_subtrees_getters_.empty());
+    llvm::append_range(
+        tree_and_subtrees_getters_,
+        llvm::map_range(units_, [&](const auto& unit) {
+          return unit->has_source() ? unit->get_trees_and_subtrees() : nullptr;
+        }));
+  }
+
+  const CompileOptions* options_;
+
+  // The units being compiled.
+  const llvm::ArrayRef<std::unique_ptr<CompilationUnit>> units_;
+
+  // For each unit, whether it's included in dumps. Used cross-phase.
+  llvm::SmallVector<bool> include_in_dumps_;
+
+  // For each unit, the `TreeAndSubtrees` getter. Used by lowering.
+  llvm::SmallVector<Parse::GetTreeAndSubtreesFn> tree_and_subtrees_getters_;
+};
+
+}  // namespace
+
+CompilationUnit::CompilationUnit(int unit_index, DriverEnv* driver_env,
+                                 const CompileOptions* options,
                                  Diagnostics::Consumer* consumer,
                                  llvm::StringRef input_filename)
-    : driver_env_(&driver_env),
+    : unit_index_(unit_index),
+      driver_env_(driver_env),
       options_(options),
       input_filename_(input_filename),
       vlog_stream_(driver_env_->vlog_stream) {
-  if (vlog_stream_ != nullptr || options_.stream_errors) {
+  if (vlog_stream_ != nullptr || options_->stream_errors) {
     consumer_ = consumer;
   } else {
     sorting_consumer_ = Diagnostics::SortingConsumer(*consumer);
     consumer_ = &*sorting_consumer_;
   }
-  if (options_.dump_mem_usage && IncludeInDumps()) {
+}
+
+auto CompilationUnit::IncludeInDumps() -> bool {
+  return cache_->include_in_dumps()[unit_index_];
+}
+
+auto CompilationUnit::SetMultiUnitCache(MultiUnitCache* cache) -> void {
+  CARBON_CHECK(!cache_, "Called SetMultiUnitCache twice");
+  cache_ = cache;
+
+  if (options_->dump_mem_usage && IncludeInDumps()) {
+    CARBON_CHECK(!mem_usage_);
     mem_usage_ = MemUsage();
   }
-  if (options_.dump_timings && IncludeInDumps()) {
+  if (options_->dump_timings && IncludeInDumps()) {
+    CARBON_CHECK(!timings_);
     timings_ = Timings();
   }
 }
 
 auto CompilationUnit::RunLex() -> void {
+  CARBON_CHECK(cache_, "Must call SetMultiUnitCache first");
   CARBON_CHECK(!tokens_, "Called RunLex twice");
 
   LogCall("SourceBuffer::MakeFromFileOrStdin", "source", [&] {
@@ -485,10 +571,10 @@ auto CompilationUnit::RunLex() -> void {
 
   LogCall("Lex::Lex", "lex",
           [&] { tokens_ = Lex::Lex(value_stores_, *source_, *consumer_); });
-  if (options_.dump_tokens && IncludeInDumps()) {
+  if (options_->dump_tokens && IncludeInDumps()) {
     consumer_->Flush();
     tokens_->Print(*driver_env_->output_stream,
-                   options_.omit_file_boundary_tokens);
+                   options_->omit_file_boundary_tokens);
   }
   if (mem_usage_) {
     mem_usage_->Collect("tokens_", *tokens_);
@@ -503,10 +589,10 @@ auto CompilationUnit::RunParse() -> void {
   LogCall("Parse::Parse", "parse", [&] {
     parse_tree_ = Parse::Parse(*tokens_, *consumer_, vlog_stream_);
   });
-  if (options_.dump_parse_tree && IncludeInDumps()) {
+  if (options_->dump_parse_tree && IncludeInDumps()) {
     consumer_->Flush();
     const auto& tree_and_subtrees = GetParseTreeAndSubtrees();
-    if (options_.preorder_parse_tree) {
+    if (options_->preorder_parse_tree) {
       tree_and_subtrees.PrintPreorder(*driver_env_->output_stream);
     } else {
       tree_and_subtrees.Print(*driver_env_->output_stream);
@@ -521,20 +607,19 @@ auto CompilationUnit::RunParse() -> void {
   }
 }
 
-auto CompilationUnit::GetCheckUnit(SemIR::CheckIRId check_ir_id)
-    -> Check::Unit {
+auto CompilationUnit::GetCheckUnit() -> Check::Unit {
   CARBON_CHECK(parse_tree_, "Must call RunParse first");
   CARBON_CHECK(!sem_ir_, "Called GetCheckUnit twice");
 
   tree_and_subtrees_getter_ = [this]() -> const Parse::TreeAndSubtrees& {
     return this->GetParseTreeAndSubtrees();
   };
-  sem_ir_.emplace(&*parse_tree_, check_ir_id, parse_tree_->packaging_decl(),
-                  value_stores_, input_filename_);
+  sem_ir_.emplace(&*parse_tree_, SemIR::CheckIRId(unit_index_),
+                  parse_tree_->packaging_decl(), value_stores_,
+                  input_filename_);
   return {.consumer = consumer_,
           .value_stores = &value_stores_,
           .timings = timings_ ? &*timings_ : nullptr,
-          .tree_and_subtrees_getter = *tree_and_subtrees_getter_,
           .sem_ir = &*sem_ir_,
           .cpp_ast = &cpp_ast_};
 }
@@ -551,25 +636,18 @@ auto CompilationUnit::PostCheck() -> void {
     mem_usage_->Collect("sem_ir_", *sem_ir_);
   }
 
-  if (options_.dump_raw_sem_ir && IncludeInDumps()) {
+  if (options_->dump_raw_sem_ir && IncludeInDumps()) {
     CARBON_VLOG("*** Raw SemIR::File ***\n{0}\n", *sem_ir_);
-    sem_ir_->Print(*driver_env_->output_stream, options_.builtin_sem_ir);
-    if (options_.dump_sem_ir) {
+    sem_ir_->Print(*driver_env_->output_stream, options_->builtin_sem_ir);
+    if (options_->dump_sem_ir) {
       *driver_env_->output_stream << "\n";
     }
   }
 
-  bool print = options_.dump_sem_ir && IncludeInDumps();
+  bool print = options_->dump_sem_ir && IncludeInDumps();
   if (vlog_stream_ || print) {
-    // Omit entities imported from files that we are not dumping.
-    auto should_format_entity = [&](SemIR::InstId entity_inst_id) -> bool {
-      auto [import_ir, _] =
-          SemIR::GetCanonicalFileAndInstId(&*sem_ir_, entity_inst_id);
-      return IncludeInDumps(import_ir->filename());
-    };
-
-    SemIR::Formatter formatter(&*sem_ir_, should_format_entity,
-                               *tree_and_subtrees_getter_);
+    SemIR::Formatter formatter(&*sem_ir_, *tree_and_subtrees_getter_,
+                               cache_->include_in_dumps());
     formatter.Format();
     if (vlog_stream_) {
       CARBON_VLOG("*** SemIR::File ***\n");
@@ -584,18 +662,19 @@ auto CompilationUnit::PostCheck() -> void {
   }
 }
 
-auto CompilationUnit::RunLower(
-    std::optional<llvm::ArrayRef<Parse::GetTreeAndSubtreesFn>>
-        tree_and_subtrees_getters_for_debug_info) -> void {
+auto CompilationUnit::RunLower() -> void {
   LogCall("Lower::LowerToLLVM", "lower", [&] {
     llvm_context_ = std::make_unique<llvm::LLVMContext>();
     // TODO: Consider disabling instruction naming by default if we're not
     // producing textual LLVM IR.
     SemIR::InstNamer inst_namer(&*sem_ir_);
-    module_ = Lower::LowerToLLVM(*llvm_context_,
-                                 tree_and_subtrees_getters_for_debug_info,
-                                 input_filename_, *sem_ir_, sem_ir_->cpp_ast(),
-                                 &inst_namer, vlog_stream_);
+    std::optional<llvm::ArrayRef<Parse::GetTreeAndSubtreesFn>> subtrees;
+    if (options_->include_debug_info) {
+      subtrees = cache_->tree_and_subtrees_getters();
+    }
+    module_ =
+        Lower::LowerToLLVM(*llvm_context_, subtrees, input_filename_, *sem_ir_,
+                           sem_ir_->cpp_ast(), &inst_namer, vlog_stream_);
   });
   if (vlog_stream_) {
     CARBON_VLOG("*** llvm::Module ***\n");
@@ -603,7 +682,7 @@ auto CompilationUnit::RunLower(
                    /*ShouldPreserveUseListOrder=*/false,
                    /*IsForDebug=*/true);
   }
-  if (options_.dump_llvm_ir && IncludeInDumps()) {
+  if (options_->dump_llvm_ir && IncludeInDumps()) {
     module_->print(*driver_env_->output_stream, /*AAW=*/nullptr,
                    /*ShouldPreserveUseListOrder=*/true);
   }
@@ -615,7 +694,7 @@ auto CompilationUnit::RunCodeGen() -> void {
 }
 
 auto CompilationUnit::PostCompile() -> void {
-  if (options_.dump_shared_values && IncludeInDumps()) {
+  if (options_->dump_shared_values && IncludeInDumps()) {
     Yaml::Print(*driver_env_->output_stream,
                 value_stores_.OutputYaml(input_filename_));
   }
@@ -636,7 +715,7 @@ auto CompilationUnit::PostCompile() -> void {
 
 auto CompilationUnit::RunCodeGenHelper() -> bool {
   std::optional<CodeGen> codegen =
-      CodeGen::Make(module_.get(), options_.codegen_options.target,
+      CodeGen::Make(module_.get(), options_->codegen_options.target,
                     driver_env_->error_stream);
   if (!codegen) {
     return false;
@@ -646,11 +725,11 @@ auto CompilationUnit::RunCodeGenHelper() -> bool {
     codegen->EmitAssembly(*vlog_stream_);
   }
 
-  if (options_.output_filename == "-") {
+  if (options_->output_filename == "-") {
     // TODO: The output file name, forcing object output, and requesting
     // textual assembly output are all somewhat linked flags. We should add
     // some validation that they are used correctly.
-    if (options_.force_obj_output) {
+    if (options_->force_obj_output) {
       if (!codegen->EmitObject(*driver_env_->output_stream)) {
         return false;
       }
@@ -660,7 +739,7 @@ auto CompilationUnit::RunCodeGenHelper() -> bool {
       }
     }
   } else {
-    llvm::SmallString<256> output_filename = options_.output_filename;
+    llvm::SmallString<256> output_filename = options_->output_filename;
     if (output_filename.empty()) {
       if (!source_->is_regular_file()) {
         // Don't invent file names like `-.o` or `/dev/stdin.o`.
@@ -675,7 +754,7 @@ auto CompilationUnit::RunCodeGenHelper() -> bool {
       }
       output_filename = input_filename_;
       llvm::sys::path::replace_extension(output_filename,
-                                         options_.asm_output ? ".s" : ".o");
+                                         options_->asm_output ? ".s" : ".o");
     } else {
       // TODO: Handle the case where multiple input files were specified
       // along with an output file name. That should either be an error or
@@ -698,7 +777,7 @@ auto CompilationUnit::RunCodeGenHelper() -> bool {
                                 output_filename.str().str(), ec.message());
       return false;
     }
-    if (options_.asm_output) {
+    if (options_->asm_output) {
       if (!codegen->EmitAssembly(output_file)) {
         return false;
       }
@@ -732,17 +811,6 @@ auto CompilationUnit::LogCall(llvm::StringLiteral logging_label,
   CARBON_VLOG("*** {0} done ***\n", logging_label);
 }
 
-auto CompilationUnit::IncludeInDumps() const -> bool {
-  return IncludeInDumps(input_filename_);
-}
-
-auto CompilationUnit::IncludeInDumps(llvm::StringRef filename) const -> bool {
-  return options_.exclude_dump_file_prefix.empty() ||
-         !filename.starts_with(options_.exclude_dump_file_prefix);
-}
-
-}  // namespace
-
 auto CompileSubcommand::Run(DriverEnv& driver_env) -> DriverResult {
   if (!ValidateOptions(driver_env.emitter)) {
     return {.success = false};
@@ -767,18 +835,19 @@ auto CompileSubcommand::Run(DriverEnv& driver_env) -> DriverResult {
 
   // Prepare CompilationUnits before building scope exit handlers.
   llvm::SmallVector<std::unique_ptr<CompilationUnit>> units;
-  units.reserve(prelude.size() + options_.input_filenames.size());
-
-  // Add the prelude files.
-  for (const auto& input_filename : prelude) {
-    units.push_back(std::make_unique<CompilationUnit>(
-        driver_env, options_, &driver_env.consumer, input_filename));
-  }
+  int unit_index = -1;
+  auto unit_builder = [&](llvm::StringRef filename) {
+    return std::make_unique<CompilationUnit>(
+        ++unit_index, &driver_env, &options_, &driver_env.consumer, filename);
+  };
+  llvm::append_range(units, llvm::map_range(prelude, unit_builder));
+  llvm::append_range(units,
+                     llvm::map_range(options_.input_filenames, unit_builder));
 
-  // Add the input source files.
-  for (const auto& input_filename : options_.input_filenames) {
-    units.push_back(std::make_unique<CompilationUnit>(
-        driver_env, options_, &driver_env.consumer, input_filename));
+  // Add the cache to all units. This must be done after all units are created.
+  MultiUnitCache cache(&options_, units);
+  for (auto& unit : units) {
+    unit->SetMultiUnitCache(&cache);
   }
 
   auto on_exit = llvm::make_scope_exit([&]() {
@@ -847,14 +916,14 @@ auto CompileSubcommand::Run(DriverEnv& driver_env) -> DriverResult {
   check_units.reserve(units.size());
   for (auto& unit : units) {
     if (unit->has_source()) {
-      SemIR::CheckIRId check_ir_id(check_units.size());
-      check_units.push_back(unit->GetCheckUnit(check_ir_id));
+      check_units.push_back(unit->GetCheckUnit());
     }
   }
 
   // Execute the actual checking.
   CARBON_VLOG_TO(driver_env.vlog_stream, "*** Check::CheckParseTrees ***\n");
-  Check::CheckParseTrees(check_units, options_.prelude_import, driver_env.fs,
+  Check::CheckParseTrees(check_units, cache.tree_and_subtrees_getters(),
+                         options_.prelude_import, driver_env.fs,
                          driver_env.vlog_stream, driver_env.fuzzing);
   CARBON_VLOG_TO(driver_env.vlog_stream,
                  "*** Check::CheckParseTrees done ***\n");
@@ -875,23 +944,8 @@ auto CompileSubcommand::Run(DriverEnv& driver_env) -> DriverResult {
   }
 
   // Lower.
-  llvm::SmallVector<Parse::GetTreeAndSubtreesFn> tree_and_subtrees_getters;
-  std::optional<llvm::ArrayRef<Parse::GetTreeAndSubtreesFn>>
-      tree_and_subtrees_getters_for_debug_info;
-  if (options_.include_debug_info) {
-    // This size may not match due to units that are missing source, but that's
-    // an error case and not worth extra work.
-    tree_and_subtrees_getters.reserve(units.size());
-    for (auto& unit : units) {
-      if (unit->has_source()) {
-        tree_and_subtrees_getters.push_back(unit->get_trees_and_subtrees());
-      }
-    }
-    tree_and_subtrees_getters_for_debug_info = {};
-    tree_and_subtrees_getters_for_debug_info = tree_and_subtrees_getters;
-  }
   for (const auto& unit : units) {
-    unit->RunLower(tree_and_subtrees_getters_for_debug_info);
+    unit->RunLower();
   }
   if (options_.phase == CompileOptions::Phase::Lower) {
     return make_result();

+ 7 - 7
toolchain/language_server/context.cpp

@@ -143,16 +143,16 @@ auto Context::File::SetText(Context& context, std::optional<int64_t> version,
     return *tree_and_subtrees_;
   };
   // TODO: Support cross-file checking when multiple files have edits.
-  llvm::SmallVector<Check::Unit> units = {{.consumer = &consumer,
-                                           .value_stores = value_stores_.get(),
-                                           .timings = nullptr,
-                                           .tree_and_subtrees_getter = getter,
-                                           .sem_ir = &sem_ir}};
+  llvm::SmallVector<Check::Unit> units = {{{.consumer = &consumer,
+                                            .value_stores = value_stores_.get(),
+                                            .timings = nullptr,
+                                            .sem_ir = &sem_ir}}};
   llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> fs =
       new llvm::vfs::InMemoryFileSystem;
   // TODO: Include the prelude.
-  Check::CheckParseTrees(units, /*prelude_import=*/false, fs,
-                         context.vlog_stream(), /*fuzzing=*/false);
+  Check::CheckParseTrees(
+      units, llvm::ArrayRef<Parse::GetTreeAndSubtreesFn>(getter),
+      /*prelude_import=*/false, fs, context.vlog_stream(), /*fuzzing=*/false);
 
   // Note we need to publish diagnostics even when empty.
   // TODO: Consider caching previously published diagnostics and only publishing

+ 11 - 6
toolchain/sem_ir/formatter.cpp

@@ -33,12 +33,12 @@
 namespace Carbon::SemIR {
 
 Formatter::Formatter(const File* sem_ir,
-                     ShouldFormatEntityFn should_format_entity,
-                     Parse::GetTreeAndSubtreesFn get_tree_and_subtrees)
+                     Parse::GetTreeAndSubtreesFn get_tree_and_subtrees,
+                     llvm::ArrayRef<bool> include_ir_in_dumps)
     : sem_ir_(sem_ir),
       inst_namer_(sem_ir_),
-      should_format_entity_(should_format_entity),
-      get_tree_and_subtrees_(get_tree_and_subtrees) {
+      get_tree_and_subtrees_(get_tree_and_subtrees),
+      include_ir_in_dumps_(include_ir_in_dumps) {
   // Create a placeholder visible chunk and assign it to all instructions that
   // don't have a chunk of their own.
   auto first_chunk = AddChunkNoFlush(true);
@@ -170,12 +170,17 @@ auto Formatter::IncludeChunkInOutput(size_t chunk) -> void {
   }
 }
 
+auto Formatter::ShouldIncludeInstByIR(InstId inst_id) -> bool {
+  const auto* import_ir = GetCanonicalFileAndInstId(sem_ir_, inst_id).first;
+  return include_ir_in_dumps_[import_ir->check_ir_id().index];
+}
+
 auto Formatter::ShouldFormatEntity(InstId decl_id, bool is_definition_start)
     -> bool {
   if (!decl_id.has_value()) {
     return true;
   }
-  if (!should_format_entity_(decl_id)) {
+  if (!ShouldIncludeInstByIR(decl_id)) {
     return false;
   }
 
@@ -522,7 +527,7 @@ auto Formatter::FormatSpecificRegion(const Generic& generic,
 auto Formatter::FormatSpecific(SpecificId id) -> void {
   const auto& specific = sem_ir_->specifics().Get(id);
   const auto& generic = sem_ir_->generics().Get(specific.generic_id);
-  if (!should_format_entity_(generic.decl_id)) {
+  if (!ShouldIncludeInstByIR(generic.decl_id)) {
     // Omit specifics if we also omitted the generic.
     return;
   }

+ 9 - 8
toolchain/sem_ir/formatter.h

@@ -17,14 +17,9 @@ namespace Carbon::SemIR {
 // Formatter for printing textual Semantics IR.
 class Formatter {
  public:
-  // A callback that indicates whether a specific entity, identified by its
-  // declaration, should be included in the output.
-  using ShouldFormatEntityFn =
-      llvm::function_ref<auto(InstId decl_inst_id)->bool>;
-
   explicit Formatter(const File* sem_ir,
-                     ShouldFormatEntityFn should_format_entity,
-                     Parse::GetTreeAndSubtreesFn get_tree_and_subtrees);
+                     Parse::GetTreeAndSubtreesFn get_tree_and_subtrees,
+                     llvm::ArrayRef<bool> include_ir_in_dumps);
 
   // Prints the SemIR into an internal buffer.
   //
@@ -89,6 +84,10 @@ class Formatter {
   // is.
   auto IncludeChunkInOutput(size_t chunk) -> void;
 
+  // Returns true if the instruction should be included according to its
+  // originating IR. Typically `ShouldFormatEntity` should be used instead.
+  auto ShouldIncludeInstByIR(InstId inst_id) -> bool;
+
   // Determines whether the specified entity should be included in the formatted
   // output. `is_definition_start` should indicate whether, if `decl_id`'s
   // `LocId` is a `NodeId`, it is expected to be a `DefinitionStart` kind.
@@ -323,9 +322,11 @@ class Formatter {
 
   const File* sem_ir_;
   InstNamer inst_namer_;
-  ShouldFormatEntityFn should_format_entity_;
   Parse::GetTreeAndSubtreesFn get_tree_and_subtrees_;
 
+  // For each CheckIRId, whether entities from it should be formatted.
+  llvm::ArrayRef<bool> include_ir_in_dumps_;
+
   // The output stream buffer.
   std::string buffer_;