Просмотр исходного кода

Refactor WhereExpr evaluation into smaller helper functions (#7006)

This splits off the functionality to handle the base facet type,
rewrites, and impls constraints into separate functions.

We use the Context instead of EvalContext throughout, as the goal is to
move this code to EvalConstantInst in time. That means we do not apply
specifics to the functions in the requirements inst block. That is fine
because WhereExpr never evaluates to an WhereExpr, so this instruction
never survives as a constant value long enough to be re-evaluated with a
specific applied to it.
Dana Jansens 3 недель назад
Родитель
Сommit
f483a28f2f
2 измененных файлов с 236 добавлено и 94 удалено
  1. 165 92
      toolchain/check/eval.cpp
  2. 71 2
      toolchain/check/testdata/eval/unexpected_runtime.carbon

+ 165 - 92
toolchain/check/eval.cpp

@@ -439,14 +439,11 @@ static auto GetConstantValue(EvalContext& eval_context, SemIR::InstId inst_id,
 
 // Issue a suitable diagnostic for an instruction that evaluated to a
 // non-constant value but was required to evaluate to a constant.
-static auto DiagnoseNonConstantValue(EvalContext& eval_context,
-                                     SemIR::InstId inst_id) -> void {
-  if (inst_id != SemIR::ErrorInst::InstId) {
-    CARBON_DIAGNOSTIC(EvalRequiresConstantValue, Error,
-                      "expression is runtime; expected constant");
-    eval_context.emitter().Emit(eval_context.GetDiagnosticLoc({inst_id}),
-                                EvalRequiresConstantValue);
-  }
+static auto DiagnoseNonConstantValue(Context& context, SemIR::LocId loc_id)
+    -> void {
+  CARBON_DIAGNOSTIC(EvalRequiresConstantValue, Error,
+                    "expression is runtime; expected constant");
+  context.emitter().Emit(loc_id, EvalRequiresConstantValue);
 }
 
 // Gets a constant value for an `inst_id`, diagnosing when the input is not a
@@ -457,6 +454,11 @@ static auto RequireConstantValue(EvalContext& eval_context,
   if (!inst_id.has_value()) {
     return SemIR::InstId::None;
   }
+  if (inst_id == SemIR::ErrorInst::InstId) {
+    *phase = Phase::UnknownDueToError;
+    return SemIR::ErrorInst::InstId;
+  }
+
   auto const_id = eval_context.GetConstantValue(inst_id);
   *phase =
       LatestPhase(*phase, GetPhase(eval_context.constant_values(), const_id));
@@ -464,7 +466,8 @@ static auto RequireConstantValue(EvalContext& eval_context,
     return eval_context.constant_values().GetInstId(const_id);
   }
 
-  DiagnoseNonConstantValue(eval_context, inst_id);
+  DiagnoseNonConstantValue(eval_context.context(),
+                           eval_context.GetDiagnosticLoc({inst_id}));
   *phase = Phase::UnknownDueToError;
   return SemIR::ErrorInst::InstId;
 }
@@ -2547,6 +2550,124 @@ static auto IsSameFacetValue(Context& context, SemIR::ConstantId const_id,
   return canon_const_id == context.constant_values().Get(facet_value_inst_id);
 }
 
+static auto AddRequirementBase(Context& context,
+                               SemIR::RequirementBaseFacetType base,
+                               SemIR::FacetTypeInfo* info, Phase* phase)
+    -> void {
+  auto base_type_inst_id =
+      context.constant_values().GetConstantTypeInstId(base.base_type_inst_id);
+  if (base_type_inst_id == SemIR::ErrorInst::TypeInstId) {
+    *phase = Phase::UnknownDueToError;
+    return;
+  }
+
+  if (auto base_facet_type =
+          context.insts().TryGetAs<SemIR::FacetType>(base_type_inst_id)) {
+    const auto& base_info =
+        context.facet_types().Get(base_facet_type->facet_type_id);
+    info->extend_constraints.append(base_info.extend_constraints);
+    info->self_impls_constraints.append(base_info.self_impls_constraints);
+    info->type_impls_interfaces.append(base_info.type_impls_interfaces);
+    info->type_impls_named_constraints.append(
+        base_info.type_impls_named_constraints);
+    info->rewrite_constraints.append(base_info.rewrite_constraints);
+    info->other_requirements |= base_info.other_requirements;
+  }
+}
+
+static auto AddRequirementRewrite(Context& context,
+                                  SemIR::RequirementRewrite rewrite,
+                                  SemIR::FacetTypeInfo* info, Phase* phase)
+    -> void {
+  auto lhs_id = context.constant_values().GetConstantInstId(rewrite.lhs_id);
+  auto rhs_id = context.constant_values().GetConstantInstId(rewrite.rhs_id);
+  if (lhs_id == SemIR::ErrorInst::InstId ||
+      rhs_id == SemIR::ErrorInst::InstId) {
+    *phase = Phase::UnknownDueToError;
+    return;
+  }
+  if (!rhs_id.has_value()) {
+    // The RHS may be an arbitrary expression, which means it could have a
+    // runtime value, which we reject since we can't evaluate that.
+    DiagnoseNonConstantValue(context, SemIR::LocId(rewrite.rhs_id));
+    *phase = Phase::UnknownDueToError;
+    return;
+  }
+
+  // The FacetTypeInfo must hold canonical IDs for constant comparison, yet here
+  // we must insert the non-canonical IDs:
+  // * Rewrite constraints are resolved once the FacetTypeInfo is fully
+  //   constructed in order to produce the constant value of the facet type.
+  //   That resolution step needs the non-canonical insts to do its job
+  //   correctly. For instance, the LHS may be a `ImplWitnessAccessSubstituted`
+  //   instruction which preserves which element in the witness is being
+  //   assigned to but evaluates to the RHS of some other rewrite. So the
+  //   constant value would be incorrect to use.
+  // * We use the id of the non-canonical RHS instruction as a hint to order
+  //   diagnostics in the resolution of rewrites, so that they can usually refer
+  //   to the rewrites in the same order as they are written in the code. Using
+  //   the constant value of the RHS reorders the diagnostics in a worse way.
+  // * The final step of constructing the facet type from the WhereExpr
+  //   canonicalizes all the instructions, so we don't need to store canonical
+  //   values here. We only need to use canonical values if we need to observe
+  //   the constant value, such as to determine in the RHS has a runtime value
+  //   above.
+  info->rewrite_constraints.push_back(
+      {.lhs_id = rewrite.lhs_id, .rhs_id = rewrite.rhs_id});
+}
+
+static auto AddRequirementImpls(Context& context, SemIR::RequirementImpls impls,
+                                SemIR::InstId period_self_id,
+                                SemIR::FacetTypeInfo* info, Phase* phase)
+    -> void {
+  auto lhs_id = context.constant_values().GetConstantInstId(impls.lhs_id);
+  auto rhs_id = context.constant_values().GetConstantInstId(impls.rhs_id);
+  if (lhs_id == SemIR::ErrorInst::InstId ||
+      rhs_id == SemIR::ErrorInst::InstId) {
+    *phase = Phase::UnknownDueToError;
+    return;
+  }
+
+  if (rhs_id == SemIR::TypeType::TypeInstId) {
+    // `<type> impls type` -> nothing to do.
+    return;
+  }
+
+  if (IsSameFacetValue(context, context.constant_values().Get(lhs_id),
+                       period_self_id)) {
+    auto facet_type = context.insts().GetAs<SemIR::FacetType>(rhs_id);
+    const auto& more_info = context.facet_types().Get(facet_type.facet_type_id);
+    // The way to prevent lookup into the interface requirements of a
+    // facet type is to put it to the right of a `.Self impls`, which we
+    // accomplish by putting them into `self_impls_constraints`.
+    llvm::append_range(info->self_impls_constraints,
+                       more_info.extend_constraints);
+    llvm::append_range(info->self_impls_constraints,
+                       more_info.self_impls_constraints);
+    llvm::append_range(info->self_impls_named_constraints,
+                       more_info.extend_named_constraints);
+    llvm::append_range(info->self_impls_named_constraints,
+                       more_info.self_impls_named_constraints);
+    // If `.Self impls Z` and Z implies `C impls Y`, then the facet type
+    // of `.Self` also knows `C impls Y`.
+    llvm::append_range(info->type_impls_interfaces,
+                       more_info.type_impls_interfaces);
+    llvm::append_range(info->type_impls_named_constraints,
+                       more_info.type_impls_named_constraints);
+    // Other requirements are copied in.
+    llvm::append_range(info->rewrite_constraints,
+                       more_info.rewrite_constraints);
+    info->other_requirements |= more_info.other_requirements;
+    return;
+  }
+
+  // TODO: Handle `impls` constraints beyond `.Self impls`.
+  info->other_requirements = true;
+}
+
+// Add the constraints from the WhereExpr instruction into a FacetTypeInfo in
+// order to construct a FacetType constant value.
+//
 // TODO: Convert this to an EvalConstantInst function. This will require
 // providing a `GetConstantValue` overload for a requirement block.
 template <>
@@ -2562,88 +2683,39 @@ auto TryEvalTypedInst<SemIR::WhereExpr>(EvalContext& eval_context,
     return SemIR::ErrorInst::ConstantId;
   }
 
-  // Add the constraints from the `WhereExpr` instruction into `info`.
-  if (typed_inst.requirements_id.has_value()) {
-    auto insts = eval_context.inst_blocks().Get(typed_inst.requirements_id);
-    // Note that these requirement instructions don't have a constant value, but
-    // they contain only canonical instructions.
-    for (auto inst_id : insts) {
-      if (auto base =
-              eval_context.insts().TryGetAs<SemIR::RequirementBaseFacetType>(
-                  inst_id)) {
-        if (base->base_type_inst_id == SemIR::ErrorInst::TypeInstId) {
-          return SemIR::ErrorInst::ConstantId;
-        }
+  // Note that these requirement instructions don't have a constant value. That
+  // means we have to look for errors inside them, we can't just look to see if
+  // their constant value is an error.
+  for (auto inst_id :
+       eval_context.inst_blocks().GetOrEmpty(typed_inst.requirements_id)) {
+    if (phase == Phase::UnknownDueToError) {
+      // Abandon ship to save work once we've encountered an error.
+      return SemIR::ErrorInst::ConstantId;
+    }
 
-        if (auto base_facet_type =
-                eval_context.insts().TryGetAs<SemIR::FacetType>(
-                    base->base_type_inst_id)) {
-          const auto& base_info =
-              eval_context.facet_types().Get(base_facet_type->facet_type_id);
-          info.extend_constraints.append(base_info.extend_constraints);
-          info.self_impls_constraints.append(base_info.self_impls_constraints);
-          info.type_impls_interfaces.append(base_info.type_impls_interfaces);
-          info.type_impls_named_constraints.append(
-              base_info.type_impls_named_constraints);
-          info.rewrite_constraints.append(base_info.rewrite_constraints);
-          info.other_requirements |= base_info.other_requirements;
-        }
-      } else if (auto rewrite =
-                     eval_context.insts().TryGetAs<SemIR::RequirementRewrite>(
-                         inst_id)) {
-        info.rewrite_constraints.push_back(
-            {.lhs_id = rewrite->lhs_id, .rhs_id = rewrite->rhs_id});
-      } else if (auto impls =
-                     eval_context.insts().TryGetAs<SemIR::RequirementImpls>(
-                         inst_id)) {
-        SemIR::ConstantId lhs_const_id =
-            eval_context.GetConstantValue(impls->lhs_id);
-        SemIR::ConstantId rhs_const_id =
-            eval_context.GetConstantValue(impls->rhs_id);
-        if (IsSameFacetValue(eval_context.context(), lhs_const_id,
-                             typed_inst.period_self_id)) {
-          auto rhs_inst_id =
-              eval_context.constant_values().GetInstId(rhs_const_id);
-          if (rhs_inst_id == SemIR::ErrorInst::InstId) {
-            // `.Self impls <error>`.
-            return SemIR::ErrorInst::ConstantId;
-          } else if (rhs_inst_id == SemIR::TypeType::TypeInstId) {
-            // `.Self impls type` -> nothing to do.
-          } else {
-            auto facet_type = eval_context.insts().GetAs<SemIR::FacetType>(
-                RequireConstantValue(eval_context, impls->rhs_id, &phase));
-            const auto& more_info =
-                eval_context.facet_types().Get(facet_type.facet_type_id);
-            // The way to prevent lookup into the interface requirements of a
-            // facet type is to put it to the right of a `.Self impls`, which we
-            // accomplish by putting them into `self_impls_constraints`.
-            llvm::append_range(info.self_impls_constraints,
-                               more_info.extend_constraints);
-            llvm::append_range(info.self_impls_constraints,
-                               more_info.self_impls_constraints);
-            llvm::append_range(info.self_impls_named_constraints,
-                               more_info.extend_named_constraints);
-            llvm::append_range(info.self_impls_named_constraints,
-                               more_info.self_impls_named_constraints);
-            // If `.Self impls Z` and Z implies `C impls Y`, then the facet type
-            // of `.Self` also knows `C impls Y`.
-            llvm::append_range(info.type_impls_interfaces,
-                               more_info.type_impls_interfaces);
-            llvm::append_range(info.type_impls_named_constraints,
-                               more_info.type_impls_named_constraints);
-            // Other requirements are copied in.
-            llvm::append_range(info.rewrite_constraints,
-                               more_info.rewrite_constraints);
-            info.other_requirements |= more_info.other_requirements;
-          }
-        } else {
-          // TODO: Handle `impls` constraints beyond `.Self impls`.
-          info.other_requirements = true;
-        }
-      } else {
-        // TODO: Handle other requirements.
+    auto inst = eval_context.insts().Get(inst_id);
+    CARBON_KIND_SWITCH(inst) {
+      case CARBON_KIND(SemIR::RequirementBaseFacetType base): {
+        AddRequirementBase(eval_context.context(), base, &info, &phase);
+        break;
+      }
+      case CARBON_KIND(SemIR::RequirementRewrite rewrite): {
+        AddRequirementRewrite(eval_context.context(), rewrite, &info, &phase);
+        break;
+      }
+      case CARBON_KIND(SemIR::RequirementImpls impls): {
+        AddRequirementImpls(eval_context.context(), impls,
+                            typed_inst.period_self_id, &info, &phase);
+        break;
+      }
+      case CARBON_KIND(SemIR::RequirementEquivalent _): {
+        // TODO: Handle equality requirements.
         info.other_requirements = true;
+        break;
       }
+      default:
+        CARBON_FATAL("unexpected inst {0} in WhereExpr requirements block",
+                     inst);
     }
   }
 
@@ -2793,13 +2865,14 @@ class FunctionExecContext : public EvalContext {
 static auto HandleExecResult(FunctionExecContext& eval_context,
                              SemIR::InstId inst_id, SemIR::ConstantId const_id)
     -> SemIR::ConstantId {
-  if (!const_id.has_value() || !const_id.is_constant()) {
-    DiagnoseNonConstantValue(eval_context, inst_id);
-    return SemIR::ErrorInst::ConstantId;
-  }
   if (const_id == SemIR::ErrorInst::ConstantId) {
     return const_id;
   }
+  if (!const_id.has_value() || !const_id.is_constant()) {
+    DiagnoseNonConstantValue(eval_context.context(),
+                             eval_context.GetDiagnosticLoc(inst_id));
+    return SemIR::ErrorInst::ConstantId;
+  }
   eval_context.locals().Update(inst_id, const_id);
   return SemIR::ConstantId::None;
 }

+ 71 - 2
toolchain/check/testdata/eval/unexpected_runtime.carbon

@@ -10,7 +10,8 @@
 // TIP: To dump output, run:
 // TIP:   bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/eval/unexpected_runtime.carbon
 
-// --- fail_unexpected_runtime.carbon
+// --- fail_unexpected_runtime_base.carbon
+library "[[@TEST_NAME]]";
 
 var x: type;
 
@@ -19,9 +20,77 @@ interface Z {
 }
 
 class D {
-  // CHECK:STDERR: fail_unexpected_runtime.carbon:[[@LINE+4]]:31: error: expression is runtime; expected constant [EvalRequiresConstantValue]
+  // CHECK:STDERR: fail_unexpected_runtime_base.carbon:[[@LINE+4]]:18: error: cannot evaluate type expression [TypeExprEvaluationFailure]
+  // CHECK:STDERR:   extend impl as x where .T = () {}
+  // CHECK:STDERR:                  ^
+  // CHECK:STDERR:
+  extend impl as x where .T = () {}
+}
+
+// --- fail_unexpected_runtime_rewrite_lhs.carbon
+library "[[@TEST_NAME]]";
+
+var x: type;
+
+interface Z {
+  let T:! type;
+}
+
+class D {
+  // CHECK:STDERR: fail_unexpected_runtime_rewrite_lhs.carbon:[[@LINE+8]]:18: error: semantics TODO: `handle invalid parse trees in `check`` [SemanticsTodo]
+  // CHECK:STDERR:   extend impl as Z where x = () {}
+  // CHECK:STDERR:                  ^~~~~~~~~
+  // CHECK:STDERR:
+  // CHECK:STDERR: fail_unexpected_runtime_rewrite_lhs.carbon:[[@LINE+4]]:28: error: requirement can only use `=` after `.member` designator [RequirementEqualAfterNonDesignator]
+  // CHECK:STDERR:   extend impl as Z where x = () {}
+  // CHECK:STDERR:                            ^
+  // CHECK:STDERR:
+  extend impl as Z where x = () {}
+}
+
+// --- fail_unexpected_runtime_rewrite_rhs.carbon
+library "[[@TEST_NAME]]";
+
+var x: type;
+
+interface Z {
+  let T:! type;
+}
+
+class D {
+  // CHECK:STDERR: fail_unexpected_runtime_rewrite_rhs.carbon:[[@LINE+4]]:31: error: expression is runtime; expected constant [EvalRequiresConstantValue]
   // CHECK:STDERR:   extend impl as Z where .T = x {}
   // CHECK:STDERR:                               ^
   // CHECK:STDERR:
   extend impl as Z where .T = x {}
 }
+
+// --- fail_unexpected_runtime_impls_lhs.carbon
+library "[[@TEST_NAME]]";
+
+var x: type;
+
+interface Z {}
+
+class D {
+  // CHECK:STDERR: fail_unexpected_runtime_impls_lhs.carbon:[[@LINE+4]]:29: error: cannot evaluate type expression [TypeExprEvaluationFailure]
+  // CHECK:STDERR:   extend impl as type where x impls Z {}
+  // CHECK:STDERR:                             ^
+  // CHECK:STDERR:
+  extend impl as type where x impls Z {}
+}
+
+// --- fail_unexpected_runtime_impls_rhs.carbon
+library "[[@TEST_NAME]]";
+
+var x: type;
+
+interface Z {}
+
+class D {
+  // CHECK:STDERR: fail_unexpected_runtime_impls_rhs.carbon:[[@LINE+4]]:41: error: cannot evaluate type expression [TypeExprEvaluationFailure]
+  // CHECK:STDERR:   extend impl as type where .Self impls x {}
+  // CHECK:STDERR:                                         ^
+  // CHECK:STDERR:
+  extend impl as type where .Self impls x {}
+}