Răsfoiți Sursa

Remove parameter-constant arrays from import_ref (#4360)

This makes the code more resilient to changes in the structure of
parameter insts, and could help avoid bugs by making the
ImportRefResolver's data structures the single source of truth.
Geoff Romer 1 an în urmă
părinte
comite
e617d64939
1 a modificat fișierele cu 50 adăugiri și 55 ștergeri
  1. 50 55
      toolchain/check/import_ref.cpp

+ 50 - 55
toolchain/check/import_ref.cpp

@@ -328,12 +328,6 @@ class ImportRefResolver {
     llvm::SmallVector<SemIR::ImportIRInst> indirect_insts = {};
   };
 
-  // Local information associated with an imported parameter.
-  struct ParamData {
-    SemIR::ConstantId type_const_id;
-    SemIR::ConstantId bind_const_id;
-  };
-
   // Local information associated with an imported generic.
   struct GenericData {
     llvm::SmallVector<SemIR::InstId> bindings;
@@ -469,6 +463,13 @@ class ImportRefResolver {
     return GetLocalConstantId(import_ir_.types().GetConstantId(type_id));
   }
 
+  template <typename Id>
+  auto GetLocalConstantIdChecked(Id id) {
+    auto result = GetLocalConstantId(id);
+    CARBON_CHECK(result.is_valid());
+    return result;
+  }
+
   // Gets the local constant values corresponding to an imported inst block.
   auto GetLocalInstBlockContents(SemIR::InstBlockId import_block_id)
       -> llvm::SmallVector<SemIR::InstId> {
@@ -643,21 +644,16 @@ class ImportRefResolver {
     return specific_id;
   }
 
-  // Returns the ConstantId for each parameter's type. Adds unresolved constants
-  // to work_stack_.
-  auto GetLocalParamConstantIds(SemIR::InstBlockId param_refs_id)
-      -> llvm::SmallVector<ParamData> {
-    llvm::SmallVector<ParamData> param_data;
+  // Adds unresolved constants for each parameter's type to work_stack_.
+  auto LoadLocalParamConstantIds(SemIR::InstBlockId param_refs_id) -> void {
     if (!param_refs_id.is_valid() ||
         param_refs_id == SemIR::InstBlockId::Empty) {
-      return param_data;
+      return;
     }
 
     const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id);
-    param_data.reserve(param_refs.size());
     for (auto inst_id : param_refs) {
-      auto type_const_id =
-          GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id());
+      GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id());
 
       // If the parameter is a symbolic binding, build the BindSymbolicName
       // constant.
@@ -667,19 +663,25 @@ class ImportRefResolver {
         bind_id = addr->inner_id;
         bind_inst = import_ir_.insts().Get(bind_id);
       }
-      auto bind_const_id = bind_inst.Is<SemIR::BindSymbolicName>()
-                               ? GetLocalConstantId(bind_id)
-                               : SemIR::ConstantId::Invalid;
-      param_data.push_back(
-          {.type_const_id = type_const_id, .bind_const_id = bind_const_id});
+      if (bind_inst.Is<SemIR::BindSymbolicName>()) {
+        GetLocalConstantId(bind_id);
+      }
     }
-    return param_data;
   }
 
-  // Given a param_refs_id and const_ids from GetLocalParamConstantIds, returns
-  // a version of param_refs_id localized to the current IR.
-  auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id,
-                           const llvm::SmallVector<ParamData>& params_data)
+  // Returns a version of param_refs_id localized to the current IR.
+  //
+  // Must only be called after a call to GetLocalParamConstantIds(param_refs_id)
+  // has completed without adding any new work to work_stack_.
+  //
+  // TODO: This is inconsistent with the rest of this class, which expects
+  // the relevant constants to be explicitly passed in. That makes it
+  // easier to statically detect when an input isn't loaded, but makes it
+  // harder to support importing more complex inst structures. We should
+  // take a holistic look at how to balance those concerns. For example,
+  // could the same function be used to load the constants and use them, with
+  // a parameter to select between the two?
+  auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id)
       -> SemIR::InstBlockId {
     if (!param_refs_id.is_valid() ||
         param_refs_id == SemIR::InstBlockId::Empty) {
@@ -687,7 +689,7 @@ class ImportRefResolver {
     }
     const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id);
     llvm::SmallVector<SemIR::InstId> new_param_refs;
-    for (auto [ref_id, param_data] : llvm::zip(param_refs, params_data)) {
+    for (auto ref_id : param_refs) {
       // Figure out the param structure. This echoes
       // Function::GetParamFromParamRefId.
       // TODO: Consider a different parameter handling to simplify import logic.
@@ -712,8 +714,8 @@ class ImportRefResolver {
 
       // Rebuild the param instruction.
       auto name_id = GetLocalNameId(param_inst.name_id);
-      auto type_id =
-          context_.GetTypeIdForTypeConstant(param_data.type_const_id);
+      auto type_id = context_.GetTypeIdForTypeConstant(
+          GetLocalConstantIdChecked(param_inst.type_id));
 
       auto new_param_id = context_.AddInstInNoBlock<SemIR::Param>(
           AddImportIRInst(param_id),
@@ -740,12 +742,12 @@ class ImportRefResolver {
             auto new_bind_inst =
                 context_.insts().GetAs<SemIR::BindSymbolicName>(
                     context_.constant_values().GetInstId(
-                        param_data.bind_const_id));
+                        GetLocalConstantIdChecked(bind_id)));
             new_bind_inst.value_id = new_param_id;
             new_param_id = context_.AddInstInNoBlock(AddImportIRInst(bind_id),
                                                      new_bind_inst);
             context_.constant_values().Set(new_param_id,
-                                           param_data.bind_const_id);
+                                           GetLocalConstantIdChecked(bind_id));
             break;
           }
           default: {
@@ -1322,9 +1324,8 @@ class ImportRefResolver {
 
     // Load constants for the definition.
     auto parent_scope_id = GetLocalNameScopeId(import_class.parent_scope_id);
-    auto implicit_param_const_ids =
-        GetLocalParamConstantIds(import_class.implicit_param_refs_id);
-    auto param_const_ids = GetLocalParamConstantIds(import_class.param_refs_id);
+    LoadLocalParamConstantIds(import_class.implicit_param_refs_id);
+    LoadLocalParamConstantIds(import_class.param_refs_id);
     auto generic_data = GetLocalGenericData(import_class.generic_id);
     auto self_const_id = GetLocalConstantId(import_class.self_type_id);
     auto complete_type_witness_id =
@@ -1341,10 +1342,9 @@ class ImportRefResolver {
 
     auto& new_class = context_.classes().Get(class_id);
     new_class.parent_scope_id = parent_scope_id;
-    new_class.implicit_param_refs_id = GetLocalParamRefsId(
-        import_class.implicit_param_refs_id, implicit_param_const_ids);
-    new_class.param_refs_id =
-        GetLocalParamRefsId(import_class.param_refs_id, param_const_ids);
+    new_class.implicit_param_refs_id =
+        GetLocalParamRefsId(import_class.implicit_param_refs_id);
+    new_class.param_refs_id = GetLocalParamRefsId(import_class.param_refs_id);
     SetGenericData(import_class.generic_id, new_class.generic_id, generic_data);
     new_class.self_type_id = context_.GetTypeIdForTypeConstant(self_const_id);
 
@@ -1495,10 +1495,8 @@ class ImportRefResolver {
           import_ir_.insts().Get(import_function.return_storage_id).type_id());
     }
     auto parent_scope_id = GetLocalNameScopeId(import_function.parent_scope_id);
-    auto implicit_param_const_ids =
-        GetLocalParamConstantIds(import_function.implicit_param_refs_id);
-    auto param_const_ids =
-        GetLocalParamConstantIds(import_function.param_refs_id);
+    LoadLocalParamConstantIds(import_function.implicit_param_refs_id);
+    LoadLocalParamConstantIds(import_function.param_refs_id);
     auto generic_data = GetLocalGenericData(import_function.generic_id);
 
     if (HasNewWork()) {
@@ -1508,10 +1506,10 @@ class ImportRefResolver {
     // Add the function declaration.
     auto& new_function = context_.functions().Get(function_id);
     new_function.parent_scope_id = parent_scope_id;
-    new_function.implicit_param_refs_id = GetLocalParamRefsId(
-        import_function.implicit_param_refs_id, implicit_param_const_ids);
+    new_function.implicit_param_refs_id =
+        GetLocalParamRefsId(import_function.implicit_param_refs_id);
     new_function.param_refs_id =
-        GetLocalParamRefsId(import_function.param_refs_id, param_const_ids);
+        GetLocalParamRefsId(import_function.param_refs_id);
     SetGenericData(import_function.generic_id, new_function.generic_id,
                    generic_data);
 
@@ -1654,8 +1652,7 @@ class ImportRefResolver {
 
     // Load constants for the definition.
     auto parent_scope_id = GetLocalNameScopeId(import_impl.parent_scope_id);
-    auto implicit_param_const_ids =
-        GetLocalParamConstantIds(import_impl.implicit_param_refs_id);
+    LoadLocalParamConstantIds(import_impl.implicit_param_refs_id);
     auto generic_data = GetLocalGenericData(import_impl.generic_id);
     auto self_const_id = GetLocalConstantId(import_impl.self_id);
     auto constraint_const_id = GetLocalConstantId(import_impl.constraint_id);
@@ -1666,8 +1663,8 @@ class ImportRefResolver {
 
     auto& new_impl = context_.impls().Get(impl_id);
     new_impl.parent_scope_id = parent_scope_id;
-    new_impl.implicit_param_refs_id = GetLocalParamRefsId(
-        import_impl.implicit_param_refs_id, implicit_param_const_ids);
+    new_impl.implicit_param_refs_id =
+        GetLocalParamRefsId(import_impl.implicit_param_refs_id);
     CARBON_CHECK(!import_impl.param_refs_id.is_valid() &&
                  !new_impl.param_refs_id.is_valid());
     SetGenericData(import_impl.generic_id, new_impl.generic_id, generic_data);
@@ -1811,10 +1808,8 @@ class ImportRefResolver {
 
     auto parent_scope_id =
         GetLocalNameScopeId(import_interface.parent_scope_id);
-    auto implicit_param_const_ids =
-        GetLocalParamConstantIds(import_interface.implicit_param_refs_id);
-    auto param_const_ids =
-        GetLocalParamConstantIds(import_interface.param_refs_id);
+    LoadLocalParamConstantIds(import_interface.implicit_param_refs_id);
+    LoadLocalParamConstantIds(import_interface.param_refs_id);
     auto generic_data = GetLocalGenericData(import_interface.generic_id);
 
     std::optional<SemIR::InstId> self_param_id;
@@ -1828,10 +1823,10 @@ class ImportRefResolver {
 
     auto& new_interface = context_.interfaces().Get(interface_id);
     new_interface.parent_scope_id = parent_scope_id;
-    new_interface.implicit_param_refs_id = GetLocalParamRefsId(
-        import_interface.implicit_param_refs_id, implicit_param_const_ids);
+    new_interface.implicit_param_refs_id =
+        GetLocalParamRefsId(import_interface.implicit_param_refs_id);
     new_interface.param_refs_id =
-        GetLocalParamRefsId(import_interface.param_refs_id, param_const_ids);
+        GetLocalParamRefsId(import_interface.param_refs_id);
     SetGenericData(import_interface.generic_id, new_interface.generic_id,
                    generic_data);