ソースを参照

Avoid UAF in impl lookup when deduce imports an impl from Core (#5126)

Deduction can do conversion, and conversion can import impls from the
Core package. If you have the right number of impls in your ImplStore at
that moment, it will reallocate and any pointer into context.impls()
will be invalidated.

In particular, in impl lookup, we currentl loop over context.impls() and
do deduction on each impl. So this can break the for loop. Additionally,
we pass around a reference to the currently-being-looked-at Impl, which
becomes invalidated.

This is very challenging to test in any reliable way as you need a
specific number of impls in your ImplStore. I hit it when making changes
to a test in the middle of a bunch of file splits. Putting the same test
in its own file did not trigger the issue. It was caught by ASAN, which
showed:
- The memory was allocated by SmallVector in handle_impl when making the
Impl.
- The memory was freed by SmallVector reallocating in import_ref.cpp
- The memory was accessed when reading through the `impl` reference in
FindWitnessInImpls(). I was able to reproduce by printing the
`impl.interface.interface_id` after the call to GetWitnessIdForImpl()
which does the deduction.

I didn't save the ASAN stack and now I can't find the exact permutation
of the test file that caused it to occur in order to reproduce. :(

To avoid the UAF we stop passing around the Impl reference, and pass
around either the ImplId, or values from the Impl. To avoid copying the
entirety of the impl ids in context.impls() into a separate container in
order to iterate safely, we move the early outs from
GetWitnessIdForImpl() up to the caller where it can use them to reduce
the number impl ids that we iterate over. Type structures will be able
to further reduce the size of this set.
Dana Jansens 1 年間 前
コミット
ce08e4d9a1
3 ファイル変更75 行追加43 行削除
  1. 3 3
      toolchain/check/deduce.cpp
  2. 13 2
      toolchain/check/deduce.h
  3. 59 38
      toolchain/check/impl_lookup.cpp

+ 3 - 3
toolchain/check/deduce.cpp

@@ -641,8 +641,8 @@ auto DeduceGenericCallArguments(
   return deduction.MakeSpecific();
 }
 
-auto DeduceImplArguments(Context& context, SemIR::LocId loc_id,
-                         const SemIR::Impl& impl, SemIR::ConstantId self_id,
+auto DeduceImplArguments(Context& context, SemIR::LocId loc_id, DeduceImpl impl,
+                         SemIR::ConstantId self_id,
                          SemIR::SpecificId constraint_specific_id)
     -> SemIR::SpecificId {
   DeductionContext deduction(context, loc_id, impl.generic_id,
@@ -653,7 +653,7 @@ auto DeduceImplArguments(Context& context, SemIR::LocId loc_id,
   // Prepare to perform deduction of the type and interface.
   deduction.Add(impl.self_id, context.constant_values().GetInstId(self_id),
                 /*needs_substitution=*/false);
-  deduction.Add(impl.interface.specific_id, constraint_specific_id,
+  deduction.Add(impl.specific_id, constraint_specific_id,
                 /*needs_substitution=*/false);
 
   if (!deduction.Deduce() || !deduction.CheckDeductionIsComplete()) {

+ 13 - 2
toolchain/check/deduce.h

@@ -18,10 +18,21 @@ auto DeduceGenericCallArguments(
     SemIR::InstId self_id, llvm::ArrayRef<SemIR::InstId> arg_ids)
     -> SemIR::SpecificId;
 
+// Data from the `Impl` that is used by deduce.
+//
+// We don't use a reference to an `Impl` as deduction can invalidate the
+// reference by causing impl declarations to be imported from `Core` during
+// conversion.
+struct DeduceImpl {
+  SemIR::InstId self_id;
+  SemIR::GenericId generic_id;
+  SemIR::SpecificId specific_id;
+};
+
 // Deduces the impl arguments to use in a use of a parameterized impl. Returns
 // `None` if deduction fails.
-auto DeduceImplArguments(Context& context, SemIR::LocId loc_id,
-                         const SemIR::Impl& impl, SemIR::ConstantId self_id,
+auto DeduceImplArguments(Context& context, SemIR::LocId loc_id, DeduceImpl impl,
+                         SemIR::ConstantId self_id,
                          SemIR::SpecificId constraint_specific_id)
     -> SemIR::SpecificId;
 

+ 59 - 38
toolchain/check/impl_lookup.cpp

@@ -181,48 +181,31 @@ static auto GetInterfacesFromConstantId(
 static auto GetWitnessIdForImpl(
     Context& context, SemIR::LocId loc_id, SemIR::ConstantId type_const_id,
     const SemIR::CompleteFacetType::RequiredInterface& interface,
-    const SemIR::Impl& impl) -> SemIR::InstId {
-  // If the impl's interface_id differs from the query, then this impl can not
-  // possibly provide the queried interface, and we don't need to proceed.
-  if (impl.interface.interface_id != interface.interface_id) {
-    return SemIR::InstId::None;
-  }
-
-  // When the impl's interface_id matches, but the interface is generic, the
-  // impl may or may not match based on restrictions in the generic parameters
-  // of the impl.
-  //
-  // As a shortcut, if the impl's constraint is not symbolic (does not depend on
-  // any generic parameters), then we can determine that we match if the
-  // specific ids match exactly.
-  auto impl_interface_const_id =
-      context.constant_values().Get(impl.constraint_id);
-  if (!impl_interface_const_id.is_symbolic()) {
-    if (impl.interface.specific_id != interface.specific_id) {
-      return SemIR::InstId::None;
-    }
-  }
-
-  // This check comes first to avoid deduction with an invalid impl. We use an
-  // error value to indicate an error during creation of the impl, such as a
-  // recursive impl which will cause deduction to recurse infinitely.
-  if (impl.witness_id == SemIR::ErrorInst::SingletonInstId) {
-    return SemIR::InstId::None;
-  }
-  CARBON_CHECK(impl.witness_id.has_value());
-
+    SemIR::ImplId impl_id) -> SemIR::InstId {
   // The impl may have generic arguments, in which case we need to deduce them
   // to find what they are given the specific type and interface query. We use
   // that specific to map values in the impl to the deduced values.
   auto specific_id = SemIR::SpecificId::None;
-  if (impl.generic_id.has_value()) {
-    specific_id = DeduceImplArguments(context, loc_id, impl, type_const_id,
-                                      interface.specific_id);
-    if (!specific_id.has_value()) {
-      return SemIR::InstId::None;
+  {
+    // DeduceImplArguments can import new impls which can invalidate any
+    // pointers into `context.impls()`.
+    const SemIR::Impl& impl = context.impls().Get(impl_id);
+    if (impl.generic_id.has_value()) {
+      specific_id =
+          DeduceImplArguments(context, loc_id,
+                              {.self_id = impl.self_id,
+                               .generic_id = impl.generic_id,
+                               .specific_id = impl.interface.specific_id},
+                              type_const_id, interface.specific_id);
+      if (!specific_id.has_value()) {
+        return SemIR::InstId::None;
+      }
     }
   }
 
+  // Get a pointer again after DeduceImplArguments() is complete.
+  const SemIR::Impl& impl = context.impls().Get(impl_id);
+
   // The self type of the impl must match the type in the query, or this is an
   // `impl T as ...` for some other type `T` and should not be considered.
   auto deduced_self_const_id = SemIR::GetConstantValueInSpecific(
@@ -321,10 +304,48 @@ static auto FindWitnessInImpls(
     Context& context, SemIR::LocId loc_id, SemIR::ConstantId type_const_id,
     const SemIR::SpecificInterface& specific_interface) -> SemIR::InstId {
   auto& stack = context.impl_lookup_stack();
-  for (const auto& impl : context.impls().array_ref()) {
-    stack.back().impl_loc = impl.definition_id;
+  // TODO: Build this candidate list by matching against type structures to
+  // narrow it down.
+  llvm::SmallVector<std::pair<SemIR::ImplId, SemIR::InstId>> candidate_impl_ids;
+  for (auto [id, impl] : context.impls().enumerate()) {
+    // If the impl's interface_id differs from the query, then this impl can not
+    // possibly provide the queried interface.
+    if (impl.interface.interface_id != specific_interface.interface_id) {
+      continue;
+    }
+
+    // When the impl's interface_id matches, but the interface is generic, the
+    // impl may or may not match based on restrictions in the generic parameters
+    // of the impl.
+    //
+    // As a shortcut, if the impl's constraint is not symbolic (does not depend
+    // on any generic parameters), then we can determine that we match if the
+    // specific ids match exactly.
+    auto impl_interface_const_id =
+        context.constant_values().Get(impl.constraint_id);
+    if (!impl_interface_const_id.is_symbolic()) {
+      if (impl.interface.specific_id != specific_interface.specific_id) {
+        continue;
+      }
+    }
+
+    // This check comes first to avoid deduction with an invalid impl. We use an
+    // error value to indicate an error during creation of the impl, such as a
+    // recursive impl which will cause deduction to recurse infinitely.
+    if (impl.witness_id == SemIR::ErrorInst::SingletonInstId) {
+      continue;
+    }
+    CARBON_CHECK(impl.witness_id.has_value());
+
+    candidate_impl_ids.push_back({id, impl.definition_id});
+  }
+
+  for (auto [impl_id, loc_inst_id] : candidate_impl_ids) {
+    stack.back().impl_loc = loc_inst_id;
+    // NOTE: GetWitnessIdForImpl() does deduction, which can cause new impls to
+    // be imported, invalidating any pointer into `context.impls()`.
     auto result_witness_id = GetWitnessIdForImpl(context, loc_id, type_const_id,
-                                                 specific_interface, impl);
+                                                 specific_interface, impl_id);
     if (result_witness_id.has_value()) {
       // We found a matching impl; don't keep looking for this interface.
       return result_witness_id;