Jelajahi Sumber

Don't require `ref` tags in thunks (#7115)

This enables thunking to work when the function has `ref` parameters,
without jumping through hoops to add `ref` tags in the desugared
function body.

This also renames `is_operator_syntax` to `is_desugared`, which is more
general and more accurate.
Geoff Romer 3 hari lalu
induk
melakukan
bd6aeae9d4

+ 9 - 9
toolchain/check/call.cpp

@@ -212,7 +212,7 @@ auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
                            SemIR::InstId callee_id,
                            const SemIR::CalleeFunction& callee_function,
                            llvm::ArrayRef<SemIR::InstId> arg_ids,
-                           bool is_operator_syntax) -> SemIR::InstId {
+                           bool is_desugared) -> SemIR::InstId {
   // If the callee is a generic function, determine the generic argument values
   // for the call.
   auto callee_specific_id = ResolveCalleeInCall(
@@ -266,9 +266,9 @@ auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
     }
   }
   // Convert the arguments to match the parameters.
-  auto converted_args_id = ConvertCallArgs(
-      context, loc_id, callee_function.self_id, arg_ids, return_arg_id, callee,
-      *callee_specific_id, is_operator_syntax);
+  auto converted_args_id =
+      ConvertCallArgs(context, loc_id, callee_function.self_id, arg_ids,
+                      return_arg_id, callee, *callee_specific_id, is_desugared);
   switch (callee.special_function_kind) {
     case SemIR::Function::SpecialFunctionKind::Thunk: {
       // If we're about to form a direct call to a thunk, inline it.
@@ -347,7 +347,7 @@ static auto PerformCallToNonFunction(Context& context, SemIR::LocId loc_id,
 }
 
 auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
-                 llvm::ArrayRef<SemIR::InstId> arg_ids, bool is_operator_syntax)
+                 llvm::ArrayRef<SemIR::InstId> arg_ids, bool is_desugared)
     -> SemIR::InstId {
   // Try treating the callee as a function first.
   auto callee = GetCallee(context.sem_ir(), callee_id);
@@ -357,16 +357,16 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
     }
     case CARBON_KIND(SemIR::CalleeFunction fn): {
       return PerformCallToFunction(context, loc_id, callee_id, fn, arg_ids,
-                                   is_operator_syntax);
+                                   is_desugared);
     }
     case CARBON_KIND(SemIR::CalleeNonFunction _): {
       return PerformCallToNonFunction(context, loc_id, callee_id, arg_ids);
     }
 
     case CARBON_KIND(SemIR::CalleeCppOverloadSet overload): {
-      return PerformCallToCppFunction(
-          context, loc_id, overload.cpp_overload_set_id, overload.self_id,
-          arg_ids, is_operator_syntax);
+      return PerformCallToCppFunction(context, loc_id,
+                                      overload.cpp_overload_set_id,
+                                      overload.self_id, arg_ids, is_desugared);
     }
   }
 }

+ 5 - 5
toolchain/check/call.h

@@ -11,22 +11,22 @@
 namespace Carbon::Check {
 
 // Checks and builds SemIR for a call to `callee_id` with arguments `args_id`,
-// where the callee is a function. `is_operator_syntax` indicates that this call
-// was generated from an operator rather than from function call syntax, so
+// where the callee is a function. `is_desugared` indicates that this call
+// was produced by desugaring, not written as a function call in user code, so
 // arguments to `ref` parameters aren't required to have `ref` tags.
 auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
                            SemIR::InstId callee_id,
                            const SemIR::CalleeFunction& callee_function,
                            llvm::ArrayRef<SemIR::InstId> arg_ids,
-                           bool is_operator_syntax) -> SemIR::InstId;
+                           bool is_desugared) -> SemIR::InstId;
 
 // Checks and builds SemIR for a call to `callee_id` with arguments `args_id`.
-// `is_operator_syntax` indicates that this call
+// `is_desugared` indicates that this call
 // was generated from an operator rather than from function call syntax, so
 // arguments to `ref` parameters aren't required to have `ref` tags.
 auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
                  llvm::ArrayRef<SemIR::InstId> arg_ids,
-                 bool is_operator_syntax = false) -> SemIR::InstId;
+                 bool is_desugared = false) -> SemIR::InstId;
 
 }  // namespace Carbon::Check
 

+ 3 - 4
toolchain/check/convert.cpp

@@ -2166,8 +2166,8 @@ auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
                      SemIR::InstId self_id,
                      llvm::ArrayRef<SemIR::InstId> arg_refs,
                      SemIR::InstId return_arg_id, const SemIR::Function& callee,
-                     SemIR::SpecificId callee_specific_id,
-                     bool is_operator_syntax) -> SemIR::InstBlockId {
+                     SemIR::SpecificId callee_specific_id, bool is_desugared)
+    -> SemIR::InstBlockId {
   auto param_patterns =
       context.inst_blocks().GetOrEmpty(callee.param_patterns_id);
   auto return_pattern_id = callee.return_pattern_id;
@@ -2188,8 +2188,7 @@ auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
 
   return CallerPatternMatch(context, callee_specific_id, callee.self_param_id,
                             callee.param_patterns_id, return_pattern_id,
-                            self_id, arg_refs, return_arg_id,
-                            is_operator_syntax);
+                            self_id, arg_refs, return_arg_id, is_desugared);
 }
 
 auto TypeExpr::ForUnsugared(Context& context, SemIR::TypeId type_id)

+ 5 - 5
toolchain/check/convert.h

@@ -206,15 +206,15 @@ auto ConvertForExplicitAs(Context& context, Parse::NodeId as_node,
 
 // Implicitly converts a set of arguments to match the parameter types in a
 // function call. Returns a block containing the converted implicit and explicit
-// argument values for runtime parameters. `is_operator_syntax` indicates that
-// this call was generated from an operator rather than from function call
-// syntax, so arguments to `ref` parameters aren't required to have `ref` tags.
+// argument values for runtime parameters. `is_desugared` indicates that this
+// call was produced by desugaring, not written as a function call in user code,
+// so arguments to `ref` parameters aren't required to have `ref` tags.
 auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
                      SemIR::InstId self_id,
                      llvm::ArrayRef<SemIR::InstId> arg_refs,
                      SemIR::InstId return_arg_id, const SemIR::Function& callee,
-                     SemIR::SpecificId callee_specific_id,
-                     bool is_operator_syntax) -> SemIR::InstBlockId;
+                     SemIR::SpecificId callee_specific_id, bool is_desugared)
+    -> SemIR::InstBlockId;
 
 // A type that has been converted for use as a type expression.
 struct TypeExpr {

+ 2 - 2
toolchain/check/cpp/call.cpp

@@ -55,7 +55,7 @@ auto PerformCallToCppFunction(Context& context, SemIR::LocId loc_id,
                               SemIR::CppOverloadSetId overload_set_id,
                               SemIR::InstId self_id,
                               llvm::ArrayRef<SemIR::InstId> arg_ids,
-                              bool is_operator_syntax) -> SemIR::InstId {
+                              bool is_desugared) -> SemIR::InstId {
   auto [template_arg_ids, function_arg_ids] =
       SplitCallArgumentList(context, arg_ids);
   auto callee_id = PerformCppOverloadResolution(
@@ -73,7 +73,7 @@ auto PerformCallToCppFunction(Context& context, SemIR::LocId loc_id,
         fn.self_id = self_id;
       }
       return PerformCallToFunction(context, loc_id, callee_id, fn,
-                                   function_arg_ids, is_operator_syntax);
+                                   function_arg_ids, is_desugared);
     }
     case CARBON_KIND(SemIR::CalleeCppOverloadSet _): {
       CARBON_FATAL("overloads can't be recursive");

+ 5 - 5
toolchain/check/cpp/call.h

@@ -21,10 +21,10 @@ auto ConvertArgsToTemplateArgs(Context& context,
                                bool diagnose = true) -> bool;
 
 // Checks and builds SemIR for a call to a C++ function in the given overload
-// set with self `self_id` and arguments `arg_ids`. `is_operator_syntax`
-// indicates that this call was generated from an operator rather than from
-// function call syntax, so arguments to `ref` parameters aren't required to
-// have `ref` tags.
+// set with self `self_id` and arguments `arg_ids`. `is_desugared`
+// indicates that this call was was produced by desugaring, not written as a
+// function call in user code, so arguments to `ref` parameters aren't required
+// to have `ref` tags.
 //
 // Chooses the best viable C++ function by performing Clang overloading
 // resolution over the overload set.
@@ -42,7 +42,7 @@ auto PerformCallToCppFunction(Context& context, SemIR::LocId loc_id,
                               SemIR::CppOverloadSetId overload_set_id,
                               SemIR::InstId self_id,
                               llvm::ArrayRef<SemIR::InstId> arg_ids,
-                              bool is_operator_syntax) -> SemIR::InstId;
+                              bool is_desugared) -> SemIR::InstId;
 
 // Checks and builds SemIR for a call to a C++ template name with arguments
 // `arg_ids`.

+ 4 - 4
toolchain/check/operator.cpp

@@ -86,7 +86,7 @@ auto BuildUnaryOperator(Context& context, SemIR::LocId loc_id, Operator op,
     // argument. Otherwise fall through to call it with a self argument.
     if (op_fn_id.has_value() && !IsCppOperatorMethod(context, op_fn_id)) {
       return PerformCall(context, loc_id, op_fn_id, {operand_id},
-                         /*is_operator_syntax=*/true);
+                         /*is_desugared=*/true);
     }
   }
 
@@ -105,7 +105,7 @@ auto BuildUnaryOperator(Context& context, SemIR::LocId loc_id, Operator op,
 
   // Form `bound_op()`.
   return PerformCall(context, loc_id, bound_op_id, {},
-                     /*is_operator_syntax=*/true);
+                     /*is_desugared=*/true);
 }
 
 auto BuildBinaryOperator(Context& context, SemIR::LocId loc_id, Operator op,
@@ -139,7 +139,7 @@ auto BuildBinaryOperator(Context& context, SemIR::LocId loc_id, Operator op,
     // call argument.
     if (op_fn_id.has_value() && !IsCppOperatorMethod(context, op_fn_id)) {
       return PerformCall(context, loc_id, op_fn_id, {lhs_id, rhs_id},
-                         /*is_operator_syntax=*/true);
+                         /*is_desugared=*/true);
     }
   }
 
@@ -158,7 +158,7 @@ auto BuildBinaryOperator(Context& context, SemIR::LocId loc_id, Operator op,
 
   // Form `bound_op(rhs)`.
   return PerformCall(context, loc_id, bound_op_id, {rhs_id},
-                     /*is_operator_syntax=*/true);
+                     /*is_desugared=*/true);
 }
 
 }  // namespace Carbon::Check

+ 2 - 2
toolchain/check/pattern_match.cpp

@@ -995,7 +995,7 @@ auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
                         SemIR::InstId return_pattern_id,
                         SemIR::InstId self_arg_id,
                         llvm::ArrayRef<SemIR::InstId> arg_refs,
-                        SemIR::InstId return_arg_id, bool is_operator_syntax)
+                        SemIR::InstId return_arg_id, bool is_desugared)
     -> SemIR::InstBlockId {
   CallerState state = {.callee_specific_id = specific_id};
   MatchContext match(context);
@@ -1011,7 +1011,7 @@ auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
            arg_refs, context.inst_blocks().GetOrEmpty(param_patterns_id))) {
     match.Match(&state, {.pattern_id = param_pattern_id,
                          .work = MatchContext::PreWork{.scrutinee_id = arg_id},
-                         .allow_unmarked_ref = is_operator_syntax});
+                         .allow_unmarked_ref = is_desugared});
   }
 
   // Track the return storage, if present.

+ 3 - 3
toolchain/check/pattern_match.h

@@ -64,8 +64,8 @@ auto ThunkPatternMatch(Context& context, SemIR::InstId self_pattern_id,
 
 // Emits the pattern-match IR for matching the given arguments with the given
 // parameter patterns, and returns an inst block of the arguments that should
-// be passed to the `Call` inst. `is_operator_syntax` indicates that this call
-// was generated from an operator rather than from function call syntax, so
+// be passed to the `Call` inst. `is_desugared` indicates that this call
+// was produced by desugaring, not written as a function call in user code, so
 // arguments to `ref` parameters aren't required to have `ref` tags.
 auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
                         SemIR::InstId self_pattern_id,
@@ -73,7 +73,7 @@ auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
                         SemIR::InstId return_pattern_id,
                         SemIR::InstId self_arg_id,
                         llvm::ArrayRef<SemIR::InstId> arg_refs,
-                        SemIR::InstId return_arg_id, bool is_operator_syntax)
+                        SemIR::InstId return_arg_id, bool is_desugared)
     -> SemIR::InstBlockId;
 
 // Emits the pattern-match IR for a local pattern matching operation with the

+ 68 - 0
toolchain/check/testdata/impl/impl_thunk.carbon

@@ -108,6 +108,23 @@ impl B as X {
   //@dump-sem-ir-end
 }
 
+// --- inheritance_ref_conversion.carbon
+
+library "[[@TEST_NAME]]";
+
+base class A {}
+base class B { extend base: A; }
+
+interface X {
+  fn F[self: Self](ref other: Self);
+}
+
+impl B as X {
+  //@dump-sem-ir-begin
+  fn F[self: A](ref other: A);
+  //@dump-sem-ir-end
+}
+
 // --- fail_inheritance_value_conversion_copy_return.carbon
 
 library "[[@TEST_NAME]]";
@@ -758,6 +775,57 @@ impl () as I({}) {
 // CHECK:STDOUT:   return %ptr.as.Copy.impl.Op.call
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
+// CHECK:STDOUT: --- inheritance_ref_conversion.carbon
+// CHECK:STDOUT:
+// CHECK:STDOUT: constants {
+// CHECK:STDOUT:   %A: type = class_type @A [concrete]
+// CHECK:STDOUT:   %B: type = class_type @B [concrete]
+// CHECK:STDOUT:   %empty_tuple.type: type = tuple_type () [concrete]
+// CHECK:STDOUT:   %pattern_type.1ab: type = pattern_type %A [concrete]
+// CHECK:STDOUT:   %B.as.X.impl.F.type.421a66.1: type = fn_type @B.as.X.impl.F.loc13_30.1 [concrete]
+// CHECK:STDOUT:   %B.as.X.impl.F.8ec460.1: %B.as.X.impl.F.type.421a66.1 = struct_value () [concrete]
+// CHECK:STDOUT:   %B.as.X.impl.F.type.421a66.2: type = fn_type @B.as.X.impl.F.loc13_30.2 [concrete]
+// CHECK:STDOUT:   %B.as.X.impl.F.8ec460.2: %B.as.X.impl.F.type.421a66.2 = struct_value () [concrete]
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: impl @B.as.X.impl: %B.ref as %X.ref {
+// CHECK:STDOUT:   %B.as.X.impl.F.decl.loc13_30.1: %B.as.X.impl.F.type.421a66.1 = fn_decl @B.as.X.impl.F.loc13_30.1 [concrete = constants.%B.as.X.impl.F.8ec460.1] {
+// CHECK:STDOUT:     %self.param_patt: %pattern_type.1ab = value_param_pattern [concrete]
+// CHECK:STDOUT:     %self.patt: %pattern_type.1ab = at_binding_pattern self, %self.param_patt [concrete]
+// CHECK:STDOUT:     %other.param_patt: %pattern_type.1ab = ref_param_pattern [concrete]
+// CHECK:STDOUT:     %other.patt: %pattern_type.1ab = at_binding_pattern other, %other.param_patt [concrete]
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     %self.param: %A = value_param call_param0
+// CHECK:STDOUT:     %A.ref.loc13_14: type = name_ref A, file.%A.decl [concrete = constants.%A]
+// CHECK:STDOUT:     %self: %A = value_binding self, %self.param
+// CHECK:STDOUT:     %other.param: ref %A = ref_param call_param1
+// CHECK:STDOUT:     %A.ref.loc13_28: type = name_ref A, file.%A.decl [concrete = constants.%A]
+// CHECK:STDOUT:     %other: ref %A = ref_binding other, %other.param
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %B.as.X.impl.F.decl.loc13_30.2: %B.as.X.impl.F.type.421a66.2 = fn_decl @B.as.X.impl.F.loc13_30.2 [concrete = constants.%B.as.X.impl.F.8ec460.2] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:
+// CHECK:STDOUT: !members:
+// CHECK:STDOUT:   .A = <poisoned>
+// CHECK:STDOUT:   .F = %B.as.X.impl.F.decl.loc13_30.1
+// CHECK:STDOUT:   witness = %X.impl_witness
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @B.as.X.impl.F.loc13_30.1(%self.param: %A, %other.param: ref %A);
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @B.as.X.impl.F.loc13_30.2(%self.param: %B, %other.param: ref %B) [thunk @B.as.X.impl.%B.as.X.impl.F.decl.loc13_30.1] {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   %F.ref: %B.as.X.impl.F.type.421a66.1 = name_ref F, @B.as.X.impl.%B.as.X.impl.F.decl.loc13_30.1 [concrete = constants.%B.as.X.impl.F.8ec460.1]
+// CHECK:STDOUT:   %B.as.X.impl.F.bound: <bound method> = bound_method %self.param, %F.ref
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:   %B.as.X.impl.F.call: init %empty_tuple.type = call %B.as.X.impl.F.bound(%.loc8_12.3, %.loc8_29.2)
+// CHECK:STDOUT:   return
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
 // CHECK:STDOUT: --- fail_inheritance_value_conversion_copy_return.carbon
 // CHECK:STDOUT:
 // CHECK:STDOUT: constants {

+ 4 - 22
toolchain/check/thunk.cpp

@@ -296,7 +296,10 @@ auto PerformThunkCall(Context& context, SemIR::LocId loc_id,
                                             args.consume_front(), callee_id);
   }
 
-  return PerformCall(context, loc_id, callee_id, args);
+  // We treat this as an operator call because it's a call that's synthesized
+  // by the toolchain, not written by the user.
+  return PerformCall(context, loc_id, callee_id, args,
+                     /*is_desugared=*/true);
 }
 
 // Build a call to a function that forwards the arguments of the enclosing
@@ -406,27 +409,6 @@ auto BuildThunkDefinitionForExport(Context& context,
     call_param_ids.pop_back();
   }
 
-  auto callee_param_ids =
-      context.inst_blocks().Get(callee_function.call_param_patterns_id);
-
-  // If any explicit parameters of the callee are `ref` parameters,
-  // modify the corresponding call arguments to be `ref` tagged.
-  for (auto index = thunk_function.call_param_ranges.explicit_begin().index;
-       index < thunk_function.call_param_ranges.explicit_end().index; index++) {
-    if (context.insts().Is<SemIR::RefParamPattern>(callee_param_ids[index])) {
-      auto& call_param_id = call_param_ids[index];
-      auto type = context.insts().Get(call_param_id).type_id();
-      SemIR::LocId loc_id(thunk_id);
-      call_param_id =
-          AddInst(context, SemIR::LocIdAndInst::RuntimeVerified(
-                               context.sem_ir(), SemIR::LocId(call_param_id),
-                               SemIR::RefTagExpr{
-                                   .type_id = type,
-                                   .expr_id = call_param_id,
-                               }));
-    }
-  }
-
   auto call_id = BuildThunkCall(context, thunk_function_id, callee_id,
                                 param_pattern_ids, call_param_ids);
   if (thunk_has_return_param) {