Ver código fonte

Introduce typed-inst accessors for ConstantValueStore (#6980)

Add `InstIs`, `GetInstAs`, and `TryGetInstAs` which act on the
underlying constant instruction in a constant value, to save an explicit
call to `GetInstId`.

```carbon
context.insts().GetAs<InstT>(context.constant_values().GetInstId(const_id))
```
can now be written as simply
```carbon
context.constant_values().GetInstAs<InstT>(const_id)
```

For future work, we might provide `GetInst()` so that
`context.insts().Get(context.constant_values().GetInstId(const_id)` can
be shortened also.
Dana Jansens 1 mês atrás
pai
commit
5503f643c6

+ 2 - 4
toolchain/check/check_unit.cpp

@@ -541,10 +541,8 @@ auto CheckUnit::CheckPoisonedConcreteImplLookupQueries() -> void {
     }
     if (found_witness_id != poison.witness_id) {
       auto witness_to_impl_id = [&](SemIR::ConstantId witness_id) {
-        // TODO: Add and use constant_values().GetAs<SemIR::FacetType>().
-        auto inst_id = context_.constant_values().GetInstId(witness_id);
-        auto table_id = context_.insts()
-                            .GetAs<SemIR::ImplWitness>(inst_id)
+        auto table_id = context_.constant_values()
+                            .GetInstAs<SemIR::ImplWitness>(witness_id)
                             .witness_table_id;
         return context_.insts()
             .GetAs<SemIR::ImplWitnessTable>(table_id)

+ 3 - 2
toolchain/check/control_flow.cpp

@@ -103,8 +103,9 @@ auto SetBlockArgResultBeforeConstantUse(Context& context,
   auto cond_const_id = context.constant_values().Get(cond_id);
   if (!cond_const_id.is_concrete()) {
     // Symbolic or non-constant condition means a non-constant result.
-  } else if (auto literal = context.insts().TryGetAs<SemIR::BoolLiteral>(
-                 context.constant_values().GetInstId(cond_const_id))) {
+  } else if (auto literal =
+                 context.constant_values().TryGetInstAs<SemIR::BoolLiteral>(
+                     cond_const_id)) {
     const_id = context.constant_values().Get(
         literal.value().value.ToBool() ? if_true : if_false);
   } else {

+ 4 - 5
toolchain/check/convert.cpp

@@ -434,10 +434,9 @@ static auto ConvertTupleToType(Context& context, SemIR::LocId loc_id,
 
   llvm::SmallVector<SemIR::InstId> type_inst_ids;
 
-  auto value_const_inst_id =
-      context.constant_values().GetInstId(value_const_id);
   if (auto tuple_value =
-          context.insts().TryGetAs<SemIR::TupleValue>(value_const_inst_id)) {
+          context.constant_values().TryGetInstAs<SemIR::TupleValue>(
+              value_const_id)) {
     for (auto tuple_inst_id :
          context.inst_blocks().Get(tuple_value->elements_id)) {
       // TODO: This call recurses back into conversion. Switch to an
@@ -448,8 +447,8 @@ static auto ConvertTupleToType(Context& context, SemIR::LocId loc_id,
   } else {
     // A value of type TupleType that isn't a TupleValue must be a symbolic
     // binding.
-    CARBON_CHECK(
-        context.insts().Is<SemIR::SymbolicBinding>(value_const_inst_id));
+    CARBON_CHECK(context.constant_values().InstIs<SemIR::SymbolicBinding>(
+        value_const_id));
     // Form a TupleAccess for each element in the symbolic value, which is then
     // converted to a `type` or diagnosed as an error.
     auto tuple_type = context.types().GetAs<SemIR::TupleType>(value_type_id);

+ 2 - 2
toolchain/check/cpp/impl_lookup.cpp

@@ -26,8 +26,8 @@ namespace Carbon::Check {
 static auto TypeAsClassDecl(Context& context,
                             SemIR::ConstantId query_self_const_id)
     -> clang::CXXRecordDecl* {
-  auto self_inst_id = context.constant_values().GetInstId(query_self_const_id);
-  auto class_type = context.insts().TryGetAs<SemIR::ClassType>(self_inst_id);
+  auto class_type = context.constant_values().TryGetInstAs<SemIR::ClassType>(
+      query_self_const_id);
   if (!class_type) {
     // Not a class.
     return nullptr;

+ 2 - 2
toolchain/check/cpp/type_mapping.cpp

@@ -65,8 +65,8 @@ static auto FindIntLiteralBitWidth(Context& context, SemIR::LocId loc_id,
     // TODO: Add tests for these cases.
     return IntId::None;
   }
-  auto arg = context.insts().TryGetAs<SemIR::IntValue>(
-      context.constant_values().GetInstId(arg_const_id));
+  auto arg =
+      context.constant_values().TryGetInstAs<SemIR::IntValue>(arg_const_id);
   if (!arg) {
     return IntId::None;
   }

+ 2 - 2
toolchain/check/eval_inst.cpp

@@ -727,8 +727,8 @@ auto EvalConstantInst(Context& context, SemIR::UnaryOperatorNot inst)
   // All other uses of unary `not` are non-constant.
   auto const_id = context.constant_values().Get(inst.operand_id);
   if (const_id.is_concrete()) {
-    auto value = context.insts().GetAs<SemIR::BoolLiteral>(
-        context.constant_values().GetInstId(const_id));
+    auto value =
+        context.constant_values().GetInstAs<SemIR::BoolLiteral>(const_id);
     value.value = SemIR::BoolValue::From(!value.value.ToBool());
     return ConstantEvalResult::NewSamePhase(value);
   }

+ 3 - 6
toolchain/check/impl.cpp

@@ -729,11 +729,8 @@ auto CheckRequireDeclsSatisfied(Context& context, SemIR::LocId loc_id,
     // requires LookupImplWitness to return a partial result, or take a
     // diagnostic lambda or something.
     if (!result.has_value()) {
-      auto facet_type_inst_id =
-          context.constant_values().GetInstId(facet_type_const_id);
-
       if (!result.has_error_value() &&
-          facet_type_inst_id != SemIR::ErrorInst::InstId) {
+          facet_type_const_id != SemIR::ErrorInst::ConstantId) {
         CARBON_DIAGNOSTIC(RequireImplsNotImplemented, Error,
                           "interface `{0}` being implemented requires that {1} "
                           "implements {2}",
@@ -742,8 +739,8 @@ auto CheckRequireDeclsSatisfied(Context& context, SemIR::LocId loc_id,
         context.emitter().Emit(
             loc_id, RequireImplsNotImplemented, impl.interface,
             context.types().GetTypeIdForTypeConstantId(self_const_id),
-            context.insts()
-                .GetAs<SemIR::FacetType>(facet_type_inst_id)
+            context.constant_values()
+                .GetInstAs<SemIR::FacetType>(facet_type_const_id)
                 .facet_type_id);
       }
     }

+ 11 - 20
toolchain/check/impl_lookup.cpp

@@ -288,16 +288,15 @@ static auto TryGetSpecificWitnessIdForImpl(
   // type: the `I` in `impl ... as I`. The deduction step may be unable to be
   // fully applied to the types in the constraint and result in an error here,
   // in which case it does not match the query.
-  auto deduced_constraint_id =
-      context.constant_values().GetInstId(SemIR::GetConstantValueInSpecific(
-          context.sem_ir(), specific_id, impl.constraint_id));
-  if (deduced_constraint_id == SemIR::ErrorInst::InstId) {
+  auto deduced_constraint_id = SemIR::GetConstantValueInSpecific(
+      context.sem_ir(), specific_id, impl.constraint_id);
+  if (deduced_constraint_id == SemIR::ErrorInst::ConstantId) {
     return SemIR::ConstantId::None;
   }
 
   auto deduced_constraint_facet_type_id =
-      context.insts()
-          .GetAs<SemIR::FacetType>(deduced_constraint_id)
+      context.constant_values()
+          .GetInstAs<SemIR::FacetType>(deduced_constraint_id)
           .facet_type_id;
   const auto& deduced_constraint_facet_type_info =
       context.facet_types().Get(deduced_constraint_facet_type_id);
@@ -498,14 +497,9 @@ static auto VerifyQueryFacetTypeConstraints(
     SemIR::ConstantId query_facet_type_const_id,
     llvm::ArrayRef<SemIR::IdentifiedFacetType::RequiredImpl> req_impls,
     llvm::ArrayRef<SemIR::InstId> witness_inst_ids) -> bool {
-  SemIR::InstId query_facet_type_inst_id =
-      context.constant_values().GetInstId(query_facet_type_const_id);
-
-  CARBON_CHECK(context.insts().Is<SemIR::FacetType>(query_facet_type_inst_id));
-
   const auto& facet_type_info = context.facet_types().Get(
-      context.insts()
-          .GetAs<SemIR::FacetType>(query_facet_type_inst_id)
+      context.constant_values()
+          .GetInstAs<SemIR::FacetType>(query_facet_type_const_id)
           .facet_type_id);
 
   if (!facet_type_info.rewrite_constraints.empty()) {
@@ -803,11 +797,8 @@ static auto FindFinalWitnessFromSelfFacetValue(
     Context& context, SemIR::ConstantId query_self_const_id,
     SemIR::IdentifiedFacetTypeId query_self_type_identified_id,
     SemIR::SpecificInterface query_specific_interface) -> SemIR::InstId {
-  // TODO: Add and use constant_values().GetAs<SemIR::FacetType>().
-  auto query_self_inst_id =
-      context.constant_values().GetInstId(query_self_const_id);
-  auto facet_value =
-      context.insts().TryGetAs<SemIR::FacetValue>(query_self_inst_id);
+  auto facet_value = context.constant_values().TryGetInstAs<SemIR::FacetValue>(
+      query_self_const_id);
   if (!facet_value) {
     return SemIR::InstId::None;
   }
@@ -912,8 +903,8 @@ auto LookupImplWitness(Context& context, SemIR::LocId loc_id,
     CARBON_CHECK((context.types().IsOneOf<SemIR::TypeType, SemIR::FacetType>(
         query_self_type_id)));
     // The query facet type value is indeed a facet type.
-    CARBON_CHECK(context.insts().Is<SemIR::FacetType>(
-        context.constant_values().GetInstId(query_facet_type_const_id)));
+    CARBON_CHECK(context.constant_values().InstIs<SemIR::FacetType>(
+        query_facet_type_const_id));
   }
 
   auto req_impls_from_constraint =

+ 25 - 26
toolchain/check/import_ref.cpp

@@ -1633,11 +1633,9 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
   } else {
     // In the third phase, compute the associated constant ID from the constant
     // value of the declaration.
-    assoc_const_id =
-        resolver.local_insts()
-            .GetAs<SemIR::AssociatedConstantDecl>(
-                resolver.local_constant_values().GetInstId(const_id))
-            .assoc_const_id;
+    assoc_const_id = resolver.local_constant_values()
+                         .GetInstAs<SemIR::AssociatedConstantDecl>(const_id)
+                         .assoc_const_id;
   }
 
   // Load the values to populate the entity with.
@@ -2634,8 +2632,9 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
   } else {
     // On the third phase, compute the impl ID from the "constant value" of
     // the declaration, which is a reference to the created ImplDecl.
-    auto impl_const_inst = resolver.local_insts().GetAs<SemIR::ImplDecl>(
-        resolver.local_constant_values().GetInstId(impl_const_id));
+    auto impl_const_inst =
+        resolver.local_constant_values().GetInstAs<SemIR::ImplDecl>(
+            impl_const_id);
     impl_id = impl_const_inst.impl_id;
   }
 
@@ -2669,8 +2668,9 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
       resolver, import_impl.interface, specific_interface_data);
   // Create a local IdentifiedFacetType for the imported facet type, since impl
   // declarations always identify the facet type.
-  if (auto facet_type = resolver.local_insts().TryGetAs<SemIR::FacetType>(
-          resolver.local_constant_values().GetInstId(constraint_const_id))) {
+  if (auto facet_type =
+          resolver.local_constant_values().TryGetInstAs<SemIR::FacetType>(
+              constraint_const_id)) {
     // Lookups later will be with the unattached constant, whereas
     // GetLocalConstantId gave us an attached constant.
     auto unattached_self_const_id =
@@ -2968,17 +2968,17 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
     // Get the local interface ID from the constant value of the interface decl,
     // which is either a GenericInterfaceType (if generic) or a FacetType (if
     // not).
-    auto interface_const_inst_id =
-        resolver.local_constant_values().GetInstId(interface_const_id);
-    if (auto struct_value = resolver.local_insts().TryGetAs<SemIR::StructValue>(
-            interface_const_inst_id)) {
+    if (auto struct_value =
+            resolver.local_constant_values().TryGetInstAs<SemIR::StructValue>(
+                interface_const_id)) {
       auto generic_interface_type =
           resolver.local_types().GetAs<SemIR::GenericInterfaceType>(
               struct_value->type_id);
       local_interface_id = generic_interface_type.interface_id;
     } else {
-      auto local_facet_type = resolver.local_insts().GetAs<SemIR::FacetType>(
-          interface_const_inst_id);
+      auto local_facet_type =
+          resolver.local_constant_values().GetInstAs<SemIR::FacetType>(
+              interface_const_id);
       const auto& local_facet_type_info =
           resolver.local_facet_types().Get(local_facet_type.facet_type_id);
       auto single_interface = *local_facet_type_info.TryAsSingleExtend();
@@ -3014,8 +3014,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
   } else {
     // On the third phase, get the interface, decl and generic IDs from the
     // constant value of the decl (which is itself) from the second phase.
-    auto decl = resolver.local_insts().GetAs<SemIR::InterfaceWithSelfDecl>(
-        resolver.local_constant_values().GetInstId(const_id));
+    auto decl = resolver.local_constant_values()
+                    .GetInstAs<SemIR::InterfaceWithSelfDecl>(const_id);
     local_interface_id = decl.interface_id;
     generic_with_self_id = resolver.local_interfaces()
                                .Get(local_interface_id)
@@ -3261,17 +3261,17 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
     // Get the local named constraint ID from the constant value of the named
     // constraint decl, which is either a GenericNamedConstraintType (if
     // generic) or a FacetType (if not).
-    auto constraint_const_inst_id =
-        resolver.local_constant_values().GetInstId(constraint_const_id);
-    if (auto struct_value = resolver.local_insts().TryGetAs<SemIR::StructValue>(
-            constraint_const_inst_id)) {
+    if (auto struct_value =
+            resolver.local_constant_values().TryGetInstAs<SemIR::StructValue>(
+                constraint_const_id)) {
       auto generic_constraint_type =
           resolver.local_types().GetAs<SemIR::GenericNamedConstraintType>(
               struct_value->type_id);
       local_constraint_id = generic_constraint_type.named_constraint_id;
     } else {
-      auto local_facet_type = resolver.local_insts().GetAs<SemIR::FacetType>(
-          constraint_const_inst_id);
+      auto local_facet_type =
+          resolver.local_constant_values().GetInstAs<SemIR::FacetType>(
+              constraint_const_id);
       const auto& local_facet_type_info =
           resolver.local_facet_types().Get(local_facet_type.facet_type_id);
       auto single_interface = *local_facet_type_info.TryAsSingleExtend();
@@ -3308,9 +3308,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
   } else {
     // On the third phase, get the interface, decl and generic IDs from the
     // constant value of the decl (which is itself) from the second phase.
-    auto decl =
-        resolver.local_insts().GetAs<SemIR::NamedConstraintWithSelfDecl>(
-            resolver.local_constant_values().GetInstId(const_id));
+    auto decl = resolver.local_constant_values()
+                    .GetInstAs<SemIR::NamedConstraintWithSelfDecl>(const_id);
     local_constraint_id = decl.named_constraint_id;
     generic_with_self_id = resolver.local_named_constraints()
                                .Get(local_constraint_id)

+ 9 - 9
toolchain/check/member_access.cpp

@@ -120,8 +120,9 @@ auto GetHighestAllowedAccess(Context& context, SemIR::LocId loc_id,
   auto self_class_info = context.classes().Get(self_class_type->class_id);
 
   // TODO: Support other types.
-  if (auto class_type = context.insts().TryGetAs<SemIR::ClassType>(
-          context.constant_values().GetInstId(name_scope_const_id))) {
+  if (auto class_type =
+          context.constant_values().TryGetInstAs<SemIR::ClassType>(
+              name_scope_const_id)) {
     auto class_info = context.classes().Get(class_type->class_id);
 
     if (self_class_info.self_type_id == class_info.self_type_id) {
@@ -725,10 +726,9 @@ auto GetAssociatedValue(Context& context, SemIR::LocId loc_id,
   // TODO: This function shares a code with PerformCompoundMemberAccess(),
   // it would be nice to reduce the duplication.
 
-  auto value_inst_id =
-      context.constant_values().GetInstId(assoc_entity_const_id);
   auto assoc_entity =
-      context.insts().GetAs<SemIR::AssociatedEntity>(value_inst_id);
+      context.constant_values().GetInstAs<SemIR::AssociatedEntity>(
+          assoc_entity_const_id);
   auto decl_id = assoc_entity.decl_id;
   LoadImportRef(context, decl_id);
 
@@ -794,8 +794,8 @@ auto PerformCompoundMemberAccess(
       member.type_id() != SemIR::ErrorInst::TypeId) {
     // As a special case, an integer-valued expression can be used as a member
     // name when indexing a tuple.
-    if (context.insts().Is<SemIR::TupleType>(
-            context.constant_values().GetInstId(base_type_const_id))) {
+    if (context.constant_values().InstIs<SemIR::TupleType>(
+            base_type_const_id)) {
       return PerformTupleAccess(context, loc_id, base_id, member_expr_id);
     }
 
@@ -850,8 +850,8 @@ auto PerformTupleAccess(Context& context, SemIR::LocId loc_id,
     return diag_non_constant_index();
   }
 
-  auto index_literal = context.insts().GetAs<SemIR::IntValue>(
-      context.constant_values().GetInstId(index_const_id));
+  auto index_literal =
+      context.constant_values().GetInstAs<SemIR::IntValue>(index_const_id);
   auto type_block = context.inst_blocks().Get(tuple_type->type_elements_id);
   std::optional<llvm::APInt> index_val = ValidateTupleIndex(
       context, loc_id, tuple_inst_id, index_literal, type_block.size());

+ 3 - 4
toolchain/check/name_lookup.cpp

@@ -214,8 +214,8 @@ static auto DiagnoseInvalidQualifiedNameAccess(
     Context& context, SemIR::LocId loc_id, SemIR::LocId member_loc_id,
     SemIR::NameId name_id, SemIR::AccessKind access_kind, bool is_parent_access,
     AccessInfo access_info) -> void {
-  auto class_type = context.insts().TryGetAs<SemIR::ClassType>(
-      context.constant_values().GetInstId(access_info.constant_id));
+  auto class_type = context.constant_values().TryGetInstAs<SemIR::ClassType>(
+      access_info.constant_id);
   if (!class_type) {
     return;
   }
@@ -307,8 +307,7 @@ static auto GetSelfFacetForInterfaceFromLookupSelfType(
     return context.constant_values().Get(self_specific_args.back());
   }
 
-  if (context.insts().Is<SemIR::FacetType>(
-          context.constant_values().GetInstId(self_type_const_id))) {
+  if (context.constant_values().InstIs<SemIR::FacetType>(self_type_const_id)) {
     // We are looking directly in a facet type, like `I.F` for an interface `I`,
     // which means there is no self-type from the lookup for the
     // interface-with-self specific. So the self-type we use is the abstract

+ 2 - 3
toolchain/check/pattern_match.cpp

@@ -488,9 +488,8 @@ auto MatchContext::DoPreWork(Context& context,
     case SemIR::FormParamPattern::Kind: {
       auto form_param_pattern =
           context.insts().GetAs<SemIR::FormParamPattern>(entry.pattern_id);
-      auto form_inst_id =
-          context.constant_values().GetInstId(form_param_pattern.form_id);
-      if (!context.insts().Is<SemIR::InitForm>(form_inst_id)) {
+      if (!context.constant_values().InstIs<SemIR::InitForm>(
+              form_param_pattern.form_id)) {
         break;
       }
       [[fallthrough]];

+ 8 - 12
toolchain/check/type_completion.cpp

@@ -976,12 +976,10 @@ static auto IdentifyFacetType(Context& context, SemIR::LocId loc_id,
           return SemIR::IdentifiedFacetTypeId::None;
         }
 
-        // TODO: Add and use constant_values().GetAs<SemIR::FacetType>().
-        auto facet_type_inst_id =
-            context.constant_values().GetInstId(require_facet_type);
-        auto facet_type_id = context.insts()
-                                 .GetAs<SemIR::FacetType>(facet_type_inst_id)
-                                 .facet_type_id;
+        auto facet_type_id =
+            context.constant_values()
+                .GetInstAs<SemIR::FacetType>(require_facet_type)
+                .facet_type_id;
         bool extend = facet_type_extends && require.extend_self;
         work.push_back({extend, require_self, facet_type_id});
       }
@@ -1035,12 +1033,10 @@ static auto IdentifyFacetType(Context& context, SemIR::LocId loc_id,
           return SemIR::IdentifiedFacetTypeId::None;
         }
 
-        // TODO: Add and use constant_values().GetAs<SemIR::FacetType>().
-        auto facet_type_inst_id =
-            context.constant_values().GetInstId(require_facet_type);
-        auto facet_type_id = context.insts()
-                                 .GetAs<SemIR::FacetType>(facet_type_inst_id)
-                                 .facet_type_id;
+        auto facet_type_id =
+            context.constant_values()
+                .GetInstAs<SemIR::FacetType>(require_facet_type)
+                .facet_type_id;
         work.push_back({false, require_self, facet_type_id});
       }
     }

+ 2 - 3
toolchain/sem_ir/class.cpp

@@ -45,9 +45,8 @@ auto Class::GetObjectRepr(const File& file, SpecificId specific_id) const
     return ErrorInst::TypeId;
   }
   return file.types().GetTypeIdForTypeInstId(
-      file.insts()
-          .GetAs<CompleteTypeWitness>(
-              file.constant_values().GetInstId(witness_id))
+      file.constant_values()
+          .GetInstAs<CompleteTypeWitness>(witness_id)
           .object_repr_type_inst_id);
 }
 

+ 21 - 0
toolchain/sem_ir/constant.h

@@ -181,6 +181,27 @@ class ConstantValueStore {
     return const_id.has_value() ? GetInstId(const_id) : InstId::None;
   }
 
+  // Returns whether the underlying constant inst for the given constant is the
+  // specified type.
+  template <typename InstT>
+  auto InstIs(ConstantId const_id) const -> bool {
+    return insts_->Is<InstT>(GetInstId(const_id));
+  }
+
+  // Returns the requested instruction from the underlying constant inst, which
+  // is known to have the specified type.
+  template <typename InstT>
+  auto GetInstAs(ConstantId const_id) const -> InstT {
+    return insts_->GetAs<InstT>(GetInstId(const_id));
+  }
+
+  // Returns the requested instruction from the underlying constant inst as the
+  // specified type, if it is of the that type.
+  template <typename InstT>
+  auto TryGetInstAs(ConstantId const_id) const -> std::optional<InstT> {
+    return insts_->TryGetAs<InstT>(GetInstId(const_id));
+  }
+
   // Given an instruction, returns the unique constant instruction that is
   // equivalent to it. Returns `None` for a non-constant instruction.
   auto GetConstantInstId(InstId inst_id) const -> InstId {