Просмотр исходного кода

Refactor check_ir_map to encapsulate it. (#3968)

This is mostly to reduce code complexity at call sites. I was
considering pulling out a wrapper type, but it seems a bit small for
that right now.
Jon Ross-Perkins 1 год назад
Родитель
Сommit
7effc1abd7
3 измененных файлов с 20 добавлено и 14 удалено
  1. 5 8
      toolchain/check/check.cpp
  2. 13 4
      toolchain/check/context.h
  3. 2 2
      toolchain/check/import_ref.cpp

+ 5 - 8
toolchain/check/check.cpp

@@ -262,8 +262,7 @@ static auto ImportCurrentPackage(Context& context, UnitInfo& unit_info,
             // The indirect IR was previously indirectly imported, but it's
             // found through `export import`. We need to mark it for re-export.
             context.import_irs()
-                .Get(context.check_ir_map()[indirect_ir.sem_ir->check_ir_id()
-                                                .index])
+                .Get(context.GetImportIRId(*indirect_ir.sem_ir))
                 .is_export = true;
           }
         }
@@ -271,15 +270,13 @@ static auto ImportCurrentPackage(Context& context, UnitInfo& unit_info,
     } else if (import.names.is_export) {
       // The IR was previously indirectly imported, but it's `export import`.
       // We need to mark it -- and transitive `export import`s -- for re-export.
-      context.import_irs()
-          .Get(context.check_ir_map()[import_sem_ir.check_ir_id().index])
-          .is_export = true;
+      context.import_irs().Get(context.GetImportIRId(import_sem_ir)).is_export =
+          true;
 
       for (const auto& indirect_ir : import_sem_ir.import_irs().array_ref()) {
         if (indirect_ir.is_export) {
           context.import_irs()
-              .Get(context
-                       .check_ir_map()[indirect_ir.sem_ir->check_ir_id().index])
+              .Get(context.GetImportIRId(*indirect_ir.sem_ir))
               .is_export = true;
         }
       }
@@ -308,7 +305,7 @@ static auto InitPackageScopeAndImports(Context& context, UnitInfo& unit_info,
   context.import_irs().Reserve(num_irs);
   context.import_ir_constant_values().reserve(num_irs);
 
-  context.check_ir_map().resize(total_ir_count, SemIR::ImportIRId::Invalid);
+  context.SetTotalIRCount(total_ir_count);
 
   // Importing makes many namespaces, so only canonicalize the type once.
   auto namespace_type_id =

+ 13 - 4
toolchain/check/context.h

@@ -275,6 +275,19 @@ class Context {
   // Finalizes the initialization function (__global_init).
   auto FinalizeGlobalInit() -> void;
 
+  // Sets the total number of IRs which exist. This is used to prepare a map
+  // from IR to imported IR.
+  auto SetTotalIRCount(int num_irs) -> void {
+    CARBON_CHECK(check_ir_map_.empty())
+        << "SetTotalIRCount is only called once";
+    check_ir_map_.resize(num_irs, SemIR::ImportIRId::Invalid);
+  }
+
+  // Returns the imported IR ID for an IR, or invalid if not imported.
+  auto GetImportIRId(const SemIR::File& sem_ir) -> SemIR::ImportIRId& {
+    return check_ir_map_[sem_ir.check_ir_id().index];
+  }
+
   // Prints information for a stack dump.
   auto PrintForStackDump(llvm::raw_ostream& output) const -> void;
 
@@ -318,10 +331,6 @@ class Context {
     return scope_stack().break_continue_stack();
   }
 
-  auto check_ir_map() -> llvm::SmallVector<SemIR::ImportIRId>& {
-    return check_ir_map_;
-  }
-
   auto import_ir_constant_values()
       -> llvm::SmallVector<SemIR::ConstantValueStore, 0>& {
     return import_ir_constant_values_;

+ 2 - 2
toolchain/check/import_ref.cpp

@@ -40,7 +40,7 @@ auto SetApiImportIR(Context& context, SemIR::ImportIR import_ir) -> void {
 
 auto AddImportIR(Context& context, SemIR::ImportIR import_ir)
     -> SemIR::ImportIRId {
-  auto& ir_id = context.check_ir_map()[import_ir.sem_ir->check_ir_id().index];
+  auto& ir_id = context.GetImportIRId(*import_ir.sem_ir);
   if (!ir_id.is_valid()) {
     // Note this updates check_ir_map.
     ir_id = InternalAddImportIR(context, import_ir);
@@ -249,7 +249,7 @@ class ImportRefResolver {
       auto prev_inst_id = cursor_inst_id;
 
       cursor_ir = cursor_ir->import_irs().Get(ir_inst.ir_id).sem_ir;
-      cursor_ir_id = context_.check_ir_map()[cursor_ir->check_ir_id().index];
+      cursor_ir_id = context_.GetImportIRId(*cursor_ir);
       if (!cursor_ir_id.is_valid()) {
         // TODO: Should we figure out a location to assign here?
         cursor_ir_id = AddImportIR(context_, {.node_id = Parse::NodeId::Invalid,