ソースを参照

Perform an extra pass to import a generic for a symbolic constant less often. (#4182)

Instead of always forcing an extra pass when we need to import a generic
ID for a generic that isn't already imported, attempt to import the
rest of the instruction in the same pass. There are then three
possibilities:

- The instruction needs a retry anyway to form its constant value, and
  we avoid an extra pass.
- The instruction produces its constant value on the first pass but
  still needs a retry. In this case, the handler for that instruction
  is expected to retry itself, before building its constant value. The
  third pass in this case can't be avoided.
- The instruction succeeds on its first pass. We still need an extra
  pass; track the constant produced by resolution separately.

---------

Co-authored-by: Jon Ross-Perkins <jperkins@google.com>
Richard Smith 1 年間 前
コミット
1705347375
2 ファイル変更234 行追加208 行削除
  1. 226 205
      toolchain/check/import_ref.cpp
  2. 8 3
      toolchain/check/sem_ir_diagnostic_converter.cpp

+ 226 - 205
toolchain/check/import_ref.cpp

@@ -125,44 +125,61 @@ auto VerifySameCanonicalImportIRInst(Context& context, SemIR::InstId prev_id,
 // Calling Resolve on an instruction operates in an iterative manner, tracking
 // Work items on work_stack_. At a high level, the loop is:
 //
-// 1. If a constant value is already known for the work item, and we're
-//    processing it for the first time, it's considered resolved.
+// 1. If a constant value is already known for the work item and was not set by
+//    this work item, it's considered resolved.
 //    - The constant check avoids performance costs of deduplication on add.
-//    - If `retry` is set, we process it again, because it didn't complete last
-//      time, even though we have a constant value already.
-// 2. Resolve the instruction: (TryResolveInst/TryResolveTypedInst)
-//    - For a symbolic constant within a generic, find the generic itself. If it
-//      needs to be imported, return Retry() to import the generic before we
-//      import the constant.
-//    - For instructions that can be forward declared, if we don't already have
-//      a constant value from a previous attempt at resolution, start by making
-//      a forward declared constant value to address circular references.
-//    - Gather all input constants.
-//      - Gathering constants directly adds unresolved values to work_stack_.
-//    - If any need to be resolved (HasNewWork), return Retry(): this
-//      instruction needs another call to complete.
-//      - If the constant value is already known because we have made a forward
-//        declaration, pass it to Retry(). It will be passed to future attempts
-//        to resolve this instruction so the earlier work can be found, and will
-//        be made available for other instructions to use.
-//      - The subsequent attempt to resolve this instruction must produce the
-//        same constant, because the value may have already been used by
-//        resolved instructions.
-//    - Build any necessary IR structures, and return the output constant.
-// 3. If resolve didn't return Retry(), pop the work. Otherwise, it needs to
-//    remain, and may no longer be at the top of the stack; set `retry` on it so
-//    we'll make sure to run it again later.
+//    - If we've processed this work item before, then we now process it again.
+//      It didn't complete last time, even though we have a constant value
+//      already.
+//
+// 2. Resolve the instruction (TryResolveInst/TryResolveTypedInst). This is done
+//    in three phases. The first and second phases can add work to the worklist
+//    and end in a retry, in which case those phases will be rerun once the
+//    added work is done. The rerun cannot also end in a retry, so this results
+//    in at most three calls, but in practice one or two calls is almost always
+//    sufficient. Due to the chance of a second or third call to TryResolveInst,
+//    it's important to only perform expensive work once, even when the same
+//    phase is rerun.
 //
-// TryResolveInst can complete in one call for a given instruction, but should
-// always complete within three calls:
+//    - First phase:
+//      - Gather all input constants necessary to form the constant value of the
+//        instruction. Gathering constants directly adds unresolved values to
+//        work_stack_.
+//      - If HasNewWork() reports that any work was added, then return Retry():
+//        this instruction needs another call to complete. Gather the
+//        now-resolved constants and continue to the next step once the retry
+//        happens.
 //
-// - TryResolveInst can retry once if the generic is not yet loaded.
-// - TryResolveTypedInst can retry once if its inputs are not yet loaded.
-// - The third call should succeed.
+//    - Second phase:
+//      - Build the constant value of the instruction.
+//      - Gather all input constants necessary to finish importing the
+//        instruction. This is only necessary for instructions like classes that
+//        can be forward-declared. For these instructions, we first import the
+//        constant value and then later import the rest of the declaration in
+//        order to break cycles.
+//      - If HasNewWork() reports that any work was added, then return
+//        Retry(constant_value): this instruction needs another call to
+//        complete.  Gather the now-resolved constants and continue to the next
+//        step once the retry happens.
 //
-// Due to the chance of a second call to TryResolveTypedInst, it's important to
-// reserve all expensive logic until it's been established that input constants
-// are available.
+//    - Third phase:
+//      - After the second phase, the constant value for the instruction is
+//        already set, and will be passed back into TryResolve*Inst on retry. It
+//        should not be created again.
+//      - Fill in any remaining information to complete the import of the
+//        instruction. For example, when importing a class declaration, build
+//        the class scope and information about the definition.
+//      - Return ResolveAs/ResolveAsConstant to finish the resolution process.
+//        This will cause the Resolve loop to set a constant value if we didn't
+//        retry at the end of the second phase.
+//
+// 3. If resolve didn't return Retry(), pop the work. Otherwise, it needs to
+//    remain, and may no longer be at the top of the stack; update the state of
+//    the work item to track what work still needs to be done.
+//
+// The same instruction can be enqueued for resolution multiple times. However,
+// we will only reach the second phase once: once a constant value is set, only
+// the resolution step that set it will retry.
 //
 // TODO: Fix class `extern` handling and merging, rewrite tests.
 // - check/testdata/class/cross_package_import.carbon
@@ -188,29 +205,29 @@ class ImportRefResolver {
 
       // Step 1: check for a constant value.
       auto existing = FindResolvedConstId(work.inst_id);
-      if (existing.const_id.is_valid() && !work.retry) {
+      if (existing.const_id.is_valid() && !work.retry_with_constant_value) {
         work_stack_.pop_back();
         continue;
       }
 
       // Step 2: resolve the instruction.
-      auto initial_work = work_stack_.size();
-      auto [new_const_id, finished] =
+      initial_work_ = work_stack_.size();
+      auto [new_const_id, retry] =
           TryResolveInst(work.inst_id, existing.const_id);
-      CARBON_CHECK(finished == !HasNewWork(initial_work));
 
       CARBON_CHECK(!existing.const_id.is_valid() ||
                    existing.const_id == new_const_id)
-          << "Constant value changed in second pass.";
+          << "Constant value changed in third phase.";
       if (!existing.const_id.is_valid()) {
         SetResolvedConstId(work.inst_id, existing.indirect_insts, new_const_id);
       }
 
       // Step 3: pop or retry.
-      if (finished) {
-        work_stack_.pop_back();
+      if (retry) {
+        work_stack_[initial_work_ - 1].retry_with_constant_value =
+            new_const_id.is_valid();
       } else {
-        work_stack_[initial_work - 1].retry = true;
+        work_stack_.pop_back();
       }
     }
     auto constant_id = import_ir_constant_values().Get(inst_id);
@@ -268,32 +285,21 @@ class ImportRefResolver {
   }
 
  private:
-  // A step in work_stack_.
-  struct Work {
-    // The instruction to work on.
-    SemIR::InstId inst_id;
-
-    // True if another pass was requested last time this was run.
-    bool retry = false;
-  };
-
   // The result of attempting to resolve an imported instruction to a constant.
   struct ResolveResult {
-    // Try resolving this function again. If `const_id` is specified, it will be
-    // passed to the next resolution attempt.
-    static auto Retry(SemIR::ConstantId const_id = SemIR::ConstantId::Invalid)
-        -> ResolveResult {
-      return {.const_id = const_id, .finished = false};
-    }
-
     // The new constant value, if known.
     SemIR::ConstantId const_id;
-    // Whether resolution has finished. If false, `TryResolveInst` will be
-    // called again. Note that this is not strictly necessary, and we can get
-    // the same information by checking whether new work was added to the stack.
-    // However, we use this for consistency checks between resolve actions and
-    // the work stack.
-    bool finished = true;
+    // Whether resolution has been attempted once and needs to be retried.
+    bool retry = false;
+  };
+
+  // A step in work_stack_.
+  struct Work {
+    // The instruction to work on.
+    SemIR::InstId inst_id;
+    // Whether this work item set the constant value for the instruction and
+    // requested a retry.
+    bool retry_with_constant_value = false;
   };
 
   // The constant found by FindResolvedConstId.
@@ -390,16 +396,12 @@ class ImportRefResolver {
     }
   }
 
-  // Returns true if new unresolved constants were found.
-  //
-  // At the start of a function, do:
-  //   auto initial_work = work_stack_.size();
-  // Then when determining:
-  //   if (HasNewWork(initial_work)) { ... }
-  auto HasNewWork(size_t initial_work) -> bool {
-    CARBON_CHECK(initial_work <= work_stack_.size())
+  // Returns true if new unresolved constants were found as part of this
+  // `Resolve` step.
+  auto HasNewWork() -> bool {
+    CARBON_CHECK(initial_work_ <= work_stack_.size())
         << "Work shouldn't decrease";
-    return initial_work < work_stack_.size();
+    return initial_work_ < work_stack_.size();
   }
 
   auto AddImportIRInst(SemIR::InstId inst_id) -> SemIR::ImportIRInstId {
@@ -478,7 +480,7 @@ class ImportRefResolver {
   }
 
   // Gets an incomplete local version of an imported generic. Most fields are
-  // set in the second pass.
+  // set in the third phase.
   auto MakeIncompleteGeneric(SemIR::InstId decl_id, SemIR::GenericId generic_id)
       -> SemIR::GenericId {
     if (!generic_id.is_valid()) {
@@ -819,7 +821,7 @@ class ImportRefResolver {
 
   // Given an imported entity base, returns an incomplete, local version of it.
   //
-  // Most fields are set in the second pass once they're imported. Import enough
+  // Most fields are set in the third phase once they're imported. Import enough
   // of the parameter lists that we know whether this interface is a generic
   // interface and can build the right constant value for it.
   //
@@ -888,29 +890,47 @@ class ImportRefResolver {
   auto TryResolveInst(SemIR::InstId inst_id, SemIR::ConstantId const_id)
       -> ResolveResult {
     auto inst_const_id = import_ir_.constant_values().Get(inst_id);
-    if (!inst_const_id.is_valid() || !inst_const_id.is_symbolic() ||
-        const_id.is_valid()) {
+    if (!inst_const_id.is_valid() || !inst_const_id.is_symbolic()) {
       return TryResolveInstCanonical(inst_id, const_id);
     }
 
-    // Try to import the generic, and retry if it's not ready yet. Note that if
-    // this retries, we can require three passes to import an instruction.
+    // Try to import the generic. This might add new work.
     const auto& symbolic_const =
         import_ir_.constant_values().GetSymbolicConstant(inst_const_id);
-    auto initial_work = work_stack_.size();
     auto generic_const_id = GetLocalConstantId(symbolic_const.generic_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+
+    auto inner_const_id = SemIR::ConstantId::Invalid;
+    if (const_id.is_valid()) {
+      // For the third phase, extract the constant value that
+      // TryResolveInstCanonical produced previously.
+      inner_const_id = context_.constant_values().Get(
+          context_.constant_values().GetSymbolicConstant(const_id).inst_id);
     }
 
     // Import the constant and rebuild the symbolic constant data.
-    auto result = TryResolveInstCanonical(inst_id, const_id);
-    if (result.const_id.is_valid()) {
-      result.const_id = context_.constant_values().AddSymbolicConstant(
-          {.inst_id = context_.constant_values().GetInstId(result.const_id),
-           .generic_id = GetLocalGenericId(generic_const_id),
-           .index = symbolic_const.index});
+    auto result = TryResolveInstCanonical(inst_id, inner_const_id);
+    if (!result.const_id.is_valid()) {
+      // First phase: TryResolveInstCanoncial needs a retry.
+      return result;
     }
+
+    if (!const_id.is_valid()) {
+      // Second phase: we have created an abstract constant. Create a
+      // corresponding generic constant.
+      if (symbolic_const.generic_id.is_valid()) {
+        result.const_id = context_.constant_values().AddSymbolicConstant(
+            {.inst_id = context_.constant_values().GetInstId(result.const_id),
+             .generic_id = GetLocalGenericId(generic_const_id),
+             .index = symbolic_const.index});
+      }
+    } else {
+      // Third phase: perform a consistency check and produce the constant we
+      // created in the second phase.
+      CARBON_CHECK(result.const_id == inner_const_id)
+          << "Constant value changed in third phase.";
+      result.const_id = const_id;
+    }
+
     return result;
   }
 
@@ -923,7 +943,7 @@ class ImportRefResolver {
     if (inst_id.is_builtin()) {
       CARBON_CHECK(!const_id.is_valid());
       // Constants for builtins can be directly copied.
-      return {.const_id = context_.constant_values().Get(inst_id)};
+      return ResolveAsConstant(context_.constant_values().Get(inst_id));
     }
 
     auto untyped_inst = import_ir_.insts().Get(inst_id);
@@ -940,9 +960,9 @@ class ImportRefResolver {
       case CARBON_KIND(SemIR::BindAlias inst): {
         return TryResolveTypedInst(inst);
       }
-      case CARBON_KIND(SemIR::BindName inst): {
-        // TODO: This always returns `ConstantId::NotConstant`.
-        return {.const_id = TryEvalInst(context_, inst_id, inst)};
+      case SemIR::BindName::Kind: {
+        // TODO: Should we be resolving BindNames at all?
+        return ResolveAsConstant(SemIR::ConstantId::NotConstant);
       }
       case CARBON_KIND(SemIR::BindSymbolicName inst): {
         return TryResolveTypedInst(inst);
@@ -989,10 +1009,6 @@ class ImportRefResolver {
       case CARBON_KIND(SemIR::IntLiteral inst): {
         return TryResolveTypedInst(inst);
       }
-      case CARBON_KIND(SemIR::Namespace inst): {
-        CARBON_FATAL() << "Namespaces shouldn't need resolution this way: "
-                       << inst;
-      }
       case CARBON_KIND(SemIR::PointerType inst): {
         return TryResolveTypedInst(inst);
       }
@@ -1019,12 +1035,39 @@ class ImportRefResolver {
     }
   }
 
-  // Produce a resolve result for the given instruction that describes a
+  // Produces a resolve result that tries resolving this instruction again. If
+  // `const_id` is specified, then this is the end of the second phase, and the
+  // constant value will be passed to the next resolution attempt. Otherwise,
+  // this is the end of the first phase.
+  auto Retry(SemIR::ConstantId const_id = SemIR::ConstantId::Invalid)
+      -> ResolveResult {
+    CARBON_CHECK(HasNewWork());
+    return {.const_id = const_id, .retry = true};
+  }
+
+  // Produces a resolve result that provides the given constant value. Requires
+  // that there is no new work.
+  auto ResolveAsConstant(SemIR::ConstantId const_id) -> ResolveResult {
+    CARBON_CHECK(!HasNewWork());
+    return {.const_id = const_id};
+  }
+
+  // Produces a resolve result that provides the given constant value. Retries
+  // instead if work has been added.
+  auto RetryOrResolveAsConstant(SemIR::ConstantId const_id) -> ResolveResult {
+    if (HasNewWork()) {
+      return Retry();
+    }
+    return ResolveAsConstant(const_id);
+  }
+
+  // Produces a resolve result for the given instruction that describes a
   // constant value. This should only be used for instructions that describe
   // constants, and not for instructions that represent declarations. For a
   // declaration, we need an associated location, so AddInstInNoBlock should be
-  // used instead.
+  // used instead. Requires that there is no new work.
   auto ResolveAsUntyped(SemIR::Inst inst) -> ResolveResult {
+    CARBON_CHECK(!HasNewWork());
     auto result = TryEvalInst(context_, SemIR::InstId::Invalid, inst);
     CARBON_CHECK(result.is_constant()) << inst << " is not constant";
     return {.const_id = result};
@@ -1037,10 +1080,9 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::AssociatedEntity inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto type_const_id = GetLocalConstantId(inst.type_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     // Add a lazy reference to the target declaration.
@@ -1057,11 +1099,10 @@ class ImportRefResolver {
   auto TryResolveTypedInst(SemIR::AssociatedEntityType inst) -> ResolveResult {
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
 
-    auto initial_work = work_stack_.size();
     auto entity_type_const_id = GetLocalConstantId(inst.entity_type_id);
     auto interface_inst_id = GetLocalConstantId(inst.interface_type_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     return ResolveAs<SemIR::AssociatedEntityType>(
@@ -1074,11 +1115,10 @@ class ImportRefResolver {
 
   auto TryResolveTypedInst(SemIR::BaseDecl inst, SemIR::InstId import_inst_id)
       -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto type_const_id = GetLocalConstantId(inst.type_id);
     auto base_type_const_id = GetLocalConstantId(inst.base_type_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     // Import the instruction in order to update contained base_type_id and
@@ -1088,23 +1128,18 @@ class ImportRefResolver {
         {.type_id = context_.GetTypeIdForTypeConstant(type_const_id),
          .base_type_id = context_.GetTypeIdForTypeConstant(base_type_const_id),
          .index = inst.index});
-    return {.const_id = context_.constant_values().Get(inst_id)};
+    return ResolveAsConstant(context_.constant_values().Get(inst_id));
   }
 
   auto TryResolveTypedInst(SemIR::BindAlias inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto value_id = GetLocalConstantId(inst.value_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
-    }
-    return {.const_id = value_id};
+    return RetryOrResolveAsConstant(value_id);
   }
 
   auto TryResolveTypedInst(SemIR::BindSymbolicName inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto type_id = GetLocalConstantId(inst.type_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     const auto& import_entity_name =
@@ -1192,12 +1227,17 @@ class ImportRefResolver {
 
     SemIR::ClassId class_id = SemIR::ClassId::Invalid;
     if (!class_const_id.is_valid()) {
-      // On the first pass, create a forward declaration of the class for any
+      if (HasNewWork()) {
+        // This is the end of the first phase. Don't make a new class yet if we
+        // already have new work.
+        return Retry();
+      }
+      // On the second phase, create a forward declaration of the class for any
       // recursive references.
       std::tie(class_id, class_const_id) = MakeIncompleteClass(import_class);
     } else {
-      // On the second pass, compute the class ID from the constant value of the
-      // declaration.
+      // On the third phase, compute the class ID from the constant
+      // value of the declaration.
       auto class_const_inst = context_.insts().Get(
           context_.constant_values().GetInstId(class_const_id));
       if (auto class_type = class_const_inst.TryAs<SemIR::ClassType>()) {
@@ -1211,8 +1251,6 @@ class ImportRefResolver {
     }
 
     // Load constants for the definition.
-    auto initial_work = work_stack_.size();
-
     auto parent_scope_id = GetLocalNameScopeId(import_class.parent_scope_id);
     llvm::SmallVector<SemIR::ConstantId> implicit_param_const_ids =
         GetLocalParamConstantIds(import_class.implicit_param_refs_id);
@@ -1228,8 +1266,8 @@ class ImportRefResolver {
                        ? GetLocalConstantInstId(import_class.base_id)
                        : SemIR::InstId::Invalid;
 
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry(class_const_id);
+    if (HasNewWork()) {
+      return Retry(class_const_id);
     }
 
     auto& new_class = context_.classes().Get(class_id);
@@ -1246,17 +1284,16 @@ class ImportRefResolver {
                          base_id);
     }
 
-    return {.const_id = class_const_id};
+    return ResolveAsConstant(class_const_id);
   }
 
   auto TryResolveTypedInst(SemIR::ClassType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto class_const_id =
         GetLocalConstantId(import_ir_.classes().Get(inst.class_id).decl_id);
     auto specific_data = GetLocalSpecificData(inst.specific_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     // Find the corresponding class type. For a non-generic class, this is the
@@ -1265,7 +1302,7 @@ class ImportRefResolver {
     auto class_const_inst = context_.insts().Get(
         context_.constant_values().GetInstId(class_const_id));
     if (class_const_inst.Is<SemIR::ClassType>()) {
-      return {.const_id = class_const_id};
+      return ResolveAsConstant(class_const_id);
     } else {
       auto generic_class_type = context_.types().GetAs<SemIR::GenericClassType>(
           class_const_inst.type_id());
@@ -1278,11 +1315,10 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::ConstType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto inner_const_id = GetLocalConstantId(inst.inner_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
     auto inner_type_id = context_.GetTypeIdForTypeConstant(inner_const_id);
     return ResolveAs<SemIR::ConstType>(
@@ -1290,20 +1326,15 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::ExportDecl inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto value_id = GetLocalConstantId(inst.value_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
-    }
-    return {.const_id = value_id};
+    return RetryOrResolveAsConstant(value_id);
   }
 
   auto TryResolveTypedInst(SemIR::FieldDecl inst, SemIR::InstId import_inst_id)
       -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto const_id = GetLocalConstantId(inst.type_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
     auto inst_id = context_.AddInstInNoBlock<SemIR::FieldDecl>(
         AddImportIRInst(import_inst_id),
@@ -1350,11 +1381,16 @@ class ImportRefResolver {
 
     SemIR::FunctionId function_id = SemIR::FunctionId::Invalid;
     if (!function_const_id.is_valid()) {
-      // On the first pass, create a forward declaration of the interface.
+      if (HasNewWork()) {
+        // This is the end of the first phase. Don't make a new function yet if
+        // we already have new work.
+        return Retry();
+      }
+      // On the second phase, create a forward declaration of the interface.
       std::tie(function_id, function_const_id) =
           MakeFunctionDecl(import_function);
     } else {
-      // On the second pass, compute the function ID from the constant value of
+      // On the third phase, compute the function ID from the constant value of
       // the declaration.
       auto function_const_inst = context_.insts().Get(
           context_.constant_values().GetInstId(function_const_id));
@@ -1363,8 +1399,6 @@ class ImportRefResolver {
       function_id = function_type.function_id;
     }
 
-    auto initial_work = work_stack_.size();
-
     auto return_type_const_id = SemIR::ConstantId::Invalid;
     if (import_function.return_storage_id.is_valid()) {
       return_type_const_id =
@@ -1377,8 +1411,8 @@ class ImportRefResolver {
         GetLocalParamConstantIds(import_function.param_refs_id);
     auto generic_data = GetLocalGenericData(import_function.generic_id);
 
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry(function_const_id);
+    if (HasNewWork()) {
+      return Retry(function_const_id);
     }
 
     // Add the function declaration.
@@ -1407,17 +1441,16 @@ class ImportRefResolver {
       new_function.definition_id = new_function.decl_id;
     }
 
-    return {.const_id = function_const_id};
+    return ResolveAsConstant(function_const_id);
   }
 
   auto TryResolveTypedInst(SemIR::FunctionType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto fn_val_id = GetLocalConstantInstId(
         import_ir_.functions().Get(inst.function_id).decl_id);
     auto specific_data = GetLocalSpecificData(inst.specific_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
     auto fn_type_id = context_.insts().Get(fn_val_id).type_id();
     return ResolveAs<SemIR::FunctionType>(
@@ -1430,55 +1463,49 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::GenericClassType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto class_val_id =
         GetLocalConstantInstId(import_ir_.classes().Get(inst.class_id).decl_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
     auto class_val = context_.insts().Get(class_val_id);
     CARBON_CHECK(
         context_.types().Is<SemIR::GenericClassType>(class_val.type_id()));
-    return {.const_id = context_.types().GetConstantId(class_val.type_id())};
+    return ResolveAsConstant(
+        context_.types().GetConstantId(class_val.type_id()));
   }
 
   auto TryResolveTypedInst(SemIR::GenericInterfaceType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto interface_val_id = GetLocalConstantInstId(
         import_ir_.interfaces().Get(inst.interface_id).decl_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
     auto interface_val = context_.insts().Get(interface_val_id);
     CARBON_CHECK(context_.types().Is<SemIR::GenericInterfaceType>(
         interface_val.type_id()));
-    return {.const_id =
-                context_.types().GetConstantId(interface_val.type_id())};
+    return ResolveAsConstant(
+        context_.types().GetConstantId(interface_val.type_id()));
   }
 
   auto TryResolveTypedInst(SemIR::ImportRefLoaded /*inst*/,
                            SemIR::InstId inst_id) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     // Return the constant for the instruction of the imported constant.
     auto constant_id = import_ir_.constant_values().Get(inst_id);
     if (!constant_id.is_valid()) {
-      return {.const_id = SemIR::ConstantId::Error};
+      return ResolveAsConstant(SemIR::ConstantId::Error);
     }
     if (!constant_id.is_constant()) {
       context_.TODO(inst_id,
                     "Non-constant ImportRefLoaded (comes up with var)");
-      return {.const_id = SemIR::ConstantId::Error};
+      return ResolveAsConstant(SemIR::ConstantId::Error);
     }
 
     auto new_constant_id =
         GetLocalConstantId(import_ir_.constant_values().GetInstId(constant_id));
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
-    }
-
-    return {.const_id = new_constant_id};
+    return RetryOrResolveAsConstant(new_constant_id);
   }
 
   // Make a declaration of an interface. This is done as a separate step from
@@ -1541,11 +1568,16 @@ class ImportRefResolver {
 
     SemIR::InterfaceId interface_id = SemIR::InterfaceId::Invalid;
     if (!interface_const_id.is_valid()) {
-      // On the first pass, create a forward declaration of the interface.
+      if (HasNewWork()) {
+        // This is the end of the first phase. Don't make a new interface yet if
+        // we already have new work.
+        return Retry();
+      }
+      // On the second phase, create a forward declaration of the interface.
       std::tie(interface_id, interface_const_id) =
           MakeInterfaceDecl(import_interface);
     } else {
-      // On the second pass, compute the interface ID from the constant value of
+      // On the third phase, compute the interface ID from the constant value of
       // the declaration.
       auto interface_const_inst = context_.insts().Get(
           context_.constant_values().GetInstId(interface_const_id));
@@ -1560,8 +1592,6 @@ class ImportRefResolver {
       }
     }
 
-    auto initial_work = work_stack_.size();
-
     auto parent_scope_id =
         GetLocalNameScopeId(import_interface.parent_scope_id);
     llvm::SmallVector<SemIR::ConstantId> implicit_param_const_ids =
@@ -1575,8 +1605,8 @@ class ImportRefResolver {
       self_param_id = GetLocalConstantInstId(import_interface.self_param_id);
     }
 
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry(interface_const_id);
+    if (HasNewWork()) {
+      return Retry(interface_const_id);
     }
 
     auto& new_interface = context_.interfaces().Get(interface_id);
@@ -1592,17 +1622,16 @@ class ImportRefResolver {
       CARBON_CHECK(self_param_id);
       AddInterfaceDefinition(import_interface, new_interface, *self_param_id);
     }
-    return {.const_id = interface_const_id};
+    return ResolveAsConstant(interface_const_id);
   }
 
   auto TryResolveTypedInst(SemIR::InterfaceType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto interface_const_id = GetLocalConstantId(
         import_ir_.interfaces().Get(inst.interface_id).decl_id);
     auto specific_data = GetLocalSpecificData(inst.specific_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     // Find the corresponding interface type. For a non-generic interface, this
@@ -1612,7 +1641,7 @@ class ImportRefResolver {
     auto interface_const_inst = context_.insts().Get(
         context_.constant_values().GetInstId(interface_const_id));
     if (interface_const_inst.Is<SemIR::InterfaceType>()) {
-      return {.const_id = interface_const_id};
+      return ResolveAsConstant(interface_const_id);
     } else {
       auto generic_interface_type =
           context_.types().GetAs<SemIR::GenericInterfaceType>(
@@ -1626,10 +1655,9 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::InterfaceWitness inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto elements = GetLocalInstBlockContents(inst.elements_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     auto elements_id = GetLocalCanonicalInstBlockId(inst.elements_id, elements);
@@ -1640,10 +1668,9 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::IntLiteral inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto type_id = GetLocalConstantId(inst.type_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     return ResolveAs<SemIR::IntLiteral>(
@@ -1652,11 +1679,10 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::PointerType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto pointee_const_id = GetLocalConstantId(inst.pointee_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     auto pointee_type_id = context_.GetTypeIdForTypeConstant(pointee_const_id);
@@ -1666,8 +1692,6 @@ class ImportRefResolver {
 
   auto TryResolveTypedInst(SemIR::StructType inst, SemIR::InstId import_inst_id)
       -> ResolveResult {
-    // Collect all constants first, locating unresolved ones in a single pass.
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto orig_fields = import_ir_.inst_blocks().Get(inst.fields_id);
     llvm::SmallVector<SemIR::ConstantId> field_const_ids;
@@ -1676,8 +1700,8 @@ class ImportRefResolver {
       auto field = import_ir_.insts().GetAs<SemIR::StructTypeField>(field_id);
       field_const_ids.push_back(GetLocalConstantId(field.field_type_id));
     }
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     // Prepare a vector of fields for GetStructType.
@@ -1701,11 +1725,10 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::StructValue inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto type_id = GetLocalConstantId(inst.type_id);
     auto elems = GetLocalInstBlockContents(inst.elements_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     return ResolveAs<SemIR::StructValue>(
@@ -1716,16 +1739,14 @@ class ImportRefResolver {
   auto TryResolveTypedInst(SemIR::TupleType inst) -> ResolveResult {
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
 
-    // Collect all constants first, locating unresolved ones in a single pass.
-    auto initial_work = work_stack_.size();
     auto orig_elem_type_ids = import_ir_.type_blocks().Get(inst.elements_id);
     llvm::SmallVector<SemIR::ConstantId> elem_const_ids;
     elem_const_ids.reserve(orig_elem_type_ids.size());
     for (auto elem_type_id : orig_elem_type_ids) {
       elem_const_ids.push_back(GetLocalConstantId(elem_type_id));
     }
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     // Prepare a vector of the tuple types for GetTupleType.
@@ -1735,16 +1756,15 @@ class ImportRefResolver {
       elem_type_ids.push_back(context_.GetTypeIdForTypeConstant(elem_const_id));
     }
 
-    return {.const_id = context_.types().GetConstantId(
-                context_.GetTupleType(elem_type_ids))};
+    return ResolveAsConstant(
+        context_.types().GetConstantId(context_.GetTupleType(elem_type_ids)));
   }
 
   auto TryResolveTypedInst(SemIR::TupleValue inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     auto type_id = GetLocalConstantId(inst.type_id);
     auto elems = GetLocalInstBlockContents(inst.elements_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     return ResolveAs<SemIR::TupleValue>(
@@ -1753,12 +1773,11 @@ class ImportRefResolver {
   }
 
   auto TryResolveTypedInst(SemIR::UnboundElementType inst) -> ResolveResult {
-    auto initial_work = work_stack_.size();
     CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
     auto class_const_id = GetLocalConstantId(inst.class_type_id);
     auto elem_const_id = GetLocalConstantId(inst.element_type_id);
-    if (HasNewWork(initial_work)) {
-      return ResolveResult::Retry();
+    if (HasNewWork()) {
+      return Retry();
     }
 
     return ResolveAs<SemIR::UnboundElementType>(
@@ -1775,6 +1794,8 @@ class ImportRefResolver {
   SemIR::ImportIRId import_ir_id_;
   const SemIR::File& import_ir_;
   llvm::SmallVector<Work> work_stack_;
+  // The size of work_stack_ at the start of resolving the current instruction.
+  size_t initial_work_ = 0;
 };
 
 // Returns a list of ImportIRInsts equivalent to the ImportRef currently being

+ 8 - 3
toolchain/check/sem_ir_diagnostic_converter.cpp

@@ -28,7 +28,7 @@ auto SemIRDiagnosticConverter::ConvertLoc(SemIRLoc loc,
       // For imports in the current file, the location is simple.
       in_import_loc = ConvertLocInFile(cursor_ir, import_loc_id.node_id(),
                                        loc.token_only, context_fn);
-    } else {
+    } else if (import_loc_id.is_import_ir_inst_id()) {
       // For implicit imports, we need to unravel the location a little
       // further.
       auto implicit_import_ir_inst =
@@ -43,8 +43,13 @@ auto SemIRDiagnosticConverter::ConvertLoc(SemIRLoc loc,
           ConvertLocInFile(implicit_ir.sem_ir, implicit_loc_id.node_id(),
                            loc.token_only, context_fn);
     }
-    CARBON_DIAGNOSTIC(InImport, Note, "In import.");
-    context_fn(in_import_loc, InImport);
+
+    // TODO: Add an "In implicit import of prelude." note for the case where we
+    // don't have a location.
+    if (import_loc_id.is_valid()) {
+      CARBON_DIAGNOSTIC(InImport, Note, "In import.");
+      context_fn(in_import_loc, InImport);
+    }
 
     cursor_ir = import_ir.sem_ir;
     cursor_inst_id = import_ir_inst.inst_id;