Parcourir la source

Rework validation of require constraints (#7081)

Use constant values in code that can work with a canonical value
(`TypeStructureReferencesSelf`).

Give explicit location ids for the full require decl and the constraint
to `ValidateRequire`.

Go directly from `InstId` to `TypeId` in `ValidateRequire`.
Correctly/explicitly handle constraint instructions which are not types
instead of calling `SemIR::TypeId::ForTypeConstant` and hoping for the
best.

Leave a clear spot where we will subst `.Self` out of the contraint.
Dana Jansens il y a 1 semaine
Parent
commit
b8bc7bffd2
3 fichiers modifiés avec 54 ajouts et 33 suppressions
  1. 35 33
      toolchain/check/handle_require.cpp
  2. 5 0
      toolchain/sem_ir/type.cpp
  3. 14 0
      toolchain/sem_ir/type.h

+ 35 - 33
toolchain/check/handle_require.cpp

@@ -104,7 +104,7 @@ auto HandleParseNode(Context& context, Parse::RequireTypeImplsId node_id)
 }
 }
 
 
 static auto TypeStructureReferencesSelf(
 static auto TypeStructureReferencesSelf(
-    Context& context, SemIR::LocId loc_id, SemIR::TypeInstId inst_id,
+    Context& context, SemIR::LocId loc_id, SemIR::ConstantId const_id,
     const SemIR::IdentifiedFacetType& identified_facet_type) -> bool {
     const SemIR::IdentifiedFacetType& identified_facet_type) -> bool {
   auto find_self = [&](SemIR::TypeIterator& type_iter) -> bool {
   auto find_self = [&](SemIR::TypeIterator& type_iter) -> bool {
     while (true) {
     while (true) {
@@ -133,7 +133,7 @@ static auto TypeStructureReferencesSelf(
 
 
   {
   {
     SemIR::TypeIterator type_iter(&context.sem_ir());
     SemIR::TypeIterator type_iter(&context.sem_ir());
-    type_iter.Add(context.constant_values().GetConstantTypeInstId(inst_id));
+    type_iter.Add(context.constant_values().GetInstId(const_id));
     if (find_self(type_iter)) {
     if (find_self(type_iter)) {
       return true;
       return true;
     }
     }
@@ -172,40 +172,37 @@ static auto TypeStructureReferencesSelf(
 }
 }
 
 
 struct ValidateRequireResult {
 struct ValidateRequireResult {
-  // The TypeId of a FacetType.
-  SemIR::TypeId constraint_type_id;
   const SemIR::IdentifiedFacetType* identified_facet_type;
   const SemIR::IdentifiedFacetType* identified_facet_type;
 };
 };
 
 
 // Returns nullopt if a diagnostic has been emitted and the `require` decl is
 // Returns nullopt if a diagnostic has been emitted and the `require` decl is
 // not valid.
 // not valid.
-static auto ValidateRequire(Context& context, SemIR::LocId loc_id,
-                            SemIR::TypeInstId self_inst_id,
+static auto ValidateRequire(Context& context, SemIR::LocId full_require_loc_id,
+                            SemIR::LocId constraint_loc_id,
+                            SemIR::InstId self_inst_id,
                             SemIR::InstId constraint_inst_id,
                             SemIR::InstId constraint_inst_id,
                             SemIR::InstId scope_inst_id)
                             SemIR::InstId scope_inst_id)
     -> std::optional<ValidateRequireResult> {
     -> std::optional<ValidateRequireResult> {
-  auto self_constant_value_id = context.constant_values().Get(self_inst_id);
-  auto constraint_constant_value_id =
-      context.constant_values().Get(constraint_inst_id);
+  auto self_type_id = context.types().GetTypeIdForTypeInstId(self_inst_id);
+  auto constraint_type_id =
+      context.types().TryGetTypeIdForTypeInstId(constraint_inst_id);
 
 
-  if (self_constant_value_id == SemIR::ErrorInst::ConstantId ||
-      constraint_constant_value_id == SemIR::ErrorInst::ConstantId ||
+  if (self_type_id == SemIR::ErrorInst::TypeId ||
+      constraint_type_id == SemIR::ErrorInst::TypeId ||
       scope_inst_id == SemIR::ErrorInst::InstId) {
       scope_inst_id == SemIR::ErrorInst::InstId) {
     // An error was already diagnosed, don't diagnose another. We can't build a
     // An error was already diagnosed, don't diagnose another. We can't build a
     // useful `require` with an error, it couldn't do anything.
     // useful `require` with an error, it couldn't do anything.
     return std::nullopt;
     return std::nullopt;
   }
   }
 
 
-  auto constraint_type_id =
-      SemIR::TypeId::ForTypeConstant(constraint_constant_value_id);
   auto constraint_facet_type =
   auto constraint_facet_type =
-      context.types().TryGetAs<SemIR::FacetType>(constraint_type_id);
+      context.types().TryGetAsIfValid<SemIR::FacetType>(constraint_type_id);
   if (!constraint_facet_type) {
   if (!constraint_facet_type) {
     CARBON_DIAGNOSTIC(
     CARBON_DIAGNOSTIC(
         RequireImplsMissingFacetType, Error,
         RequireImplsMissingFacetType, Error,
         "`require` declaration constrained by a non-facet type; "
         "`require` declaration constrained by a non-facet type; "
         "expected an `interface` or `constraint` name after `impls`");
         "expected an `interface` or `constraint` name after `impls`");
-    context.emitter().Emit(constraint_inst_id, RequireImplsMissingFacetType);
+    context.emitter().Emit(constraint_loc_id, RequireImplsMissingFacetType);
     // Can't continue without a constraint to use.
     // Can't continue without a constraint to use.
     return std::nullopt;
     return std::nullopt;
   }
   }
@@ -218,7 +215,7 @@ static auto ValidateRequire(Context& context, SemIR::LocId loc_id,
     // TODO: Handle other impls named constraints for the
     // TODO: Handle other impls named constraints for the
     // RequireImplsReferenceCycle diagnostic.
     // RequireImplsReferenceCycle diagnostic.
     if (constraint_facet_type_info.other_requirements) {
     if (constraint_facet_type_info.other_requirements) {
-      context.TODO(constraint_inst_id,
+      context.TODO(constraint_loc_id,
                    "facet type has constraints that we don't handle yet");
                    "facet type has constraints that we don't handle yet");
       return std::nullopt;
       return std::nullopt;
     }
     }
@@ -237,7 +234,7 @@ static auto ValidateRequire(Context& context, SemIR::LocId loc_id,
                           "facet type in `require` declaration refers to the "
                           "facet type in `require` declaration refers to the "
                           "named constraint `{0}` from within its definition",
                           "named constraint `{0}` from within its definition",
                           SemIR::NameId);
                           SemIR::NameId);
-        context.emitter().Emit(constraint_inst_id, RequireImplsReferenceCycle,
+        context.emitter().Emit(constraint_loc_id, RequireImplsReferenceCycle,
                                named_constraint.name_id);
                                named_constraint.name_id);
         return std::nullopt;
         return std::nullopt;
       }
       }
@@ -245,14 +242,14 @@ static auto ValidateRequire(Context& context, SemIR::LocId loc_id,
   }
   }
 
 
   auto identified_facet_type_id = RequireIdentifiedFacetType(
   auto identified_facet_type_id = RequireIdentifiedFacetType(
-      context, SemIR::LocId(constraint_inst_id), self_constant_value_id,
+      context, constraint_loc_id, self_type_id.AsConstantId(),
       *constraint_facet_type, [&](auto& builder) {
       *constraint_facet_type, [&](auto& builder) {
         CARBON_DIAGNOSTIC(
         CARBON_DIAGNOSTIC(
             RequireImplsUnidentifiedFacetType, Context,
             RequireImplsUnidentifiedFacetType, Context,
             "facet type {0} cannot be identified in `require` declaration",
             "facet type {0} cannot be identified in `require` declaration",
-            InstIdAsType);
-        builder.Context(constraint_inst_id, RequireImplsUnidentifiedFacetType,
-                        constraint_inst_id);
+            SemIR::TypeId);
+        builder.Context(constraint_loc_id, RequireImplsUnidentifiedFacetType,
+                        constraint_type_id);
       });
       });
   if (!identified_facet_type_id.has_value()) {
   if (!identified_facet_type_id.has_value()) {
     // The constraint can't be used. A diagnostic was emitted by
     // The constraint can't be used. A diagnostic was emitted by
@@ -262,12 +259,12 @@ static auto ValidateRequire(Context& context, SemIR::LocId loc_id,
   const auto& identified =
   const auto& identified =
       context.identified_facet_types().Get(identified_facet_type_id);
       context.identified_facet_types().Get(identified_facet_type_id);
 
 
-  if (!TypeStructureReferencesSelf(context, loc_id, self_inst_id, identified)) {
+  if (!TypeStructureReferencesSelf(context, full_require_loc_id,
+                                   self_type_id.AsConstantId(), identified)) {
     return std::nullopt;
     return std::nullopt;
   }
   }
 
 
-  return ValidateRequireResult{.constraint_type_id = constraint_type_id,
-                               .identified_facet_type = &identified};
+  return ValidateRequireResult{.identified_facet_type = &identified};
 }
 }
 
 
 auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
 auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
@@ -287,8 +284,9 @@ auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
   auto scope_inst_id =
   auto scope_inst_id =
       context.node_stack().Pop<Parse::NodeKind::RequireIntroducer>();
       context.node_stack().Pop<Parse::NodeKind::RequireIntroducer>();
 
 
-  auto validated = ValidateRequire(context, node_id, self_inst_id,
-                                   constraint_inst_id, scope_inst_id);
+  auto validated =
+      ValidateRequire(context, node_id, constraint_node_id, self_inst_id,
+                      constraint_inst_id, scope_inst_id);
   if (!validated) {
   if (!validated) {
     // In an `extend` decl, errors get propagated into the parent scope just as
     // In an `extend` decl, errors get propagated into the parent scope just as
     // names do.
     // names do.
@@ -300,7 +298,7 @@ auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
     return true;
     return true;
   }
   }
 
 
-  auto [constraint_type_id, identified_facet_type] = *validated;
+  auto [identified_facet_type] = *validated;
   if (identified_facet_type->required_impls().empty()) {
   if (identified_facet_type->required_impls().empty()) {
     // A `require T impls type` adds no actual constraints, so nothing to do.
     // A `require T impls type` adds no actual constraints, so nothing to do.
     // This is not an error though.
     // This is not an error though.
@@ -308,6 +306,10 @@ auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
     return true;
     return true;
   }
   }
 
 
+  // TODO: Substitute .Self here.
+  auto constraint_type_inst_id =
+      context.types().GetAsTypeInstId(constraint_inst_id);
+
   auto require_impls_decl =
   auto require_impls_decl =
       SemIR::RequireImplsDecl{// To be filled in after.
       SemIR::RequireImplsDecl{// To be filled in after.
                               .require_impls_id = SemIR::RequireImplsId::None,
                               .require_impls_id = SemIR::RequireImplsId::None,
@@ -315,8 +317,7 @@ auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
   auto decl_id = AddPlaceholderInst(context, node_id, require_impls_decl);
   auto decl_id = AddPlaceholderInst(context, node_id, require_impls_decl);
   auto require_impls_id = context.require_impls().Add(
   auto require_impls_id = context.require_impls().Add(
       {.self_id = self_inst_id,
       {.self_id = self_inst_id,
-       .facet_type_inst_id =
-           context.types().GetAsTypeInstId(constraint_inst_id),
+       .facet_type_inst_id = constraint_type_inst_id,
        .extend_self = extend,
        .extend_self = extend,
        .decl_id = decl_id,
        .decl_id = decl_id,
        .parent_scope_id = context.scope_stack().PeekNameScopeId(),
        .parent_scope_id = context.scope_stack().PeekNameScopeId(),
@@ -333,14 +334,15 @@ auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
   // monomorphization errors that result.
   // monomorphization errors that result.
   if (extend) {
   if (extend) {
     if (!RequireCompleteType(
     if (!RequireCompleteType(
-            context, constraint_type_id, SemIR::LocId(constraint_inst_id),
-            [&](auto& builder) {
+            context,
+            context.types().GetTypeIdForTypeInstId(constraint_type_inst_id),
+            constraint_node_id, [&](auto& builder) {
               CARBON_DIAGNOSTIC(RequireImplsIncompleteFacetType, Context,
               CARBON_DIAGNOSTIC(RequireImplsIncompleteFacetType, Context,
                                 "`extend require` of incomplete facet type {0}",
                                 "`extend require` of incomplete facet type {0}",
                                 InstIdAsType);
                                 InstIdAsType);
-              builder.Context(constraint_inst_id,
+              builder.Context(constraint_node_id,
                               RequireImplsIncompleteFacetType,
                               RequireImplsIncompleteFacetType,
-                              constraint_inst_id);
+                              constraint_type_inst_id);
             })) {
             })) {
       return true;
       return true;
     }
     }

+ 5 - 0
toolchain/sem_ir/type.cpp

@@ -57,6 +57,11 @@ auto TypeStore::GetTypeIdForTypeInstId(TypeInstId inst_id) const -> TypeId {
   return TypeId::ForTypeConstant(constant_id);
   return TypeId::ForTypeConstant(constant_id);
 }
 }
 
 
+auto TypeStore::TryGetTypeIdForTypeInstId(InstId inst_id) const -> TypeId {
+  auto constant_id = file_->constant_values().Get(inst_id);
+  return TryGetTypeIdForTypeConstantId(constant_id);
+}
+
 auto TypeStore::GetAsTypeInstId(InstId inst_id) const -> TypeInstId {
 auto TypeStore::GetAsTypeInstId(InstId inst_id) const -> TypeInstId {
   auto constant_id = file_->constant_values().Get(inst_id);
   auto constant_id = file_->constant_values().Get(inst_id);
   CheckTypeOfConstantIsTypeType(*file_, constant_id);
   CheckTypeOfConstantIsTypeType(*file_, constant_id);

+ 14 - 0
toolchain/sem_ir/type.h

@@ -77,6 +77,10 @@ class TypeStore : public Yaml::Printable<TypeStore> {
   auto GetTypeIdForTypeInstId(InstId inst_id) const -> TypeId;
   auto GetTypeIdForTypeInstId(InstId inst_id) const -> TypeId;
   auto GetTypeIdForTypeInstId(TypeInstId inst_id) const -> TypeId;
   auto GetTypeIdForTypeInstId(TypeInstId inst_id) const -> TypeId;
 
 
+  // Like GetTypeIdForTypeInstId() but returns None if the constant is not a
+  // value of type `TypeType`.
+  auto TryGetTypeIdForTypeInstId(InstId inst_id) const -> TypeId;
+
   // Converts an `InstId` to a `TypeInstId` of the same id value. This process
   // Converts an `InstId` to a `TypeInstId` of the same id value. This process
   // involves checking that the type of the instruction's value is `TypeType`,
   // involves checking that the type of the instruction's value is `TypeType`,
   // and then this check is encoded in the type system via `TypeInstId`.
   // and then this check is encoded in the type system via `TypeInstId`.
@@ -137,6 +141,16 @@ class TypeStore : public Yaml::Printable<TypeStore> {
     return GetAsInst(type_id).TryAs<InstT>();
     return GetAsInst(type_id).TryAs<InstT>();
   }
   }
 
 
+  // Like TryGetAs() but also handles the case where `type_id` has no value, and
+  // then returns nullopt.
+  template <typename InstT>
+  auto TryGetAsIfValid(TypeId type_id) const -> std::optional<InstT> {
+    if (!type_id.has_value()) {
+      return {};
+    }
+    return GetAsInst(type_id).TryAs<InstT>();
+  }
+
   // Returns whether two type IDs represent the same type. This includes the
   // Returns whether two type IDs represent the same type. This includes the
   // case where they might be in different generics and thus might have
   // case where they might be in different generics and thus might have
   // different ConstantIds, but are still symbolically equal.
   // different ConstantIds, but are still symbolically equal.