Просмотр исходного кода

Use canonical constant values as the keys for ImplWitnessAccess in AccessRewriteValues (#5912)

Instead of a bespoke structure based on the `EntityName` in the
`ImplWitnessAccess`' `.Self` type, use the constant value of the
`ImplWitnessAccess` as the map key in `AccessRewriteValues`. This is
okay after #5883 makes the `ImplWitnessAccess` to `.Self` canonically
the same regardless of how it's constructed with nested `where`
expressions.

Introduce `KnownInstId` which tracks in the type system that an `InstId`
is known to refer to a specific typed inst structure. This avoids
writing CHECKs and comments and allows compiler enforcement.
Dana Jansens 8 месяцев назад
Родитель
Сommit
2e22733372
3 измененных файлов с 106 добавлено и 104 удалено
  1. 34 100
      toolchain/check/facet_type.cpp
  2. 30 4
      toolchain/sem_ir/ids.h
  3. 42 0
      toolchain/sem_ir/inst.h

+ 34 - 100
toolchain/check/facet_type.cpp

@@ -263,38 +263,6 @@ auto AllocateFacetTypeImplWitness(Context& context,
   context.inst_blocks().ReplacePlaceholder(witness_id, empty_table);
 }
 
-namespace {
-// TODO: This class should go away, and we should just use the constant value of
-// the ImplWitnessAccess as a key in AccessRewriteValues, but that requires
-// changing its API to work with InstId instead of ImplWitnessAccess.
-struct FacetTypeConstraintValue {
-  SemIR::EntityNameId entity_name_id;
-  SemIR::ElementIndex access_index;
-  SemIR::SpecificInterfaceId specific_interface_id;
-
-  friend auto operator==(const FacetTypeConstraintValue& lhs,
-                         const FacetTypeConstraintValue& rhs) -> bool = default;
-};
-}  // namespace
-
-static auto GetFacetTypeConstraintValue(Context& context,
-                                        SemIR::ImplWitnessAccess access)
-    -> std::optional<FacetTypeConstraintValue> {
-  auto lookup =
-      context.insts().TryGetAs<SemIR::LookupImplWitness>(access.witness_id);
-  if (lookup) {
-    auto self = context.insts().TryGetAs<SemIR::BindSymbolicName>(
-        context.constant_values().GetConstantInstId(
-            lookup->query_self_inst_id));
-    if (self) {
-      return {{.entity_name_id = self->entity_name_id,
-               .access_index = access.index,
-               .specific_interface_id = lookup->query_specific_interface_id}};
-    }
-  }
-  return std::nullopt;
-}
-
 // A mapping of each associated constant (represented as `ImplWitnessAccess`) to
 // its value (represented as an `InstId`). Used to track rewrite constraints,
 // with the LHS mapping to the resolved value of the RHS.
@@ -310,24 +278,24 @@ class AccessRewriteValues {
     SemIR::InstId inst_id;
   };
 
-  auto InsertNotRewritten(Context& context, SemIR::ImplWitnessAccess access,
-                          SemIR::InstId inst_id) -> void {
-    map_.insert({*GetKey(context, access), {NotRewritten, inst_id}});
+  auto InsertNotRewritten(
+      Context& context, SemIR::KnownInstId<SemIR::ImplWitnessAccess> access_id,
+      SemIR::InstId inst_id) -> void {
+    map_.Insert(context.constant_values().Get(access_id),
+                {NotRewritten, inst_id});
   }
 
   // Finds and returns a pointer into the cache for a given ImplWitnessAccess.
   // The pointer will be invalidated by mutating the cache. Returns `nullptr`
   // if `access` is not found.
-  auto FindRef(Context& context, SemIR::ImplWitnessAccess access) -> Value* {
-    auto key = GetKey(context, access);
-    if (!key) {
-      return nullptr;
-    }
-    auto it = map_.find(*key);
-    if (it == map_.end()) {
+  auto FindRef(Context& context,
+               SemIR::KnownInstId<SemIR::ImplWitnessAccess> access_id)
+      -> Value* {
+    auto result = map_.Lookup(context.constant_values().Get(access_id));
+    if (!result) {
       return nullptr;
     }
-    return &it->second;
+    return &result.value();
   }
 
   auto SetBeingRewritten(Value& value) -> void {
@@ -346,54 +314,13 @@ class AccessRewriteValues {
   }
 
  private:
-  using Key = FacetTypeConstraintValue;
-  struct KeyInfo {
-    static auto getEmptyKey() -> Key {
-      return {
-          .entity_name_id = SemIR::EntityNameId::None,
-          .access_index = SemIR::ElementIndex(-1),
-          .specific_interface_id = SemIR::SpecificInterfaceId::None,
-      };
-    }
-    static auto getTombstoneKey() -> Key {
-      return {
-          .entity_name_id = SemIR::EntityNameId::None,
-          .access_index = SemIR::ElementIndex(-2),
-          .specific_interface_id = SemIR::SpecificInterfaceId::None,
-      };
-    }
-    static auto getHashValue(Key key) -> unsigned {
-      // This hash produces the same value if two ImplWitnessAccess are
-      // pointing to the same associated constant, even if they are different
-      // instruction ids.
-      //
-      // TODO: This truncates the high bits of the hash code which does not
-      // make for a good hash function.
-      return static_cast<unsigned>(static_cast<uint64_t>(HashValue(key)));
-    }
-    static auto isEqual(Key lhs, Key rhs) -> bool {
-      // This comparison is true if the two ImplWitnessAccess are pointing to
-      // the same associated constant, even if they are different instruction
-      // ids.
-      return lhs == rhs;
-    }
-  };
-
-  // Returns a key for the `access` to an associated context if the access is
-  // through a facet value. If the access it through another `ImplWitnessAccess`
-  // then no key is able to be made.
-  auto GetKey(Context& context, SemIR::ImplWitnessAccess access)
-      -> std::optional<Key> {
-    return GetFacetTypeConstraintValue(context, access);
-  }
-
   // Try avoid heap allocations in the common case where there are a small
   // number of rewrite rules referring to each other by keeping up to 16 on
   // the stack.
   //
   // TODO: Revisit if 16 is an appropriate number when we can measure how deep
   // rewrite constraint chains go in practice.
-  llvm::SmallDenseMap<Key, Value, 16, KeyInfo> map_;
+  Map<SemIR::ConstantId, Value, 16> map_;
 };
 
 // To be used for substituting into the RHS of a rewrite constraint.
@@ -437,7 +364,7 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
 
   auto Subst(SemIR::InstId& rhs_inst_id) -> SubstResult override {
     auto rhs_access =
-        context().insts().TryGetAs<SemIR::ImplWitnessAccess>(rhs_inst_id);
+        context().insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(rhs_inst_id);
     if (!rhs_access) {
       // We only want to substitute ImplWitnessAccesses written directly on the
       // RHS of the rewrite constraint, not when they are nested inside facet
@@ -471,7 +398,7 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
     // access needs to be resolved to a facet value first. If it can't be
     // resolved then the outer one can not be either.
     if (auto lookup = context().insts().TryGetAs<SemIR::LookupImplWitness>(
-            rhs_access->witness_id)) {
+            rhs_access->inst.witness_id)) {
       if (context().insts().Is<SemIR::ImplWitnessAccess>(
               lookup->query_self_inst_id)) {
         substs_in_progress_.push_back(rhs_inst_id);
@@ -479,7 +406,8 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
       }
     }
 
-    auto* rewrite_value = rewrite_values_->FindRef(context(), *rhs_access);
+    auto* rewrite_value =
+        rewrite_values_->FindRef(context(), rhs_access->inst_id);
     if (!rewrite_value) {
       // The RHS refers to an associated constant for which there is no rewrite
       // rule.
@@ -521,9 +449,11 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
       -> SemIR::InstId override {
     auto inst_id = RebuildNewInst(loc_id_, new_inst);
     auto subst_inst_id = substs_in_progress_.pop_back_val();
-    if (auto access = context().insts().TryGetAs<SemIR::ImplWitnessAccess>(
-            subst_inst_id)) {
-      if (auto* rewrite_value = rewrite_values_->FindRef(context(), *access)) {
+    if (auto access =
+            context().insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
+                subst_inst_id)) {
+      if (auto* rewrite_value =
+              rewrite_values_->FindRef(context(), access->inst_id)) {
         rewrite_values_->SetFullyRewritten(context(), *rewrite_value, inst_id);
       }
     }
@@ -532,9 +462,11 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
 
   auto ReuseUnchanged(SemIR::InstId orig_inst_id) -> SemIR::InstId override {
     auto subst_inst_id = substs_in_progress_.pop_back_val();
-    if (auto access = context().insts().TryGetAs<SemIR::ImplWitnessAccess>(
-            subst_inst_id)) {
-      if (auto* rewrite_value = rewrite_values_->FindRef(context(), *access)) {
+    if (auto access =
+            context().insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
+                subst_inst_id)) {
+      if (auto* rewrite_value =
+              rewrite_values_->FindRef(context(), access->inst_id)) {
         rewrite_values_->SetFullyRewritten(context(), *rewrite_value,
                                            orig_inst_id);
       }
@@ -580,23 +512,25 @@ auto ResolveFacetTypeRewriteConstraints(
   AccessRewriteValues rewrite_values;
 
   for (auto& constraint : rewrites) {
-    auto lhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(
+    auto lhs_access = context.insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
         GetImplWitnessAccessWithoutSubstitution(context, constraint.lhs_id));
     if (!lhs_access) {
       continue;
     }
 
-    rewrite_values.InsertNotRewritten(context, *lhs_access, constraint.rhs_id);
+    rewrite_values.InsertNotRewritten(context, lhs_access->inst_id,
+                                      constraint.rhs_id);
   }
 
   for (auto& constraint : rewrites) {
-    auto lhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(
+    auto lhs_access = context.insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
         GetImplWitnessAccessWithoutSubstitution(context, constraint.lhs_id));
     if (!lhs_access) {
       continue;
     }
 
-    auto* lhs_rewrite_value = rewrite_values.FindRef(context, *lhs_access);
+    auto* lhs_rewrite_value =
+        rewrite_values.FindRef(context, lhs_access->inst_id);
     // Every LHS was added with InsertNotRewritten above.
     CARBON_CHECK(lhs_rewrite_value);
     rewrite_values.SetBeingRewritten(*lhs_rewrite_value);
@@ -658,14 +592,14 @@ auto ResolveFacetTypeRewriteConstraints(
   for (size_t i = 0; i < keep_size;) {
     auto& constraint = rewrites[i];
 
-    auto lhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(
+    auto lhs_access = context.insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
         GetImplWitnessAccessWithoutSubstitution(context, constraint.lhs_id));
     if (!lhs_access) {
       ++i;
       continue;
     }
 
-    auto& rewrite_value = *rewrite_values.FindRef(context, *lhs_access);
+    auto& rewrite_value = *rewrite_values.FindRef(context, lhs_access->inst_id);
     auto rhs_id = std::exchange(rewrite_value.inst_id, SemIR::InstId::None);
     if (rhs_id == SemIR::InstId::None) {
       std::swap(rewrites[i], rewrites[keep_size - 1]);

+ 30 - 4
toolchain/sem_ir/ids.h

@@ -37,11 +37,10 @@ struct InstId : public IdBase<InstId> {
 
 constexpr InstId InstId::InitTombstone = InstId(NoneIndex - 1);
 
-// An InstId whose value is a type. The fact it's a type is CHECKed on
-// construction, and this allows that check to be represented in the type
-// system.
+// An InstId whose value is a type. The fact it's a type must be validated
+// before construction, and this allows that validation to be represented in the
+// type system.
 struct TypeInstId : public InstId {
-  static constexpr llvm::StringLiteral Label = "type_inst";
   static const TypeInstId None;
 
   using InstId::InstId;
@@ -58,6 +57,33 @@ struct TypeInstId : public InstId {
 
 constexpr TypeInstId TypeInstId::None = TypeInstId::UnsafeMake(InstId::None);
 
+// An InstId whose type is known to be T. The fact it's a type must be validated
+// before construction, and this allows that validation to be represented in the
+// type system.
+//
+// Unlike TypeInstId, this type can *not* be an operand in instructions, since
+// being a template prevents it from being used in non-generic contexts such as
+// switches.
+template <class T>
+struct KnownInstId : public InstId {
+  static const KnownInstId None;
+
+  using InstId::InstId;
+
+  static constexpr auto UnsafeMake(InstId id) -> KnownInstId {
+    return KnownInstId(UnsafeCtor(), id);
+  }
+
+ private:
+  struct UnsafeCtor {};
+  explicit constexpr KnownInstId(UnsafeCtor /*unsafe*/, InstId id)
+      : InstId(id) {}
+};
+
+template <class T>
+constexpr KnownInstId<T> KnownInstId<T>::None =
+    KnownInstId<T>::UnsafeMake(InstId::None);
+
 // An ID of an instruction that is referenced absolutely by another instruction.
 // This should only be used as the type of a field within a typed instruction
 // class.

+ 42 - 0
toolchain/sem_ir/inst.h

@@ -18,6 +18,7 @@
 #include "toolchain/base/int.h"
 #include "toolchain/base/value_store.h"
 #include "toolchain/sem_ir/id_kind.h"
+#include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/inst_kind.h"
 #include "toolchain/sem_ir/singleton_insts.h"
 #include "toolchain/sem_ir/typed_insts.h"
@@ -465,6 +466,13 @@ class InstStore {
     return result;
   }
 
+  // Returns the requested instruction, which is known to have the specified
+  // type.
+  template <typename InstT>
+  auto Get(KnownInstId<InstT> inst_id) const -> InstT {
+    return Get(static_cast<InstId>(inst_id)).As<InstT>();
+  }
+
   // Returns the requested instruction, preserving its attached type.
   auto GetWithAttachedType(InstId inst_id) const -> Inst {
     return values_.Get(inst_id);
@@ -500,6 +508,10 @@ class InstStore {
     return Get(inst_id).TryAs<InstT>();
   }
 
+  // Use `Get()` when the instruction type is known.
+  template <typename InstT, typename KnownInstT>
+  auto TryGetAs(KnownInstId<KnownInstT> inst_id) const = delete;
+
   // Returns the requested instruction as the specified type, if it is valid and
   // of that type. Otherwise returns nullopt.
   template <typename InstT>
@@ -510,6 +522,36 @@ class InstStore {
     return TryGetAs<InstT>(inst_id);
   }
 
+  template <class InstT>
+  struct GetAsWithIdResult {
+    SemIR::KnownInstId<InstT> inst_id;
+    InstT inst;
+  };
+
+  // Returns the requested instruction, which is known to have the specified
+  // type, along with the original `InstId`, encoding the work of checking its
+  // type in a `KnownInstId`.
+  template <typename InstT>
+  auto GetAsWithId(InstId inst_id) const -> GetAsWithIdResult<InstT> {
+    auto inst = GetAs<InstT>(inst_id);
+    return {.inst_id = SemIR::KnownInstId<InstT>::UnsafeMake(inst_id),
+            .inst = inst};
+  }
+
+  // Returns the requested instruction, if it is of that type, along with the
+  // original `InstId`, encoding the work of checking its type in a
+  // `KnownInstId`.
+  template <typename InstT>
+  auto TryGetAsWithId(InstId inst_id) const
+      -> std::optional<GetAsWithIdResult<InstT>> {
+    auto inst = TryGetAs<InstT>(inst_id);
+    if (!inst) {
+      return std::nullopt;
+    }
+    return {{.inst_id = SemIR::KnownInstId<InstT>::UnsafeMake(inst_id),
+             .inst = *inst}};
+  }
+
   // Attempts to convert the given instruction to the type that contains
   // `member`. If it can be converted, the instruction ID and instruction are
   // replaced by the unwrapped value of that member, and the converted wrapper