Explorar el Código

Model thunk call as a pattern match (#6988)

This makes the thunk-call logic more general and more supportable by
reusing the existing pattern-matching logic.

---------

Co-authored-by: Richard Smith <richard@metafoo.co.uk>
Geoff Romer hace 1 mes
padre
commit
47e9d62fd5
Se han modificado 3 ficheros con 106 adiciones y 76 borrados
  1. 64 3
      toolchain/check/pattern_match.cpp
  2. 28 6
      toolchain/check/pattern_match.h
  3. 14 67
      toolchain/check/thunk.cpp

+ 64 - 3
toolchain/check/pattern_match.cpp

@@ -46,7 +46,14 @@ struct CalleeState {
 // State for local pattern matching.
 struct LocalState {};
 
-using State = std::variant<CallerState*, CalleeState*, LocalState*>;
+// State for thunk pattern matching.
+struct ThunkState {
+  // The not-yet-processed `Call` arguments for the outer call.
+  llvm::ArrayRef<SemIR::InstId> outer_call_args;
+};
+
+using State =
+    std::variant<CallerState*, CalleeState*, LocalState*, ThunkState*>;
 
 // The worklist and state machine for a pattern-matching operation.
 //
@@ -356,9 +363,13 @@ auto MatchContext::DoPreWork(State state,
   }
 }
 
-auto MatchContext::DoPostWork(State /*state*/,
+auto MatchContext::DoPostWork(State state,
                               SemIR::AnyBindingPattern binding_pattern,
                               WorkItem entry) -> void {
+  if (std::holds_alternative<ThunkState*>(state)) {
+    // Pass through the result from the subpattern.
+    return;
+  }
   // We're logically consuming this map entry, so we invalidate it in order
   // to avoid accidentally consuming it twice.
   auto [bind_name_id, type_expr_region_id] =
@@ -517,6 +528,19 @@ auto MatchContext::DoPreWork(State state, SemIR::AnyParamPattern param_pattern,
       callee_state->call_param_patterns.push_back(entry.pattern_id);
       break;
     }
+    case CARBON_KIND(ThunkState* thunk_state): {
+      auto param_id = thunk_state->outer_call_args.consume_front();
+      if (auto var_param_pattern =
+              context_.insts().TryGetAs<SemIR::VarParamPattern>(
+                  entry.pattern_id)) {
+        AddWork({.pattern_id = var_param_pattern->subpattern_id,
+                 .work = PreWork{.scrutinee_id = param_id},
+                 .allow_unmarked_ref = entry.allow_unmarked_ref});
+      } else {
+        results_stack_.AppendToTop(param_id);
+      }
+      break;
+    }
     case CARBON_KIND(LocalState* _): {
       CARBON_FATAL("Found ValueParamPattern during local pattern match");
     }
@@ -600,6 +624,9 @@ auto MatchContext::DoVarPreWorkImpl(State state, SemIR::TypeId pattern_type_id,
       // insts.
       return scrutinee_id;
     }
+    case CARBON_KIND(ThunkState* _): {
+      return scrutinee_id;
+    }
     case CARBON_KIND(LocalState* _): {
       // In a `var`/`let` declaration, the `VarStorage` inst is created before
       // we start pattern matching.
@@ -676,7 +703,8 @@ auto MatchContext::DoPreWork(State state, SemIR::TuplePattern tuple_pattern,
         }
       };
   if (!scrutinee_id.has_value()) {
-    CARBON_CHECK(std::holds_alternative<CalleeState*>(state));
+    CARBON_CHECK(std::holds_alternative<CalleeState*>(state) ||
+                 std::holds_alternative<ThunkState*>(state));
     // If we don't have a scrutinee yet, we're still on the caller side of the
     // pattern, so the subpatterns don't have a scrutinee either.
     for (auto subpattern_id : llvm::reverse(subpattern_ids)) {
@@ -745,6 +773,9 @@ auto MatchContext::DoPostWork(State /*state*/,
 
 auto MatchContext::Dispatch(State state, WorkItem entry) -> void {
   if (entry.pattern_id == SemIR::ErrorInst::InstId) {
+    if (need_subpattern_results()) {
+      results_stack_.AppendToTop(SemIR::ErrorInst::InstId);
+    }
     return;
   }
   Diagnostics::AnnotationScope annotate_diagnostics(
@@ -881,6 +912,36 @@ auto CalleePatternMatch(Context& context,
           .param_ranges = {implicit_end, explicit_end, return_end}};
 }
 
+auto ThunkPatternMatch(Context& context, SemIR::InstId self_pattern_id,
+                       SemIR::InstBlockId param_patterns_id,
+                       llvm::ArrayRef<SemIR::InstId> outer_call_args)
+    -> ThunkPatternMatchResults {
+  ThunkState state = {.outer_call_args = outer_call_args};
+  MatchContext match(context);
+
+  llvm::SmallVector<SemIR::InstId> inner_args;
+  inner_args.reserve(outer_call_args.size() + 1);
+
+  if (self_pattern_id.has_value()) {
+    inner_args.push_back(match.MatchWithResult(
+        &state,
+        {.pattern_id = self_pattern_id,
+         .work = MatchContext::PreWork{.scrutinee_id = SemIR::InstId::None}}));
+  }
+
+  if (param_patterns_id.has_value()) {
+    for (SemIR::InstId inst_id : context.inst_blocks().Get(param_patterns_id)) {
+      inner_args.push_back(match.MatchWithResult(
+          &state, {.pattern_id = inst_id,
+                   .work = MatchContext::PreWork{.scrutinee_id =
+                                                     SemIR::InstId::None}}));
+    }
+  }
+
+  return {.syntactic_args = std::move(inner_args),
+          .ignored_call_args = state.outer_call_args};
+}
+
 auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
                         SemIR::InstId self_pattern_id,
                         SemIR::InstBlockId param_patterns_id,

+ 28 - 6
toolchain/check/pattern_match.h

@@ -19,6 +19,14 @@ namespace Carbon::Check {
 // are matched by the callee, and pattern insts that have a `ParamPattern`
 // as a descendant are matched by the caller.
 
+// Return type for CalleePatternMatch.
+struct CalleePatternMatchResults {
+  SemIR::InstBlockId call_param_patterns_id;
+  SemIR::InstBlockId call_params_id;
+
+  SemIR::Function::CallParamIndexRanges param_ranges;
+};
+
 // Emits the pattern-match IR for the declaration of a parameterized entity with
 // the given implicit and explicit parameter patterns, and the given return
 // patterns (any of which may be `None` if not applicable). This IR performs the
@@ -28,18 +36,32 @@ namespace Carbon::Check {
 // Returns the IDs of inst blocks consisting of references to the `Call`
 // parameter patterns and `Call` parameters of the function, as well as
 // the implicit, explicit, and return index ranges of those blocks.
-struct CalleePatternMatchResults {
-  SemIR::InstBlockId call_param_patterns_id;
-  SemIR::InstBlockId call_params_id;
-
-  SemIR::Function::CallParamIndexRanges param_ranges;
-};
 auto CalleePatternMatch(Context& context,
                         SemIR::InstBlockId implicit_param_patterns_id,
                         SemIR::InstBlockId param_patterns_id,
                         SemIR::InstBlockId return_patterns_id)
     -> CalleePatternMatchResults;
 
+// Return type for ThunkPatternMatch.
+struct ThunkPatternMatchResults {
+  // The syntactic argument list. If `self_pattern_id` is not `None`, the first
+  // element will be the corresponding argument.
+  llvm::SmallVector<SemIR::InstId> syntactic_args;
+
+  // The trailing elements of `outer_call_args` that were not used in
+  // `syntactic_args`. These presumably represent the output arguments for the
+  // return.
+  llvm::ArrayRef<SemIR::InstId> ignored_call_args;
+};
+
+// Given the `Call` arguments for the outer part of a thunked function call,
+// computes the corresponding syntactic argument list, suitable for passing to
+// the inner part of the thunked function call.
+auto ThunkPatternMatch(Context& context, SemIR::InstId self_pattern_id,
+                       SemIR::InstBlockId param_patterns_id,
+                       llvm::ArrayRef<SemIR::InstId> outer_call_args)
+    -> ThunkPatternMatchResults;
+
 // Emits the pattern-match IR for matching the given arguments with the given
 // parameter patterns, and returns an inst block of the arguments that should
 // be passed to the `Call` inst. `is_operator_syntax` indicates that this call

+ 14 - 67
toolchain/check/thunk.cpp

@@ -265,73 +265,20 @@ auto PerformThunkCall(Context& context, SemIR::LocId loc_id,
                       SemIR::InstId callee_id) -> SemIR::InstId {
   auto& function = context.functions().Get(function_id);
 
-  auto param_pattern_ids =
-      context.inst_blocks().Get(function.call_param_patterns_id);
-
-  // Maps each `Call` parameter pattern ID to its index.
-  // TODO: is it possible to arrange for the param patterns to be created in
-  // order, so that we could use `param_pattern_ids` for this directly?
-  struct InstWithIndex {
-    SemIR::InstId inst_id;
-    int index;
-
-    auto operator<(InstWithIndex other) const -> bool {
-      return inst_id.index < other.inst_id.index;
-    }
-  };
-  llvm::SmallVector<InstWithIndex> param_to_index;
-
-  param_to_index.reserve(param_pattern_ids.size());
-  for (auto [index, inst_id] : llvm::enumerate(param_pattern_ids)) {
-    param_to_index.push_back({inst_id, static_cast<int>(index)});
-  }
-  llvm::sort(param_to_index);
-
-  // Given that `call_arg_ids` is a list of the _`Call`_ arguments for a call to
-  // `function_id`, this returns the _syntactic_ argument that was passed for
-  // param_pattern_id in that call.
-  auto build_syntactic_arg = [&](SemIR::InstId param_pattern_id) {
-    if (auto at_binding_pattern =
-            context.insts().TryGetAs<SemIR::WrapperBindingPattern>(
-                param_pattern_id)) {
-      param_pattern_id = at_binding_pattern->subpattern_id;
-    }
-    // NOLINTNEXTLINE(readability-qualified-auto)
-    auto result =
-        llvm::lower_bound(param_to_index, InstWithIndex{param_pattern_id, -1});
-    if (result < param_to_index.end() && result->inst_id == param_pattern_id) {
-      return call_arg_ids[result->index];
-    } else {
-      if (param_pattern_id != SemIR::ErrorInst::InstId) {
-        context.TODO(param_pattern_id,
-                     "don't know how to reconstruct the syntactic argument for "
-                     "this pattern in thunk");
-      }
-      return SemIR::ErrorInst::InstId;
-    }
-  };
-
-  llvm::SmallVector<SemIR::InstId> args;
-
-  // If we have a self parameter, form `self.<callee_id>`.
-  if (function.self_param_id.has_value()) {
-    auto self_arg_id = build_syntactic_arg(function.self_param_id);
-    if (IsCppConstructorOrNonMethodOperator(context, callee_id)) {
-      // When calling a C++ constructor to implement `Copy`, or calling a C++
-      // non-method operator to implement a Carbon operator, the interface has a
-      // `self` parameter but C++ models that parameter as an explicit argument
-      // instead, so add the `self` to the argument list instead in that case.
-      args.push_back(self_arg_id);
-    } else {
-      callee_id =
-          PerformCompoundMemberAccess(context, loc_id, self_arg_id, callee_id);
-    }
-  }
-
-  // Form an argument list.
-  for (auto pattern_id :
-       context.inst_blocks().Get(function.param_patterns_id)) {
-    args.push_back(build_syntactic_arg(pattern_id));
+  auto [args_vec, ignored_call_args] =
+      ThunkPatternMatch(context, function.self_param_id,
+                        function.param_patterns_id, call_arg_ids);
+  llvm::ArrayRef<SemIR::InstId> args = args_vec;
+
+  // If we have a self parameter, form `self.<callee_id>` if needed.
+  // When calling a C++ constructor to implement `Copy`, or calling a C++
+  // non-method operator to implement a Carbon operator, the interface has a
+  // `self` parameter but C++ models that parameter as an explicit argument
+  // instead, so add the `self` to the argument list instead in that case.
+  if (function.self_param_id.has_value() &&
+      !IsCppConstructorOrNonMethodOperator(context, callee_id)) {
+    callee_id = PerformCompoundMemberAccess(context, loc_id,
+                                            args.consume_front(), callee_id);
   }
 
   return PerformCall(context, loc_id, callee_id, args);