Ver código fonte

Allocate `CallParamIndex`es eagerly (#6540)

This approach is more robust because there's no intermediate state where
the `ParamPattern` insts have been created, but don't yet have their
final values.
Geoff Romer 3 meses atrás
pai
commit
b72bfb918b

+ 11 - 3
toolchain/check/cpp/import.cpp

@@ -22,6 +22,7 @@
 #include "common/ostream.h"
 #include "common/raw_string_ostream.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
 #include "toolchain/base/int.h"
@@ -1257,9 +1258,10 @@ static auto GetReturnInfo(Context& context, SemIR::LocId loc_id,
       context,
       MakeImportedLocIdAndInst(
           context, return_type_import_ir_inst_id,
-          SemIR::OutParamPattern({.type_id = pattern_type_id,
-                                  .subpattern_id = return_slot_pattern_id,
-                                  .index = SemIR::CallParamIndex::None})));
+          SemIR::OutParamPattern(
+              {.type_id = pattern_type_id,
+               .subpattern_id = return_slot_pattern_id,
+               .index = context.full_pattern_stack().NextCallParamIndex()})));
   auto return_patterns_id = context.inst_blocks().Add({param_pattern_id});
   return {.return_type_inst_id = type_inst_id,
           .return_patterns_id = return_patterns_id};
@@ -1288,11 +1290,16 @@ static auto CreateFunctionSignatureInsts(Context& context, SemIR::LocId loc_id,
                                          clang::FunctionDecl* clang_decl,
                                          int num_params)
     -> std::optional<FunctionSignatureInsts> {
+  context.full_pattern_stack().PushFullPattern(
+      FullPatternStack::Kind::ImplicitParamList);
+  std::optional pop = llvm::make_scope_exit(
+      [&context] { context.full_pattern_stack().PopFullPattern(); });
   auto implicit_param_patterns_id =
       MakeImplicitParamPatternsBlockId(context, loc_id, *clang_decl);
   if (!implicit_param_patterns_id.has_value()) {
     return std::nullopt;
   }
+  context.full_pattern_stack().EndImplicitParamList();
   auto param_patterns_id =
       MakeParamPatternsBlockId(context, loc_id, *clang_decl, num_params);
   if (!param_patterns_id.has_value()) {
@@ -1303,6 +1310,7 @@ static auto CreateFunctionSignatureInsts(Context& context, SemIR::LocId loc_id,
   if (return_type_inst_id == SemIR::ErrorInst::TypeInstId) {
     return std::nullopt;
   }
+  pop.reset();
 
   auto call_params_id =
       CalleePatternMatch(context, implicit_param_patterns_id, param_patterns_id,

+ 3 - 0
toolchain/check/custom_witness.cpp

@@ -47,6 +47,8 @@ static auto MakeNoOpFunction(Context& context, SemIR::LocId loc_id,
   context.scope_stack().PushForDeclName();
   context.inst_block_stack().Push();
   context.pattern_block_stack().Push();
+  context.full_pattern_stack().PushFullPattern(
+      FullPatternStack::Kind::ExplicitParamList);
 
   BeginSubpattern(context);
   auto type_id = GetFacetAsType(context, loc_id, self_const_id);
@@ -61,6 +63,7 @@ static auto MakeNoOpFunction(Context& context, SemIR::LocId loc_id,
                          /*param_patterns_id=*/SemIR::InstBlockId::Empty,
                          /*return_patterns_id=*/SemIR::InstBlockId::None);
 
+  context.full_pattern_stack().PopFullPattern();
   auto pattern_block_id = context.pattern_block_stack().Pop();
   auto decl_block_id = context.inst_block_stack().Pop();
   context.scope_stack().Pop();

+ 11 - 0
toolchain/check/full_pattern_stack.h

@@ -46,6 +46,7 @@ class FullPatternStack {
   auto PushFullPattern(Kind kind) -> void {
     kind_stack_.push_back(kind);
     bind_name_stack_.PushArray();
+    param_index_stack_.push_back(SemIR::CallParamIndex(0));
   }
 
   // Marks the end of an implicit parameter list, and the presumptive start
@@ -87,6 +88,7 @@ class FullPatternStack {
   auto PopFullPattern() -> void {
     kind_stack_.pop_back();
     bind_name_stack_.PopArray();
+    param_index_stack_.pop_back();
   }
 
   // Records that `name_id` was introduced by the current full-pattern.
@@ -102,6 +104,13 @@ class FullPatternStack {
                  kind_stack_.size());
   }
 
+  // Allocates the next unallocated CallParamIndex, starting from 0.
+  auto NextCallParamIndex() -> SemIR::CallParamIndex {
+    auto result = param_index_stack_.back();
+    ++param_index_stack_.back().index;
+    return result;
+  }
+
  private:
   LexicalLookup* lookup_;
 
@@ -112,6 +121,8 @@ class FullPatternStack {
     SemIR::InstId inst_id;
   };
   ArrayStack<LookupEntry> bind_name_stack_;
+
+  llvm::SmallVector<SemIR::CallParamIndex> param_index_stack_;
 };
 
 }  // namespace Carbon::Check

+ 2 - 2
toolchain/check/handle_binding_pattern.cpp

@@ -237,13 +237,13 @@ static auto HandleAnyBindingPattern(Context& context, Parse::NodeId node_id,
               context, node_id,
               {.type_id = type_id,
                .subpattern_id = result_inst_id,
-               .index = SemIR::CallParamIndex::None});
+               .index = context.full_pattern_stack().NextCallParamIndex()});
         } else {
           result_inst_id = AddPatternInst<SemIR::ValueParamPattern>(
               context, node_id,
               {.type_id = type_id,
                .subpattern_id = result_inst_id,
-               .index = SemIR::CallParamIndex::None});
+               .index = context.full_pattern_stack().NextCallParamIndex()});
         }
       }
       context.node_stack().Push(node_id, result_inst_id);

+ 1 - 1
toolchain/check/handle_function.cpp

@@ -79,7 +79,7 @@ auto HandleParseNode(Context& context, Parse::ReturnTypeId node_id) -> bool {
       context, node_id,
       {.type_id = pattern_type_id,
        .subpattern_id = return_slot_pattern_id,
-       .index = SemIR::CallParamIndex::None});
+       .index = context.full_pattern_stack().NextCallParamIndex()});
   context.node_stack().Push(node_id, param_pattern_id);
   return true;
 }

+ 1 - 1
toolchain/check/handle_let_and_var.cpp

@@ -118,7 +118,7 @@ auto HandleParseNode(Context& context, Parse::VariablePatternId node_id)
           context, node_id,
           {.type_id = type_id,
            .subpattern_id = subpattern_id,
-           .index = SemIR::CallParamIndex::None});
+           .index = context.full_pattern_stack().NextCallParamIndex()});
       break;
     case FullPatternStack::Kind::NameBindingDecl:
       break;

+ 6 - 5
toolchain/check/pattern.cpp

@@ -156,11 +156,12 @@ auto AddParamPattern(Context& context, SemIR::LocId loc_id,
   pattern_id = AddPatternInst(
       context,
       SemIR::LocIdAndInst::UncheckedLoc(
-          loc_id, SemIR::AnyParamPattern{
-                      .kind = param_pattern_kind,
-                      .type_id = context.insts().Get(pattern_id).type_id(),
-                      .subpattern_id = pattern_id,
-                      .index = SemIR::CallParamIndex::None}));
+          loc_id,
+          SemIR::AnyParamPattern{
+              .kind = param_pattern_kind,
+              .type_id = context.insts().Get(pattern_id).type_id(),
+              .subpattern_id = pattern_id,
+              .index = context.full_pattern_stack().NextCallParamIndex()}));
 
   return pattern_id;
 }

+ 1 - 21
toolchain/check/pattern_match.cpp

@@ -61,7 +61,7 @@ class MatchContext {
   // specific.
   explicit MatchContext(MatchKind kind, SemIR::SpecificId callee_specific_id =
                                             SemIR::SpecificId::None)
-      : next_index_(0), kind_(kind), callee_specific_id_(callee_specific_id) {}
+      : kind_(kind), callee_specific_id_(callee_specific_id) {}
 
   // Adds a work item to the stack.
   auto AddWork(WorkItem work_item) -> void { stack_.push_back(work_item); }
@@ -73,13 +73,6 @@ class MatchContext {
   auto DoWork(Context& context) -> SemIR::InstBlockId;
 
  private:
-  // Allocates the next unallocated RuntimeParamIndex, starting from 0.
-  auto NextRuntimeIndex() -> SemIR::CallParamIndex {
-    auto result = next_index_;
-    ++next_index_.index;
-    return result;
-  }
-
   // Emits the pattern-match insts necessary to match the pattern inst
   // `entry.pattern_id` against the scrutinee value `entry.scrutinee_id`, and
   // adds to `stack_` any work necessary to traverse into its subpatterns. This
@@ -120,9 +113,6 @@ class MatchContext {
   // The stack of work to be processed.
   llvm::SmallVector<WorkItem> stack_;
 
-  // The next index to be allocated by `NextRuntimeIndex`.
-  SemIR::CallParamIndex next_index_;
-
   // The pending results that will be returned by the current `DoWork` call.
   // It represents the contents of the `Call` arguments block when kind_
   // is Caller, or the `Call` parameters block when kind_ is Callee
@@ -277,10 +267,6 @@ auto MatchContext::DoEmitPatternMatch(Context& context,
       break;
     }
     case MatchKind::Callee: {
-      CARBON_CHECK(!param_pattern.index.has_value(),
-                   "ValueParamPattern index set before callee pattern match");
-      param_pattern.index = NextRuntimeIndex();
-      ReplaceInstBeforeConstantUse(context, entry.pattern_id, param_pattern);
       auto param_id = AddInst<SemIR::ValueParam>(
           context, SemIR::LocId(entry.pattern_id),
           {.type_id =
@@ -329,9 +315,6 @@ auto MatchContext::DoEmitPatternMatch(Context& context,
       break;
     }
     case MatchKind::Callee: {
-      CARBON_CHECK(!param_pattern.index.has_value());
-      param_pattern.index = NextRuntimeIndex();
-      ReplaceInstBeforeConstantUse(context, entry.pattern_id, param_pattern);
       auto param_id = AddInst<SemIR::RefParam>(
           context, SemIR::LocId(entry.pattern_id),
           {.type_id =
@@ -374,9 +357,6 @@ auto MatchContext::DoEmitPatternMatch(Context& context,
     case MatchKind::Callee: {
       // TODO: Consider ways to address near-duplication with the
       // other ParamPattern cases.
-      CARBON_CHECK(!param_pattern.index.has_value());
-      param_pattern.index = NextRuntimeIndex();
-      ReplaceInstBeforeConstantUse(context, entry.pattern_id, param_pattern);
       auto param_id = AddInst<SemIR::OutParam>(
           context, SemIR::LocId(entry.pattern_id),
           {.type_id =

+ 4 - 4
toolchain/check/testdata/function/declaration/fail_pattern_in_signature.carbon

@@ -36,14 +36,14 @@ fn F((a: {}, b: {}), c: {});
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %F.decl: %F.type = fn_decl @F [concrete = constants.%F] {
 // CHECK:STDOUT:     %a.patt: %pattern_type.a96 = value_binding_pattern a [concrete]
-// CHECK:STDOUT:     %a.param_patt: %pattern_type.a96 = value_param_pattern %a.patt, call_param<none> [concrete]
+// CHECK:STDOUT:     %a.param_patt: %pattern_type.a96 = value_param_pattern %a.patt, call_param0 [concrete]
 // CHECK:STDOUT:     %b.patt: %pattern_type.a96 = value_binding_pattern b [concrete]
-// CHECK:STDOUT:     %b.param_patt: %pattern_type.a96 = value_param_pattern %b.patt, call_param<none> [concrete]
+// CHECK:STDOUT:     %b.param_patt: %pattern_type.a96 = value_param_pattern %b.patt, call_param1 [concrete]
 // CHECK:STDOUT:     %.loc19_19: %pattern_type.de4 = tuple_pattern (%a.param_patt, %b.param_patt) [concrete]
 // CHECK:STDOUT:     %c.patt: %pattern_type.a96 = value_binding_pattern c [concrete]
-// CHECK:STDOUT:     %c.param_patt: %pattern_type.a96 = value_param_pattern %c.patt, call_param0 [concrete]
+// CHECK:STDOUT:     %c.param_patt: %pattern_type.a96 = value_param_pattern %c.patt, call_param2 [concrete]
 // CHECK:STDOUT:   } {
-// CHECK:STDOUT:     %c.param: %empty_struct_type = value_param call_param0
+// CHECK:STDOUT:     %c.param: %empty_struct_type = value_param call_param2
 // CHECK:STDOUT:     %.loc19_26.1: type = splice_block %.loc19_26.3 [concrete = constants.%empty_struct_type] {
 // CHECK:STDOUT:       %.loc19_26.2: %empty_struct_type = struct_literal () [concrete = constants.%empty_struct]
 // CHECK:STDOUT:       %.loc19_26.3: type = converted %.loc19_26.2, constants.%empty_struct_type [concrete = constants.%empty_struct_type]

+ 1 - 1
toolchain/check/thunk.cpp

@@ -127,7 +127,7 @@ static auto ClonePattern(Context& context, SemIR::SpecificId specific_id,
         {.kind = param->kind,
          .type_id = get_type(param_id),
          .subpattern_id = new_pattern_id,
-         .index = SemIR::CallParamIndex::None});
+         .index = param->index});
   }
 
   return new_pattern_id;