Quellcode durchsuchen

Store additional information for symbolic constants. (#4102)

When forming a `ConstantId` for a symbolic constant, add storage to
track the generic in which the constant was formed and the index within
that generic. These fields are not yet populated.
Richard Smith vor 1 Jahr
Ursprung
Commit
6ecf4ce9a7

+ 2 - 2
toolchain/check/member_access.cpp

@@ -132,8 +132,8 @@ static auto LookupInterfaceWitness(Context& context,
   // considering impls that are for the same interface we're querying. We can
   // considering impls that are for the same interface we're querying. We can
   // also skip impls that mention any types that aren't part of our impl query.
   // also skip impls that mention any types that aren't part of our impl query.
   for (const auto& impl : context.impls().array_ref()) {
   for (const auto& impl : context.impls().array_ref()) {
-    if (context.types().GetInstId(impl.self_id) !=
-        context.constant_values().GetInstId(type_const_id)) {
+    if (!context.constant_values().EqualAcrossDeclarations(
+            context.types().GetConstantId(impl.self_id), type_const_id)) {
       continue;
       continue;
     }
     }
     auto interface_type =
     auto interface_type =

+ 10 - 10
toolchain/check/testdata/basics/no_prelude/raw_ir.carbon

@@ -35,12 +35,12 @@ fn Foo[T:! type](n: T) -> (T, ()) {
 // CHECK:STDOUT:   generic_instances: {}
 // CHECK:STDOUT:   generic_instances: {}
 // CHECK:STDOUT:   types:
 // CHECK:STDOUT:   types:
 // CHECK:STDOUT:     type0:           {constant: template instNamespaceType, value_rep: {kind: copy, type: type0}}
 // CHECK:STDOUT:     type0:           {constant: template instNamespaceType, value_rep: {kind: copy, type: type0}}
-// CHECK:STDOUT:     type1:           {constant: symbolic inst+3, value_rep: {kind: copy, type: type1}}
+// CHECK:STDOUT:     type1:           {constant: symbolic 0, value_rep: {kind: copy, type: type1}}
 // CHECK:STDOUT:     type2:           {constant: template inst+8, value_rep: {kind: none, type: type2}}
 // CHECK:STDOUT:     type2:           {constant: template inst+8, value_rep: {kind: none, type: type2}}
 // CHECK:STDOUT:     type3:           {constant: template inst+10, value_rep: {kind: unknown, type: type<invalid>}}
 // CHECK:STDOUT:     type3:           {constant: template inst+10, value_rep: {kind: unknown, type: type<invalid>}}
-// CHECK:STDOUT:     type4:           {constant: symbolic inst+13, value_rep: {kind: pointer, type: type6}}
+// CHECK:STDOUT:     type4:           {constant: symbolic 1, value_rep: {kind: pointer, type: type6}}
 // CHECK:STDOUT:     type5:           {constant: template inst+17, value_rep: {kind: none, type: type2}}
 // CHECK:STDOUT:     type5:           {constant: template inst+17, value_rep: {kind: none, type: type2}}
-// CHECK:STDOUT:     type6:           {constant: symbolic inst+19, value_rep: {kind: copy, type: type6}}
+// CHECK:STDOUT:     type6:           {constant: symbolic 2, value_rep: {kind: copy, type: type6}}
 // CHECK:STDOUT:   type_blocks:
 // CHECK:STDOUT:   type_blocks:
 // CHECK:STDOUT:     type_block0:     {}
 // CHECK:STDOUT:     type_block0:     {}
 // CHECK:STDOUT:     type_block1:
 // CHECK:STDOUT:     type_block1:
@@ -84,19 +84,19 @@ fn Foo[T:! type](n: T) -> (T, ()) {
 // CHECK:STDOUT:     'inst+31':         {kind: ReturnExpr, arg0: inst+30, arg1: inst+15}
 // CHECK:STDOUT:     'inst+31':         {kind: ReturnExpr, arg0: inst+30, arg1: inst+15}
 // CHECK:STDOUT:   constant_values:
 // CHECK:STDOUT:   constant_values:
 // CHECK:STDOUT:     'inst+0':          template inst+0
 // CHECK:STDOUT:     'inst+0':          template inst+0
-// CHECK:STDOUT:     'inst+2':          symbolic inst+3
-// CHECK:STDOUT:     'inst+3':          symbolic inst+3
-// CHECK:STDOUT:     'inst+4':          symbolic inst+3
-// CHECK:STDOUT:     'inst+7':          symbolic inst+3
+// CHECK:STDOUT:     'inst+2':          symbolic 0
+// CHECK:STDOUT:     'inst+3':          symbolic 0
+// CHECK:STDOUT:     'inst+4':          symbolic 0
+// CHECK:STDOUT:     'inst+7':          symbolic 0
 // CHECK:STDOUT:     'inst+8':          template inst+8
 // CHECK:STDOUT:     'inst+8':          template inst+8
 // CHECK:STDOUT:     'inst+10':         template inst+10
 // CHECK:STDOUT:     'inst+10':         template inst+10
 // CHECK:STDOUT:     'inst+12':         template inst+8
 // CHECK:STDOUT:     'inst+12':         template inst+8
-// CHECK:STDOUT:     'inst+13':         symbolic inst+13
-// CHECK:STDOUT:     'inst+14':         symbolic inst+13
+// CHECK:STDOUT:     'inst+13':         symbolic 1
+// CHECK:STDOUT:     'inst+14':         symbolic 1
 // CHECK:STDOUT:     'inst+16':         template inst+18
 // CHECK:STDOUT:     'inst+16':         template inst+18
 // CHECK:STDOUT:     'inst+17':         template inst+17
 // CHECK:STDOUT:     'inst+17':         template inst+17
 // CHECK:STDOUT:     'inst+18':         template inst+18
 // CHECK:STDOUT:     'inst+18':         template inst+18
-// CHECK:STDOUT:     'inst+19':         symbolic inst+19
+// CHECK:STDOUT:     'inst+19':         symbolic 2
 // CHECK:STDOUT:     'inst+26':         template inst+27
 // CHECK:STDOUT:     'inst+26':         template inst+27
 // CHECK:STDOUT:     'inst+27':         template inst+27
 // CHECK:STDOUT:     'inst+27':         template inst+27
 // CHECK:STDOUT:     'inst+28':         template inst+27
 // CHECK:STDOUT:     'inst+28':         template inst+27

+ 13 - 3
toolchain/sem_ir/constant.cpp

@@ -12,9 +12,19 @@ auto ConstantStore::GetOrAdd(Inst inst, bool is_symbolic) -> ConstantId {
   auto [it, added] = map_.insert({inst, ConstantId::Invalid});
   auto [it, added] = map_.insert({inst, ConstantId::Invalid});
   if (added) {
   if (added) {
     auto inst_id = sem_ir_.insts().AddInNoBlock(LocIdAndInst::NoLoc(inst));
     auto inst_id = sem_ir_.insts().AddInNoBlock(LocIdAndInst::NoLoc(inst));
-    auto const_id = is_symbolic
-                        ? SemIR::ConstantId::ForSymbolicConstant(inst_id)
-                        : SemIR::ConstantId::ForTemplateConstant(inst_id);
+    ConstantId const_id = ConstantId::Invalid;
+    if (is_symbolic) {
+      // The instruction in the constants store is an abstract symbolic
+      // constant, not associated with any particular generic.
+      auto symbolic_constant =
+          SymbolicConstant{.inst_id = inst_id,
+                           .generic_id = GenericId::Invalid,
+                           .index = GenericInstIndex::Invalid};
+      const_id =
+          sem_ir_.constant_values().AddSymbolicConstant(symbolic_constant);
+    } else {
+      const_id = SemIR::ConstantId::ForTemplateConstant(inst_id);
+    }
     it->second = const_id;
     it->second = const_id;
     sem_ir_.constant_values().Set(inst_id, const_id);
     sem_ir_.constant_values().Set(inst_id, const_id);
     constants_.push_back(inst_id);
     constants_.push_back(inst_id);

+ 49 - 2
toolchain/sem_ir/constant.h

@@ -5,12 +5,25 @@
 #ifndef CARBON_TOOLCHAIN_SEM_IR_CONSTANT_H_
 #ifndef CARBON_TOOLCHAIN_SEM_IR_CONSTANT_H_
 #define CARBON_TOOLCHAIN_SEM_IR_CONSTANT_H_
 #define CARBON_TOOLCHAIN_SEM_IR_CONSTANT_H_
 
 
-#include "llvm/ADT/FoldingSet.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/inst.h"
 #include "toolchain/sem_ir/inst.h"
 
 
 namespace Carbon::SemIR {
 namespace Carbon::SemIR {
 
 
+// Information about a symbolic constant value. These are indexed by
+// `ConstantId`s for which `is_symbolic` is true.
+struct SymbolicConstant {
+  // The constant instruction that defines the value of this symbolic constant.
+  InstId inst_id;
+  // The enclosing generic. If this is invalid, then this is an abstract
+  // symbolic constant, such as a constant instruction in the constants block,
+  // rather than one associated with a particular generic.
+  GenericId generic_id;
+  // The index of this symbolic constant within the generic's list of symbolic
+  // constants, or invalid if `generic_id` is invalid.
+  GenericInstIndex index;
+};
+
 // Provides a ValueStore wrapper for tracking the constant values of
 // Provides a ValueStore wrapper for tracking the constant values of
 // instructions.
 // instructions.
 class ConstantValueStore {
 class ConstantValueStore {
@@ -40,7 +53,13 @@ class ConstantValueStore {
   // Gets the instruction ID that defines the value of the given constant.
   // Gets the instruction ID that defines the value of the given constant.
   // Returns Invalid if the constant ID is non-constant. Requires is_valid.
   // Returns Invalid if the constant ID is non-constant. Requires is_valid.
   auto GetInstId(ConstantId const_id) const -> InstId {
   auto GetInstId(ConstantId const_id) const -> InstId {
-    return const_id.inst_id();
+    if (const_id.is_template()) {
+      return const_id.template_inst_id();
+    }
+    if (const_id.is_symbolic()) {
+      return GetSymbolicConstant(const_id).inst_id;
+    }
+    return InstId::Invalid;
   }
   }
 
 
   // Gets the instruction ID that defines the value of the given constant.
   // Gets the instruction ID that defines the value of the given constant.
@@ -55,6 +74,27 @@ class ConstantValueStore {
     return GetInstId(Get(inst_id));
     return GetInstId(Get(inst_id));
   }
   }
 
 
+  // Returns whether two constant IDs represent the same constant value. This
+  // includes the case where they might be in different generics and thus might
+  // have different ConstantIds, but are still symbolically equal.
+  auto EqualAcrossDeclarations(ConstantId a, ConstantId b) const -> bool {
+    return GetInstId(a) == GetInstId(b);
+  }
+
+  auto AddSymbolicConstant(SymbolicConstant constant) -> ConstantId {
+    symbolic_constants_.push_back(constant);
+    return ConstantId::ForSymbolicConstantIndex(symbolic_constants_.size() - 1);
+  }
+
+  auto GetSymbolicConstant(ConstantId const_id) -> SymbolicConstant& {
+    return symbolic_constants_[const_id.symbolic_index()];
+  }
+
+  auto GetSymbolicConstant(ConstantId const_id) const
+      -> const SymbolicConstant& {
+    return symbolic_constants_[const_id.symbolic_index()];
+  }
+
   // Returns the constant values mapping as an ArrayRef whose keys are
   // Returns the constant values mapping as an ArrayRef whose keys are
   // instruction indexes. Some of the elements in this mapping may be Invalid or
   // instruction indexes. Some of the elements in this mapping may be Invalid or
   // NotConstant.
   // NotConstant.
@@ -70,6 +110,13 @@ class ConstantValueStore {
   // Set inline size to 0 because these will typically be too large for the
   // Set inline size to 0 because these will typically be too large for the
   // stack, while this does make File smaller.
   // stack, while this does make File smaller.
   llvm::SmallVector<ConstantId, 0> values_;
   llvm::SmallVector<ConstantId, 0> values_;
+
+  // A mapping from a symbolic constant ID index to information about the
+  // symbolic constant. For a template constant, the only information that we
+  // track is the instruction ID, which is stored directly within the
+  // `ConstantId`. For a symbolic constant, we also track information about
+  // where the constant was used, which is stored here.
+  llvm::SmallVector<SymbolicConstant, 0> symbolic_constants_;
 };
 };
 
 
 // Provides storage for instructions representing deduplicated global constants.
 // Provides storage for instructions representing deduplicated global constants.

+ 1 - 0
toolchain/sem_ir/formatter.cpp

@@ -452,6 +452,7 @@ class Formatter {
         out_ << " = ";
         out_ << " = ";
         FormatInstName(
         FormatInstName(
             sem_ir_.constant_values().GetInstId(pending_constant_value_));
             sem_ir_.constant_values().GetInstId(pending_constant_value_));
+        // TODO: For a symbolic constant, include the generic and index.
       }
       }
     } else {
     } else {
       out_ << pending_constant_value_;
       out_ << pending_constant_value_;

+ 83 - 18
toolchain/sem_ir/ids.h

@@ -95,6 +95,11 @@ constexpr InstId InstId::PackageNamespace = InstId(BuiltinKind::ValidCount);
 // - a symbolic constant, whose value includes a symbolic parameter, such as
 // - a symbolic constant, whose value includes a symbolic parameter, such as
 //   `Vector(T*)`, or
 //   `Vector(T*)`, or
 // - a runtime expression, such as `Print("hello")`.
 // - a runtime expression, such as `Print("hello")`.
+//
+// Template constants are a thin wrapper around the instruction ID of the
+// constant instruction that defines the constant. Symbolic constants are an
+// index into a separate table of `SymbolicConstant`s maintained by the constant
+// value store.
 struct ConstantId : public IdBase, public Printable<ConstantId> {
 struct ConstantId : public IdBase, public Printable<ConstantId> {
   // An ID for an expression that is not constant.
   // An ID for an expression that is not constant.
   static const ConstantId NotConstant;
   static const ConstantId NotConstant;
@@ -108,14 +113,13 @@ struct ConstantId : public IdBase, public Printable<ConstantId> {
   // either be in the `constants` block in the file or should be known to be
   // either be in the `constants` block in the file or should be known to be
   // unique.
   // unique.
   static constexpr auto ForTemplateConstant(InstId const_id) -> ConstantId {
   static constexpr auto ForTemplateConstant(InstId const_id) -> ConstantId {
-    return ConstantId(const_id.index + IndexOffset);
+    return ConstantId(const_id.index);
   }
   }
 
 
-  // Returns the constant ID corresponding to a symbolic constant, which should
-  // either be in the `constants` block in the file or should be known to be
-  // unique.
-  static constexpr auto ForSymbolicConstant(InstId const_id) -> ConstantId {
-    return ConstantId(-const_id.index - IndexOffset);
+  // Returns the constant ID corresponding to a symbolic constant index.
+  static constexpr auto ForSymbolicConstantIndex(int32_t symbolic_index)
+      -> ConstantId {
+    return ConstantId(FirstSymbolicIndex - symbolic_index);
   }
   }
 
 
   using IdBase::IdBase;
   using IdBase::IdBase;
@@ -128,21 +132,21 @@ struct ConstantId : public IdBase, public Printable<ConstantId> {
   // Returns whether this represents a symbolic constant. Requires is_valid.
   // Returns whether this represents a symbolic constant. Requires is_valid.
   auto is_symbolic() const -> bool {
   auto is_symbolic() const -> bool {
     CARBON_CHECK(is_valid());
     CARBON_CHECK(is_valid());
-    return index <= -IndexOffset;
+    return index <= FirstSymbolicIndex;
   }
   }
   // Returns whether this represents a template constant. Requires is_valid.
   // Returns whether this represents a template constant. Requires is_valid.
   auto is_template() const -> bool {
   auto is_template() const -> bool {
     CARBON_CHECK(is_valid());
     CARBON_CHECK(is_valid());
-    return index >= IndexOffset;
+    return index >= 0;
   }
   }
 
 
   auto Print(llvm::raw_ostream& out) const -> void {
   auto Print(llvm::raw_ostream& out) const -> void {
     if (!is_valid()) {
     if (!is_valid()) {
       IdBase::Print(out);
       IdBase::Print(out);
     } else if (is_template()) {
     } else if (is_template()) {
-      out << "template " << inst_id();
+      out << "template " << template_inst_id();
     } else if (is_symbolic()) {
     } else if (is_symbolic()) {
-      out << "symbolic " << inst_id();
+      out << "symbolic " << symbolic_index();
     } else {
     } else {
       out << "runtime";
       out << "runtime";
     }
     }
@@ -155,18 +159,23 @@ struct ConstantId : public IdBase, public Printable<ConstantId> {
   // logic here. LLVM should still optimize this.
   // logic here. LLVM should still optimize this.
   static constexpr auto Abs(int32_t i) -> int32_t { return i > 0 ? i : -i; }
   static constexpr auto Abs(int32_t i) -> int32_t { return i > 0 ? i : -i; }
 
 
-  // Returns the instruction that describes this constant value, or
-  // InstId::Invalid for a runtime value. This is not part of the public
-  // interface of `ConstantId`. Use `ConstantValueStore::GetInstId` to get the
+  // Returns the instruction that describes this template constant value.
+  // Requires `is_template()`. Use `ConstantValueStore::GetInstId` to get the
   // instruction ID of a `ConstantId`.
   // instruction ID of a `ConstantId`.
-  constexpr auto inst_id() const -> InstId {
-    CARBON_CHECK(is_valid());
-    return InstId(Abs(index) - IndexOffset);
+  constexpr auto template_inst_id() const -> InstId {
+    CARBON_CHECK(is_template());
+    return InstId(index);
+  }
+
+  // Returns the symbolic constant index that describes this symbolic constant
+  // value. Requires `is_symbolic()`.
+  constexpr auto symbolic_index() const -> int32_t {
+    CARBON_CHECK(is_symbolic());
+    return FirstSymbolicIndex - index;
   }
   }
 
 
   static constexpr int32_t NotConstantIndex = InvalidIndex - 1;
   static constexpr int32_t NotConstantIndex = InvalidIndex - 1;
-  // The offset of InstId indices to ConstantId indices.
-  static constexpr int32_t IndexOffset = -NotConstantIndex + 1;
+  static constexpr int32_t FirstSymbolicIndex = InvalidIndex - 2;
 };
 };
 
 
 constexpr ConstantId ConstantId::NotConstant = ConstantId(NotConstantIndex);
 constexpr ConstantId ConstantId::NotConstant = ConstantId(NotConstantIndex);
@@ -319,6 +328,62 @@ struct GenericInstanceId : public IdBase, public Printable<GenericInstanceId> {
 constexpr GenericInstanceId GenericInstanceId::Invalid =
 constexpr GenericInstanceId GenericInstanceId::Invalid =
     GenericInstanceId(InvalidIndex);
     GenericInstanceId(InvalidIndex);
 
 
+// The index of an instruction that depends on generic parameters within a
+// generic, and the value of that instruction within the instances of that
+// generic. This is a pair of a region and an index, stored in 32 bits.
+struct GenericInstIndex : public IndexBase, public Printable<GenericInstIndex> {
+  // Where the value is first used within the generic.
+  enum Region : uint8_t {
+    // In the declaration.
+    Declaration,
+    // In the definition.
+    Definition,
+  };
+
+  // An explicitly invalid index.
+  static const GenericInstIndex Invalid;
+
+  explicit constexpr GenericInstIndex(Region region, int32_t index)
+      : IndexBase(region == Declaration ? index
+                                        : FirstDefinitionIndex - index) {
+    CARBON_CHECK(index >= 0);
+  }
+
+  // Returns the index of the instruction within the region.
+  auto index() const -> int32_t {
+    CARBON_CHECK(is_valid());
+    return IndexBase::index >= 0 ? IndexBase::index
+                                 : FirstDefinitionIndex - IndexBase::index;
+  }
+
+  // Returns the region within which this instruction was first used.
+  auto region() const -> Region {
+    CARBON_CHECK(is_valid());
+    return IndexBase::index >= 0 ? Declaration : Definition;
+  }
+
+  auto Print(llvm::raw_ostream& out) const -> void {
+    out << "genericInst";
+    if (is_valid()) {
+      out << (region() == Declaration ? "InDecl" : "InDef") << index();
+    } else {
+      out << "<invalid>";
+    }
+  }
+
+ private:
+  static constexpr auto MakeInvalid() -> GenericInstIndex {
+    GenericInstIndex result(Declaration, 0);
+    result.IndexBase::index = InvalidIndex;
+    return result;
+  }
+
+  static constexpr int32_t FirstDefinitionIndex = InvalidIndex - 1;
+};
+
+constexpr GenericInstIndex GenericInstIndex::Invalid =
+    GenericInstIndex::MakeInvalid();
+
 // The ID of an IR within the set of imported IRs, both direct and indirect.
 // The ID of an IR within the set of imported IRs, both direct and indirect.
 struct ImportIRId : public IdBase, public Printable<ImportIRId> {
 struct ImportIRId : public IdBase, public Printable<ImportIRId> {
   using ValueType = ImportIR;
   using ValueType = ImportIR;