Просмотр исходного кода

Clean up pattern matching (#6987)

The key changes here are:
- The different kinds of pattern match are represented as alternatives
of a `variant`, instead of enumerators of an `enum`, so that they can
hold their own state instead of having a bunch of conditionally-usable
members of `MatchContext`.
- The public API of `MatchContext` is a `Match` operation that's applied
to a single pattern and scrutinee; the worklist is no longer directly
accessible.
- `Match` has a counterpart `MatchWithResult` that returns the result of
matching the pattern.
- `Context` is now a member of `MatchContext` instead of a parameter to
most of its methods.
Geoff Romer 1 месяц назад
Родитель
Сommit
f4260feee4
1 измененных файлов с 214 добавлено и 261 удалено
  1. 214 261
      toolchain/check/pattern_match.cpp

+ 214 - 261
toolchain/check/pattern_match.cpp

@@ -6,6 +6,7 @@
 
 #include <functional>
 #include <utility>
+#include <variant>
 #include <vector>
 
 #include "llvm/ADT/STLExtras.h"
@@ -24,24 +25,30 @@ namespace Carbon::Check {
 
 namespace {
 
-// Selects between the different kinds of pattern matching.
-enum class MatchKind : uint8_t {
-  // Caller pattern matching occurs on the caller side of a function call, and
-  // is responsible for matching the argument expression against the portion
-  // of the pattern above the ParamPattern insts.
-  Caller,
-
-  // Callee pattern matching occurs in the function decl block, and is
-  // responsible for matching the function's calling-convention parameters
-  // against the portion of the pattern below the ParamPattern insts.
-  Callee,
-
-  // Local pattern matching is pattern matching outside of a function call,
-  // such as in a let/var declaration.
-  Local,
+// State for caller-side pattern matching.
+struct CallerState {
+  // The in-progress contents of the `Call` arguments block.
+  llvm::SmallVector<SemIR::InstId> call_args;
+
+  // The SpecificId of the function being called (if any).
+  SemIR::SpecificId callee_specific_id;
+};
+
+// State for callee-side pattern matching.
+struct CalleeState {
+  // The in-progress contents of the `Call` parameters block.
+  llvm::SmallVector<SemIR::InstId> call_params;
+
+  // The in-progress contents of the `Call` parameter patterns block.
+  llvm::SmallVector<SemIR::InstId> call_param_patterns;
 };
 
-// The collected state of a pattern-matching operation.
+// State for local pattern matching.
+struct LocalState {};
+
+using State = std::variant<CallerState*, CalleeState*, LocalState*>;
+
+// The worklist and state machine for a pattern-matching operation.
 //
 // Conceptually, pattern matching is a recursive traversal of the pattern inst
 // tree: we match a pattern inst to a scrutinee inst by converting the scrutinee
@@ -98,13 +105,16 @@ class MatchContext {
     }
   };
 
-  // Constructs a MatchContext. If `callee_specific_id` is not `None`, this
-  // pattern match operation is part of implementing the signature of the given
-  // specific.
-  explicit MatchContext(MatchKind kind, SemIR::SpecificId callee_specific_id =
-                                            SemIR::SpecificId::None)
-      : kind_(kind), callee_specific_id_(callee_specific_id) {}
+  // Constructs a MatchContext.
+  explicit MatchContext(Context& context) : context_(context) {}
+
+  // Performs pattern matching for the given work item.
+  auto Match(State state, WorkItem entry) -> void;
+
+  // Performs pattern matching for the given work item, and returns the result.
+  auto MatchWithResult(State state, WorkItem entry) -> SemIR::InstId;
 
+ private:
   // Whether the result of the work item at the top of the stack is needed.
   auto need_subpattern_results() const -> bool {
     return !results_stack_.empty();
@@ -119,62 +129,39 @@ class MatchContext {
     AddWork(entry);
   }
 
-  // Processes all work items on the stack.
-  auto DoWork(Context& context) -> void;
-
-  // Returns an inst block of references to all the emitted `Call` arguments.
-  // Can only be called once, at the end of Caller pattern matching.
-  auto GetCallArgs(Context& context) && -> SemIR::InstBlockId;
-
-  // Returns an inst block of references to all the emitted `Call` params,
-  // and an inst block of references to the `Call` param patterns they were
-  // emitted to match. Can only be called once, at the end of Callee pattern
-  // matching.
-  struct ParamBlocks {
-    SemIR::InstBlockId call_param_patterns_id;
-    SemIR::InstBlockId call_params_id;
-  };
-  auto GetCallParams(Context& context) && -> ParamBlocks;
-
-  // Returns the number of call parameters that have been emitted so far.
-  auto param_count() -> int { return call_params_.size(); }
-  ~MatchContext();
-
- private:
   // Dispatches `entry` to the appropriate DoWork method based on the kinds of
   // `entry.pattern_id` and `entry.work`.
-  auto Dispatch(Context& context, WorkItem entry) -> void;
+  auto Dispatch(State state, WorkItem entry) -> void;
 
   // Do the pre-work for `entry`. `entry.work` must be a `PreWork` containing
   // `scrutinee_id`, and the pattern argument must be the value of
   // `entry.pattern_id` in `context`.
-  auto DoPreWork(Context& context, SemIR::AnyBindingPattern binding_pattern,
+  auto DoPreWork(State state, SemIR::AnyBindingPattern binding_pattern,
                  SemIR::InstId scrutinee_id, WorkItem entry) -> void;
-  auto DoPreWork(Context& context, SemIR::AnyParamPattern param_pattern,
+  auto DoPreWork(State state, SemIR::AnyParamPattern param_pattern,
                  SemIR::InstId scrutinee_id, WorkItem entry) -> void;
-  auto DoPreWork(Context& context, SemIR::ExprPattern expr_pattern,
+  auto DoPreWork(State state, SemIR::ExprPattern expr_pattern,
                  SemIR::InstId scrutinee_id, WorkItem entry) -> void;
-  auto DoPreWork(Context& context, SemIR::ReturnSlotPattern return_slot_pattern,
+  auto DoPreWork(State state, SemIR::ReturnSlotPattern return_slot_pattern,
                  SemIR::InstId scrutinee_id, WorkItem entry) -> void;
-  auto DoPreWork(Context& context, SemIR::VarPattern var_pattern,
+  auto DoPreWork(State state, SemIR::VarPattern var_pattern,
                  SemIR::InstId scrutinee_id, WorkItem entry) -> void;
-  auto DoPreWork(Context& context, SemIR::TuplePattern tuple_pattern,
+  auto DoPreWork(State state, SemIR::TuplePattern tuple_pattern,
                  SemIR::InstId scrutinee_id, WorkItem entry) -> void;
 
   // Do the post-work for `entry`. `entry.work` must be a `PostWork`, and
-  // the pattern argument must be the value of `entry.pattern_id` in `context`.
-  auto DoPostWork(Context& context, SemIR::AnyBindingPattern binding_pattern,
-                  WorkItem entry) -> void;
-  auto DoPostWork(Context& context, SemIR::VarPattern var_pattern,
+  // the pattern argument must be the value of `entry.pattern_id` in `context_`.
+  auto DoPostWork(State state, SemIR::AnyBindingPattern binding_pattern,
                   WorkItem entry) -> void;
-  auto DoPostWork(Context& context, SemIR::AnyParamPattern param_pattern,
-                  WorkItem entry) -> void;
-  auto DoPostWork(Context& context, SemIR::ExprPattern expr_pattern,
+  auto DoPostWork(State state, SemIR::VarPattern var_pattern, WorkItem entry)
+      -> void;
+  auto DoPostWork(State state, SemIR::AnyParamPattern param_pattern,
                   WorkItem entry) -> void;
-  auto DoPostWork(Context& context,
-                  SemIR::ReturnSlotPattern return_slot_pattern, WorkItem entry)
+  auto DoPostWork(State state, SemIR::ExprPattern expr_pattern, WorkItem entry)
       -> void;
-  auto DoPostWork(Context& context, SemIR::TuplePattern tuple_pattern,
+  auto DoPostWork(State state, SemIR::ReturnSlotPattern return_slot_pattern,
+                  WorkItem entry) -> void;
+  auto DoPostWork(State state, SemIR::TuplePattern tuple_pattern,
                   WorkItem entry) -> void;
 
   // Asserts that there is a single inst in the top array in `results_stack_`,
@@ -191,7 +178,7 @@ class MatchContext {
   // matched with, rather than pushing it onto the worklist. This is factored
   // out so it can be reused when handling a `FormBindingPattern` or
   // `FormParamPattern` with an initializing form.
-  auto DoVarPreWorkImpl(Context& context, SemIR::TypeId pattern_type_id,
+  auto DoVarPreWorkImpl(State state, SemIR::TypeId pattern_type_id,
                         SemIR::InstId scrutinee_id, WorkItem entry) const
       -> SemIR::InstId;
 
@@ -202,58 +189,24 @@ class MatchContext {
   // a single result, which may have multiple sub-results.
   ArrayStack<SemIR::InstId> results_stack_;
 
-  // The in-progress contents of the `Call` arguments block. This is populated
-  // only when kind_ is Caller.
-  llvm::SmallVector<SemIR::InstId> call_args_;
-
-  // The in-progress contents of the `Call` parameters block. This is populated
-  // only when kind_ is Callee.
-  llvm::SmallVector<SemIR::InstId> call_params_;
-
-  // The in-progress contents of the `Call` parameter patterns block. This is
-  // populated only when kind_ is Callee.
-  llvm::SmallVector<SemIR::InstId> call_param_patterns_;
-
-  // The kind of pattern match being performed.
-  MatchKind kind_;
-
-  // The SpecificId of the function being called (if any).
-  SemIR::SpecificId callee_specific_id_;
+  Context& context_;
 };
 
 }  // namespace
 
-auto MatchContext::DoWork(Context& context) -> void {
+auto MatchContext::Match(State state, WorkItem entry) -> void {
+  CARBON_CHECK(stack_.empty());
+  stack_.push_back(entry);
   while (!stack_.empty()) {
-    Dispatch(context, stack_.pop_back_val());
+    Dispatch(state, stack_.pop_back_val());
   }
 }
 
-auto MatchContext::GetCallArgs(Context& context) && -> SemIR::InstBlockId {
-  CARBON_CHECK(kind_ == MatchKind::Caller);
-  auto block_id = context.inst_blocks().Add(call_args_);
-  call_args_.clear();
-  return block_id;
-}
-
-auto MatchContext::GetCallParams(Context& context) && -> ParamBlocks {
-  CARBON_CHECK(kind_ == MatchKind::Callee);
-  CARBON_CHECK(call_params_.size() == call_param_patterns_.size());
-  auto call_param_patterns_id = context.inst_blocks().Add(call_param_patterns_);
-  call_param_patterns_.clear();
-  auto call_params_id = context.inst_blocks().Add(call_params_);
-  call_params_.clear();
-  return {.call_param_patterns_id = call_param_patterns_id,
-          .call_params_id = call_params_id};
-}
-
-MatchContext::~MatchContext() {
-  CARBON_CHECK(call_args_.empty() && call_params_.empty() &&
-                   call_param_patterns_.empty(),
-               "Unhandled pattern matching outputs. call_args_.size(): {0}, "
-               "call_params_.size(): {1}, call_param_patterns_.size(): {2}",
-               call_args_.size(), call_params_.size(),
-               call_param_patterns_.size());
+auto MatchContext::MatchWithResult(State state, WorkItem entry)
+    -> SemIR::InstId {
+  results_stack_.PushArray();
+  Match(state, entry);
+  return PopResult();
 }
 
 // Inserts the given region into the current code block. If the region
@@ -379,12 +332,12 @@ static auto ConversionKindFor(Context& context, SemIR::Inst pattern,
   }
 }
 
-auto MatchContext::DoPreWork(Context& /*context*/,
+auto MatchContext::DoPreWork(State state,
                              SemIR::AnyBindingPattern binding_pattern,
-                             SemIR::InstId scrutinee_id,
-                             MatchContext::WorkItem entry) -> void {
+                             SemIR::InstId scrutinee_id, WorkItem entry)
+    -> void {
   bool scheduled_post_work = false;
-  if (kind_ != MatchKind::Caller) {
+  if (!std::holds_alternative<CallerState*>(state)) {
     results_stack_.PushArray();
     AddAsPostWork(entry);
     scheduled_post_work = true;
@@ -403,22 +356,22 @@ auto MatchContext::DoPreWork(Context& /*context*/,
   }
 }
 
-auto MatchContext::DoPostWork(Context& context,
+auto MatchContext::DoPostWork(State /*state*/,
                               SemIR::AnyBindingPattern binding_pattern,
-                              MatchContext::WorkItem entry) -> void {
+                              WorkItem entry) -> void {
   // We're logically consuming this map entry, so we invalidate it in order
   // to avoid accidentally consuming it twice.
   auto [bind_name_id, type_expr_region_id] =
-      std::exchange(context.bind_name_map().Lookup(entry.pattern_id).value(),
+      std::exchange(context_.bind_name_map().Lookup(entry.pattern_id).value(),
                     {.bind_name_id = SemIR::InstId::None,
                      .type_expr_region_id = SemIR::ExprRegionId::None});
   if (type_expr_region_id.has_value()) {
-    InsertHere(context, type_expr_region_id);
+    InsertHere(context_, type_expr_region_id);
   }
   auto value_id = PopResult();
 
   if (value_id.has_value()) {
-    auto conversion_kind = ConversionKindFor(context, binding_pattern, entry);
+    auto conversion_kind = ConversionKindFor(context_, binding_pattern, entry);
     if (!bind_name_id.has_value()) {
       // TODO: Is this appropriate, or should we perform a conversion based on
       // the category of the `_` binding first, and then separately discard the
@@ -426,19 +379,19 @@ auto MatchContext::DoPostWork(Context& context,
       conversion_kind = ConversionTarget::Discarded;
     }
     value_id =
-        Convert(context, SemIR::LocId(value_id), value_id,
+        Convert(context_, SemIR::LocId(value_id), value_id,
                 {.kind = conversion_kind,
-                 .type_id = context.insts().Get(bind_name_id).type_id()});
+                 .type_id = context_.insts().Get(bind_name_id).type_id()});
   } else {
     CARBON_CHECK(binding_pattern.kind == SemIR::SymbolicBindingPattern::Kind);
   }
 
   if (bind_name_id.has_value()) {
-    auto bind_name = context.insts().GetAs<SemIR::AnyBinding>(bind_name_id);
+    auto bind_name = context_.insts().GetAs<SemIR::AnyBinding>(bind_name_id);
     CARBON_CHECK(!bind_name.value_id.has_value());
     bind_name.value_id = value_id;
-    ReplaceInstBeforeConstantUse(context, bind_name_id, bind_name);
-    context.inst_block_stack().AddInstId(bind_name_id);
+    ReplaceInstBeforeConstantUse(context_, bind_name_id, bind_name);
+    context_.inst_block_stack().AddInstId(bind_name_id);
   }
   if (need_subpattern_results()) {
     results_stack_.AppendToTop(value_id);
@@ -480,8 +433,7 @@ static auto ParamKindFor(Context& context, SemIR::Inst param_pattern,
   }
 }
 
-auto MatchContext::DoPreWork(Context& context,
-                             SemIR::AnyParamPattern param_pattern,
+auto MatchContext::DoPreWork(State state, SemIR::AnyParamPattern param_pattern,
                              SemIR::InstId scrutinee_id, WorkItem entry)
     -> void {
   AddAsPostWork(entry);
@@ -491,8 +443,8 @@ auto MatchContext::DoPreWork(Context& context,
   switch (param_pattern.kind) {
     case SemIR::FormParamPattern::Kind: {
       auto form_param_pattern =
-          context.insts().GetAs<SemIR::FormParamPattern>(entry.pattern_id);
-      if (!context.constant_values().InstIs<SemIR::InitForm>(
+          context_.insts().GetAs<SemIR::FormParamPattern>(entry.pattern_id);
+      if (!context_.constant_values().InstIs<SemIR::InitForm>(
               form_param_pattern.form_id)) {
         break;
       }
@@ -500,7 +452,7 @@ auto MatchContext::DoPreWork(Context& context,
     }
     case SemIR::VarParamPattern::Kind: {
       scrutinee_id =
-          DoVarPreWorkImpl(context, param_pattern.type_id, scrutinee_id, entry);
+          DoVarPreWorkImpl(state, param_pattern.type_id, scrutinee_id, entry);
       entry.allow_unmarked_ref = true;
       break;
     }
@@ -508,51 +460,52 @@ auto MatchContext::DoPreWork(Context& context,
       break;
   }
 
-  switch (kind_) {
-    case MatchKind::Caller: {
+  CARBON_KIND_SWITCH(state) {
+    case CARBON_KIND(CallerState* caller_state): {
       CARBON_CHECK(scrutinee_id.has_value());
       if (scrutinee_id == SemIR::ErrorInst::InstId) {
-        call_args_.push_back(SemIR::ErrorInst::InstId);
+        caller_state->call_args.push_back(SemIR::ErrorInst::InstId);
       } else {
         auto scrutinee_type_id = ExtractScrutineeType(
-            context.sem_ir(),
-            SemIR::GetTypeOfInstInSpecific(
-                context.sem_ir(), callee_specific_id_, entry.pattern_id));
-        call_args_.push_back(
-            Convert(context, SemIR::LocId(scrutinee_id), scrutinee_id,
-                    {.kind = ConversionKindFor(context, param_pattern, entry),
+            context_.sem_ir(),
+            SemIR::GetTypeOfInstInSpecific(context_.sem_ir(),
+                                           caller_state->callee_specific_id,
+                                           entry.pattern_id));
+        caller_state->call_args.push_back(
+            Convert(context_, SemIR::LocId(scrutinee_id), scrutinee_id,
+                    {.kind = ConversionKindFor(context_, param_pattern, entry),
                      .type_id = scrutinee_type_id}));
       }
       // Do not traverse farther or schedule PostWork, because the caller side
       // of the pattern ends here.
       break;
     }
-    case MatchKind::Callee: {
-      SemIR::Inst param =
-          SemIR::AnyParam{.kind = ParamKindFor(context, param_pattern, entry),
-                          .type_id = ExtractScrutineeType(
-                              context.sem_ir(), param_pattern.type_id),
-                          .index = SemIR::CallParamIndex(call_params_.size()),
-                          .pretty_name_id = SemIR::GetPrettyNameFromPatternId(
-                              context.sem_ir(), entry.pattern_id)};
+    case CARBON_KIND(CalleeState* callee_state): {
+      SemIR::Inst param = SemIR::AnyParam{
+          .kind = ParamKindFor(context_, param_pattern, entry),
+          .type_id =
+              ExtractScrutineeType(context_.sem_ir(), param_pattern.type_id),
+          .index = SemIR::CallParamIndex(callee_state->call_params.size()),
+          .pretty_name_id = SemIR::GetPrettyNameFromPatternId(
+              context_.sem_ir(), entry.pattern_id)};
       auto loc_id = SemIR::LocId(entry.pattern_id);
       auto param_id = SemIR::InstId::None;
       // TODO: find a way to avoid this boilerplate.
       switch (param.kind()) {
         case SemIR::OutParam::Kind:
-          param_id = AddInst(context, loc_id, param.As<SemIR::OutParam>());
+          param_id = AddInst(context_, loc_id, param.As<SemIR::OutParam>());
           break;
         case SemIR::RefParam::Kind:
-          param_id = AddInst(context, loc_id, param.As<SemIR::RefParam>());
+          param_id = AddInst(context_, loc_id, param.As<SemIR::RefParam>());
           break;
         case SemIR::ValueParam::Kind:
-          param_id = AddInst(context, loc_id, param.As<SemIR::ValueParam>());
+          param_id = AddInst(context_, loc_id, param.As<SemIR::ValueParam>());
           break;
         default:
           CARBON_FATAL("Unexpected parameter kind");
       }
       if (auto var_param_pattern =
-              context.insts().TryGetAs<SemIR::VarParamPattern>(
+              context_.insts().TryGetAs<SemIR::VarParamPattern>(
                   entry.pattern_id)) {
         AddWork({.pattern_id = var_param_pattern->subpattern_id,
                  .work = PreWork{.scrutinee_id = param_id},
@@ -560,17 +513,17 @@ auto MatchContext::DoPreWork(Context& context,
       } else {
         results_stack_.AppendToTop(param_id);
       }
-      call_params_.push_back(param_id);
-      call_param_patterns_.push_back(entry.pattern_id);
+      callee_state->call_params.push_back(param_id);
+      callee_state->call_param_patterns.push_back(entry.pattern_id);
       break;
     }
-    case MatchKind::Local: {
+    case CARBON_KIND(LocalState* _): {
       CARBON_FATAL("Found ValueParamPattern during local pattern match");
     }
   }
 }
 
-auto MatchContext::DoPostWork(Context& /*context*/,
+auto MatchContext::DoPostWork(State /*state*/,
                               SemIR::AnyParamPattern /*param_pattern*/,
                               WorkItem /*entry*/) -> void {
   // No-op: the subpattern's result is this pattern's result. Note that if
@@ -578,22 +531,22 @@ auto MatchContext::DoPostWork(Context& /*context*/,
   // would have to be done here.
 }
 
-auto MatchContext::DoPreWork(Context& context,
+auto MatchContext::DoPreWork(State /*state*/,
                              SemIR::ExprPattern /*expr_pattern*/,
                              SemIR::InstId /*scrutinee_id*/, WorkItem entry)
     -> void {
-  context.TODO(entry.pattern_id, "expression pattern");
+  context_.TODO(entry.pattern_id, "expression pattern");
 }
 
-auto MatchContext::DoPostWork(Context& /*context*/,
+auto MatchContext::DoPostWork(State /*state*/,
                               SemIR::ExprPattern /*expr_pattern*/,
                               WorkItem /*entry*/) -> void {}
 
-auto MatchContext::DoPreWork(Context& /*context*/,
+auto MatchContext::DoPreWork(State state,
                              SemIR::ReturnSlotPattern return_slot_pattern,
                              SemIR::InstId scrutinee_id, WorkItem entry)
     -> void {
-  if (kind_ == MatchKind::Callee) {
+  if (std::holds_alternative<CalleeState*>(state)) {
     CARBON_CHECK(!scrutinee_id.has_value());
     results_stack_.PushArray();
     AddAsPostWork(entry);
@@ -602,19 +555,19 @@ auto MatchContext::DoPreWork(Context& /*context*/,
            .work = PreWork{.scrutinee_id = scrutinee_id}});
 }
 
-auto MatchContext::DoPostWork(Context& context,
+auto MatchContext::DoPostWork(State state,
                               SemIR::ReturnSlotPattern return_slot_pattern,
                               WorkItem entry) -> void {
-  CARBON_CHECK(kind_ == MatchKind::Callee);
+  CARBON_CHECK(std::holds_alternative<CalleeState*>(state));
   auto type_id =
-      ExtractScrutineeType(context.sem_ir(), return_slot_pattern.type_id);
+      ExtractScrutineeType(context_.sem_ir(), return_slot_pattern.type_id);
   auto return_slot_id = AddInst<SemIR::ReturnSlot>(
-      context, SemIR::LocId(entry.pattern_id),
+      context_, SemIR::LocId(entry.pattern_id),
       {.type_id = type_id,
-       .type_inst_id = context.types().GetTypeInstId(type_id),
+       .type_inst_id = context_.types().GetTypeInstId(type_id),
        .storage_id = PopResult()});
   bool already_in_lookup =
-      context.scope_stack()
+      context_.scope_stack()
           .LookupOrAddName(SemIR::NameId::ReturnSlot, return_slot_id)
           .has_value();
   CARBON_CHECK(!already_in_lookup);
@@ -623,11 +576,11 @@ auto MatchContext::DoPostWork(Context& context,
   }
 }
 
-auto MatchContext::DoPreWork(Context& context, SemIR::VarPattern var_pattern,
+auto MatchContext::DoPreWork(State state, SemIR::VarPattern var_pattern,
                              SemIR::InstId scrutinee_id, WorkItem entry)
     -> void {
   auto new_scrutinee_id =
-      DoVarPreWorkImpl(context, var_pattern.type_id, scrutinee_id, entry);
+      DoVarPreWorkImpl(state, var_pattern.type_id, scrutinee_id, entry);
   if (need_subpattern_results()) {
     AddAsPostWork(entry);
   }
@@ -636,30 +589,30 @@ auto MatchContext::DoPreWork(Context& context, SemIR::VarPattern var_pattern,
            .allow_unmarked_ref = true});
 }
 
-auto MatchContext::DoVarPreWorkImpl(Context& context,
-                                    SemIR::TypeId pattern_type_id,
+auto MatchContext::DoVarPreWorkImpl(State state, SemIR::TypeId pattern_type_id,
                                     SemIR::InstId scrutinee_id,
                                     WorkItem entry) const -> SemIR::InstId {
   auto storage_id = SemIR::InstId::None;
-  switch (kind_) {
-    case MatchKind::Callee: {
+  CARBON_KIND_SWITCH(state) {
+    case CARBON_KIND(CalleeState* _): {
       // We're emitting pattern-match IR for the callee, but we're still on
       // the caller side of the pattern, so we traverse without emitting any
       // insts.
       return scrutinee_id;
     }
-    case MatchKind::Local: {
+    case CARBON_KIND(LocalState* _): {
       // In a `var`/`let` declaration, the `VarStorage` inst is created before
       // we start pattern matching.
-      auto lookup_result = context.var_storage_map().Lookup(entry.pattern_id);
+      auto lookup_result = context_.var_storage_map().Lookup(entry.pattern_id);
       CARBON_CHECK(lookup_result);
       storage_id = lookup_result.value();
       break;
     }
-    case MatchKind::Caller: {
+    case CARBON_KIND(CallerState* _): {
       storage_id = AddInst<SemIR::TemporaryStorage>(
-          context, SemIR::LocId(entry.pattern_id),
-          {.type_id = ExtractScrutineeType(context.sem_ir(), pattern_type_id)});
+          context_, SemIR::LocId(entry.pattern_id),
+          {.type_id =
+               ExtractScrutineeType(context_.sem_ir(), pattern_type_id)});
       CARBON_CHECK(scrutinee_id.has_value());
       break;
     }
@@ -667,50 +620,49 @@ auto MatchContext::DoVarPreWorkImpl(Context& context,
   // TODO: Find a more efficient way to put these insts in the global_init
   // block (or drop the distinction between the global_init block and the
   // file scope?)
-  if (context.scope_stack().PeekIndex() == ScopeIndex::Package) {
-    context.global_init().Resume();
+  if (context_.scope_stack().PeekIndex() == ScopeIndex::Package) {
+    context_.global_init().Resume();
   }
   if (scrutinee_id.has_value()) {
-    auto init_id = Initialize(context, SemIR::LocId(entry.pattern_id),
+    auto init_id = Initialize(context_, SemIR::LocId(entry.pattern_id),
                               storage_id, scrutinee_id);
     // If we created a `TemporaryStorage` to hold the var, create a
     // corresponding `Temporary` to model that its initialization is complete.
     // TODO: If the subpattern is a binding, we may want to destroy the
     // parameter variable in the callee instead of the caller so that we can
     // support destructive move from it.
-    if (kind_ == MatchKind::Caller) {
+    if (std::holds_alternative<CallerState*>(state)) {
       storage_id = AddInstWithCleanup<SemIR::Temporary>(
-          context, SemIR::LocId(entry.pattern_id),
-          {.type_id = context.insts().Get(storage_id).type_id(),
+          context_, SemIR::LocId(entry.pattern_id),
+          {.type_id = context_.insts().Get(storage_id).type_id(),
            .storage_id = storage_id,
            .init_id = init_id});
     } else {
       // TODO: Consider using different instruction kinds for assignment
       // versus initialization.
-      AddInst<SemIR::Assign>(context, SemIR::LocId(entry.pattern_id),
+      AddInst<SemIR::Assign>(context_, SemIR::LocId(entry.pattern_id),
                              {.lhs_id = storage_id, .rhs_id = init_id});
     }
   }
-  if (context.scope_stack().PeekIndex() == ScopeIndex::Package) {
-    context.global_init().Suspend();
+  if (context_.scope_stack().PeekIndex() == ScopeIndex::Package) {
+    context_.global_init().Suspend();
   }
   return storage_id;
 }
 
-auto MatchContext::DoPostWork(Context& /*context*/,
+auto MatchContext::DoPostWork(State /*state*/,
                               SemIR::VarPattern /*var_pattern*/,
                               WorkItem /*entry*/) -> void {
   // No-op: the subpattern's result is this pattern's result.
 }
 
-auto MatchContext::DoPreWork(Context& context,
-                             SemIR::TuplePattern tuple_pattern,
+auto MatchContext::DoPreWork(State state, SemIR::TuplePattern tuple_pattern,
                              SemIR::InstId scrutinee_id, WorkItem entry)
     -> void {
   if (tuple_pattern.type_id == SemIR::ErrorInst::TypeId) {
     return;
   }
-  auto subpattern_ids = context.inst_blocks().Get(tuple_pattern.elements_id);
+  auto subpattern_ids = context_.inst_blocks().Get(tuple_pattern.elements_id);
   if (need_subpattern_results()) {
     results_stack_.PushArray();
     AddAsPostWork(entry);
@@ -724,7 +676,7 @@ auto MatchContext::DoPreWork(Context& context,
         }
       };
   if (!scrutinee_id.has_value()) {
-    CARBON_CHECK(kind_ == MatchKind::Callee);
+    CARBON_CHECK(std::holds_alternative<CalleeState*>(state));
     // If we don't have a scrutinee yet, we're still on the caller side of the
     // pattern, so the subpatterns don't have a scrutinee either.
     for (auto subpattern_id : llvm::reverse(subpattern_ids)) {
@@ -733,18 +685,18 @@ auto MatchContext::DoPreWork(Context& context,
     }
     return;
   }
-  auto scrutinee = context.insts().GetWithLocId(scrutinee_id);
+  auto scrutinee = context_.insts().GetWithLocId(scrutinee_id);
   if (auto scrutinee_literal = scrutinee.inst.TryAs<SemIR::TupleLiteral>()) {
     auto subscrutinee_ids =
-        context.inst_blocks().Get(scrutinee_literal->elements_id);
+        context_.inst_blocks().Get(scrutinee_literal->elements_id);
     if (subscrutinee_ids.size() != subpattern_ids.size()) {
       CARBON_DIAGNOSTIC(TuplePatternSizeDoesntMatchLiteral, Error,
                         "tuple pattern expects {0} element{0:s}, but tuple "
                         "literal has {1}",
                         Diagnostics::IntAsSelect, Diagnostics::IntAsSelect);
-      context.emitter().Emit(entry.pattern_id,
-                             TuplePatternSizeDoesntMatchLiteral,
-                             subpattern_ids.size(), subscrutinee_ids.size());
+      context_.emitter().Emit(entry.pattern_id,
+                              TuplePatternSizeDoesntMatchLiteral,
+                              subpattern_ids.size(), subscrutinee_ids.size());
       return;
     }
     add_all_subscrutinees(subscrutinee_ids);
@@ -752,25 +704,25 @@ auto MatchContext::DoPreWork(Context& context,
   }
 
   auto tuple_type_id =
-      ExtractScrutineeType(context.sem_ir(), tuple_pattern.type_id);
+      ExtractScrutineeType(context_.sem_ir(), tuple_pattern.type_id);
   auto converted_scrutinee_id = ConvertToValueOrRefOfType(
-      context, SemIR::LocId(entry.pattern_id), scrutinee_id, tuple_type_id);
-  if (auto scrutinee_value =
-          context.insts().TryGetAs<SemIR::TupleValue>(converted_scrutinee_id)) {
+      context_, SemIR::LocId(entry.pattern_id), scrutinee_id, tuple_type_id);
+  if (auto scrutinee_value = context_.insts().TryGetAs<SemIR::TupleValue>(
+          converted_scrutinee_id)) {
     add_all_subscrutinees(
-        context.inst_blocks().Get(scrutinee_value->elements_id));
+        context_.inst_blocks().Get(scrutinee_value->elements_id));
     return;
   }
 
-  auto tuple_type = context.types().GetAs<SemIR::TupleType>(tuple_type_id);
+  auto tuple_type = context_.types().GetAs<SemIR::TupleType>(tuple_type_id);
   auto element_type_inst_ids =
-      context.inst_blocks().Get(tuple_type.type_elements_id);
+      context_.inst_blocks().Get(tuple_type.type_elements_id);
   llvm::SmallVector<SemIR::InstId> subscrutinee_ids;
   subscrutinee_ids.reserve(element_type_inst_ids.size());
   for (auto [i, element_type_id] : llvm::enumerate(
-           context.types().GetBlockAsTypeIds(element_type_inst_ids))) {
+           context_.types().GetBlockAsTypeIds(element_type_inst_ids))) {
     subscrutinee_ids.push_back(
-        AddInst<SemIR::TupleAccess>(context, scrutinee.loc_id,
+        AddInst<SemIR::TupleAccess>(context_, scrutinee.loc_id,
                                     {.type_id = element_type_id,
                                      .tuple_id = converted_scrutinee_id,
                                      .index = SemIR::ElementIndex(i)}));
@@ -778,32 +730,32 @@ auto MatchContext::DoPreWork(Context& context,
   add_all_subscrutinees(subscrutinee_ids);
 }
 
-auto MatchContext::DoPostWork(Context& context,
+auto MatchContext::DoPostWork(State /*state*/,
                               SemIR::TuplePattern tuple_pattern, WorkItem entry)
     -> void {
-  auto elements_id = context.inst_blocks().Add(results_stack_.PeekArray());
+  auto elements_id = context_.inst_blocks().Add(results_stack_.PeekArray());
   results_stack_.PopArray();
   auto tuple_value_id =
-      AddInst<SemIR::TupleValue>(context, SemIR::LocId(entry.pattern_id),
+      AddInst<SemIR::TupleValue>(context_, SemIR::LocId(entry.pattern_id),
                                  {.type_id = SemIR::ExtractScrutineeType(
-                                      context.sem_ir(), tuple_pattern.type_id),
+                                      context_.sem_ir(), tuple_pattern.type_id),
                                   .elements_id = elements_id});
   results_stack_.AppendToTop(tuple_value_id);
 }
 
-auto MatchContext::Dispatch(Context& context, WorkItem entry) -> void {
+auto MatchContext::Dispatch(State state, WorkItem entry) -> void {
   if (entry.pattern_id == SemIR::ErrorInst::InstId) {
     return;
   }
   Diagnostics::AnnotationScope annotate_diagnostics(
-      &context.emitter(), [&](auto& builder) {
-        if (kind_ == MatchKind::Caller) {
+      &context_.emitter(), [&](auto& builder) {
+        if (std::holds_alternative<CallerState*>(state)) {
           CARBON_DIAGNOSTIC(InCallToFunctionParam, Note,
                             "initializing function parameter");
           builder.Note(entry.pattern_id, InCallToFunctionParam);
         }
       });
-  auto pattern = context.insts().Get(entry.pattern_id);
+  auto pattern = context_.insts().Get(entry.pattern_id);
   CARBON_KIND_SWITCH(entry.work) {
     case CARBON_KIND(PreWork work): {
       // TODO: Require that `work.scrutinee_id` is valid if and only if insts
@@ -811,27 +763,27 @@ auto MatchContext::Dispatch(Context& context, WorkItem entry) -> void {
       // `ParamPattern` case.
       CARBON_KIND_SWITCH(pattern) {
         case CARBON_KIND_ANY(SemIR::AnyBindingPattern, any_binding_pattern): {
-          DoPreWork(context, any_binding_pattern, work.scrutinee_id, entry);
+          DoPreWork(state, any_binding_pattern, work.scrutinee_id, entry);
           break;
         }
         case CARBON_KIND_ANY(SemIR::AnyParamPattern, any_param_pattern): {
-          DoPreWork(context, any_param_pattern, work.scrutinee_id, entry);
+          DoPreWork(state, any_param_pattern, work.scrutinee_id, entry);
           break;
         }
         case CARBON_KIND(SemIR::ExprPattern expr_pattern): {
-          DoPreWork(context, expr_pattern, work.scrutinee_id, entry);
+          DoPreWork(state, expr_pattern, work.scrutinee_id, entry);
           break;
         }
         case CARBON_KIND(SemIR::ReturnSlotPattern return_slot_pattern): {
-          DoPreWork(context, return_slot_pattern, work.scrutinee_id, entry);
+          DoPreWork(state, return_slot_pattern, work.scrutinee_id, entry);
           break;
         }
         case CARBON_KIND(SemIR::VarPattern var_pattern): {
-          DoPreWork(context, var_pattern, work.scrutinee_id, entry);
+          DoPreWork(state, var_pattern, work.scrutinee_id, entry);
           break;
         }
         case CARBON_KIND(SemIR::TuplePattern tuple_pattern): {
-          DoPreWork(context, tuple_pattern, work.scrutinee_id, entry);
+          DoPreWork(state, tuple_pattern, work.scrutinee_id, entry);
           break;
         }
         default: {
@@ -843,27 +795,27 @@ auto MatchContext::Dispatch(Context& context, WorkItem entry) -> void {
     case CARBON_KIND(PostWork _): {
       CARBON_KIND_SWITCH(pattern) {
         case CARBON_KIND_ANY(SemIR::AnyBindingPattern, any_binding_pattern): {
-          DoPostWork(context, any_binding_pattern, entry);
+          DoPostWork(state, any_binding_pattern, entry);
           break;
         }
         case CARBON_KIND_ANY(SemIR::AnyParamPattern, any_param_pattern): {
-          DoPostWork(context, any_param_pattern, entry);
+          DoPostWork(state, any_param_pattern, entry);
           break;
         }
         case CARBON_KIND(SemIR::ExprPattern expr_pattern): {
-          DoPostWork(context, expr_pattern, entry);
+          DoPostWork(state, expr_pattern, entry);
           break;
         }
         case CARBON_KIND(SemIR::ReturnSlotPattern return_slot_pattern): {
-          DoPostWork(context, return_slot_pattern, entry);
+          DoPostWork(state, return_slot_pattern, entry);
           break;
         }
         case CARBON_KIND(SemIR::VarPattern var_pattern): {
-          DoPostWork(context, var_pattern, entry);
+          DoPostWork(state, var_pattern, entry);
           break;
         }
         case CARBON_KIND(SemIR::TuplePattern tuple_pattern): {
-          DoPostWork(context, tuple_pattern, entry);
+          DoPostWork(state, tuple_pattern, entry);
           break;
         }
         default: {
@@ -887,45 +839,45 @@ auto CalleePatternMatch(Context& context,
             .param_ranges = SemIR::Function::CallParamIndexRanges::Empty};
   }
 
-  MatchContext match(MatchKind::Callee);
+  CalleeState state;
+  MatchContext match(context);
 
   // We add work to the stack in reverse so that the results will be produced
   // in the original order.
   if (implicit_param_patterns_id.has_value()) {
     for (SemIR::InstId inst_id :
-         llvm::reverse(context.inst_blocks().Get(implicit_param_patterns_id))) {
-      match.AddWork(
+         context.inst_blocks().Get(implicit_param_patterns_id)) {
+      match.Match(
+          &state,
           {.pattern_id = inst_id,
            .work = MatchContext::PreWork{.scrutinee_id = SemIR::InstId::None}});
     }
   }
-  match.DoWork(context);
-  auto implicit_end = SemIR::CallParamIndex(match.param_count());
+  auto implicit_end = SemIR::CallParamIndex(state.call_params.size());
 
   if (param_patterns_id.has_value()) {
-    for (SemIR::InstId inst_id :
-         llvm::reverse(context.inst_blocks().Get(param_patterns_id))) {
-      match.AddWork(
+    for (SemIR::InstId inst_id : context.inst_blocks().Get(param_patterns_id)) {
+      match.Match(
+          &state,
           {.pattern_id = inst_id,
            .work = MatchContext::PreWork{.scrutinee_id = SemIR::InstId::None}});
     }
   }
-  match.DoWork(context);
-  auto explicit_end = SemIR::CallParamIndex(match.param_count());
+  auto explicit_end = SemIR::CallParamIndex(state.call_params.size());
 
   for (auto return_pattern_id :
        context.inst_blocks().GetOrEmpty(return_patterns_id)) {
-    match.AddWork(
+    match.Match(
+        &state,
         {.pattern_id = return_pattern_id,
          .work = MatchContext::PreWork{.scrutinee_id = SemIR::InstId::None}});
   }
-  match.DoWork(context);
-  auto return_end = SemIR::CallParamIndex(match.param_count());
+  auto return_end = SemIR::CallParamIndex(state.call_params.size());
+  CARBON_CHECK(state.call_params.size() == state.call_param_patterns.size());
 
-  match.DoWork(context);
-  auto blocks = std::move(match).GetCallParams(context);
-  return {.call_param_patterns_id = blocks.call_param_patterns_id,
-          .call_params_id = blocks.call_params_id,
+  return {.call_param_patterns_id =
+              context.inst_blocks().Add(state.call_param_patterns),
+          .call_params_id = context.inst_blocks().Add(state.call_params),
           .param_ranges = {implicit_end, explicit_end, return_end}};
 }
 
@@ -937,16 +889,31 @@ auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
                         llvm::ArrayRef<SemIR::InstId> arg_refs,
                         llvm::ArrayRef<SemIR::InstId> return_arg_ids,
                         bool is_operator_syntax) -> SemIR::InstBlockId {
-  MatchContext match(MatchKind::Caller, specific_id);
+  CallerState state = {.callee_specific_id = specific_id};
+  MatchContext match(context);
+
+  if (self_pattern_id.has_value()) {
+    match.Match(&state,
+                {.pattern_id = self_pattern_id,
+                 .work = MatchContext::PreWork{.scrutinee_id = self_arg_id},
+                 .allow_unmarked_ref = true});
+  }
+
+  for (auto [arg_id, param_pattern_id] : llvm::zip_equal(
+           arg_refs, context.inst_blocks().GetOrEmpty(param_patterns_id))) {
+    match.Match(&state, {.pattern_id = param_pattern_id,
+                         .work = MatchContext::PreWork{.scrutinee_id = arg_id},
+                         .allow_unmarked_ref = is_operator_syntax});
+  }
 
   auto return_patterns = context.inst_blocks().GetOrEmpty(return_patterns_id);
   // Track the return storage, if present.
   for (auto [return_pattern_id, return_arg_id] :
        llvm::zip_equal(return_patterns, return_arg_ids)) {
     if (return_arg_id.has_value()) {
-      match.AddWork(
-          {.pattern_id = return_pattern_id,
-           .work = MatchContext::PreWork{.scrutinee_id = return_arg_id}});
+      match.Match(&state, {.pattern_id = return_pattern_id,
+                           .work = MatchContext::PreWork{.scrutinee_id =
+                                                             return_arg_id}});
     } else {
       CARBON_CHECK(return_arg_ids.size() == 1,
                    "TODO: do the match even if return_arg_id is None, so that "
@@ -954,30 +921,16 @@ auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
     }
   }
 
-  // Check type conversions per-element.
-  for (auto [arg_id, param_pattern_id] : llvm::reverse(llvm::zip_equal(
-           arg_refs, context.inst_blocks().GetOrEmpty(param_patterns_id)))) {
-    match.AddWork({.pattern_id = param_pattern_id,
-                   .work = MatchContext::PreWork{.scrutinee_id = arg_id},
-                   .allow_unmarked_ref = is_operator_syntax});
-  }
-
-  if (self_pattern_id.has_value()) {
-    match.AddWork({.pattern_id = self_pattern_id,
-                   .work = MatchContext::PreWork{.scrutinee_id = self_arg_id},
-                   .allow_unmarked_ref = true});
-  }
-
-  match.DoWork(context);
-  return std::move(match).GetCallArgs(context);
+  return context.inst_blocks().Add(state.call_args);
 }
 
 auto LocalPatternMatch(Context& context, SemIR::InstId pattern_id,
                        SemIR::InstId scrutinee_id) -> void {
-  MatchContext match(MatchKind::Local);
-  match.AddWork({.pattern_id = pattern_id,
-                 .work = MatchContext::PreWork{.scrutinee_id = scrutinee_id}});
-  match.DoWork(context);
+  LocalState state;
+  MatchContext match(context);
+  match.Match(&state,
+              {.pattern_id = pattern_id,
+               .work = MatchContext::PreWork{.scrutinee_id = scrutinee_id}});
 }
 
 }  // namespace Carbon::Check