Просмотр исходного кода

Decouple PerformCallToFunction from ReturnTypeInfo (#6572)

`ReturnTypeInfo` is built around the assumption that a function call
results in exactly one initializing expression, but with `ref` returns
there may be zero, and in the future composite return forms will enable
there to be more than one. This change removes some usages of
`ReturnTypeInfo`, and restructures the calling code to be prepared for
multiple initializing returns.
Geoff Romer 3 месяцев назад
Родитель
Сommit
87b4ca54e6

+ 29 - 33
toolchain/check/call.cpp

@@ -215,7 +215,7 @@ static auto BuildCalleeSpecificFunction(
 static auto CheckCalleeFunctionReturnType(Context& context, SemIR::LocId loc_id,
                                           SemIR::FunctionId callee_function_id,
                                           SemIR::SpecificId callee_specific_id)
-    -> SemIR::ReturnTypeInfo {
+    -> SemIR::TypeId {
   auto& function = context.functions().Get(callee_function_id);
   Diagnostics::AnnotationScope annotate_diagnostics(
       &context.emitter(), [&](auto& builder) {
@@ -247,43 +247,39 @@ auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
                                             *callee_specific_id);
   }
 
-  // If there is a return slot, build storage for the result.
-  SemIR::ReturnTypeInfo return_info = CheckCalleeFunctionReturnType(
+  auto return_type_id = CheckCalleeFunctionReturnType(
       context, loc_id, callee_function.function_id, *callee_specific_id);
-  SemIR::InstId return_slot_arg_id = SemIR::InstId::None;
-  switch (return_info.init_repr.kind) {
-    case SemIR::InitRepr::InPlace:
-    case SemIR::InitRepr::Dependent:
-      // Tentatively put storage for a temporary in the function's return slot.
-      // This will be replaced if necessary when we perform initialization.
-      return_slot_arg_id = AddInst<SemIR::TemporaryStorage>(
-          context, loc_id, {.type_id = return_info.type_id});
-      break;
-    case SemIR::InitRepr::None:
-      // For functions with an implicit return type, the return type is the
-      // empty tuple type.
-      if (!return_info.type_id.has_value()) {
-        return_info.type_id = GetTupleType(context, {});
-      }
-      break;
-    case SemIR::InitRepr::ByCopy:
-      break;
-    case SemIR::InitRepr::Abstract:
-    case SemIR::InitRepr::Incomplete:
-      // Don't form an initializing expression with an abstract or incomplete
-      // type. CheckFunctionReturnType will have diagnosed this for us if
-      // needed.
-      return_info.type_id = SemIR::ErrorInst::TypeId;
-      break;
-  }
 
   auto& callee = context.functions().Get(callee_function.function_id);
 
+  // Build storage for any output parameters.
+  llvm::SmallVector<SemIR::InstId, 1> return_arg_ids;
+  for (auto return_pattern_id :
+       context.inst_blocks().GetOrEmpty(callee.return_patterns_id)) {
+    auto arg_type_id = SemIR::ExtractScrutineeType(
+        context.sem_ir(),
+        SemIR::GetTypeOfInstInSpecific(context.sem_ir(), *callee_specific_id,
+                                       return_pattern_id));
+    switch (SemIR::InitRepr::ForType(context.sem_ir(), arg_type_id).kind) {
+      case SemIR::InitRepr::InPlace:
+      case SemIR::InitRepr::Dependent:
+        // Tentatively use storage for a temporary as the return argument.
+        // This will be replaced if necessary when we perform initialization.
+        return_arg_ids.push_back(AddInst<SemIR::TemporaryStorage>(
+            context, loc_id, {.type_id = arg_type_id}));
+        break;
+      case SemIR::InitRepr::None:
+      case SemIR::InitRepr::ByCopy:
+      case SemIR::InitRepr::Incomplete:
+      case SemIR::InitRepr::Abstract:
+        return_arg_ids.push_back(SemIR::InstId::None);
+        break;
+    }
+  }
   // Convert the arguments to match the parameters.
   auto converted_args_id = ConvertCallArgs(
-      context, loc_id, callee_function.self_id, arg_ids, return_slot_arg_id,
-      callee, *callee_specific_id, is_operator_syntax);
-
+      context, loc_id, callee_function.self_id, arg_ids, return_arg_ids, callee,
+      *callee_specific_id, is_operator_syntax);
   switch (callee.special_function_kind) {
     case SemIR::Function::SpecialFunctionKind::Thunk: {
       // If we're about to form a direct call to a thunk, inline it.
@@ -310,7 +306,7 @@ auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
     case SemIR::Function::SpecialFunctionKind::None:
     case SemIR::Function::SpecialFunctionKind::Builtin: {
       return GetOrAddInst<SemIR::Call>(context, loc_id,
-                                       {.type_id = return_info.type_id,
+                                       {.type_id = return_type_id,
                                         .callee_id = callee_id,
                                         .args_id = converted_args_id});
     }

+ 2 - 2
toolchain/check/convert.cpp

@@ -1893,7 +1893,7 @@ auto ConvertForExplicitAs(Context& context, Parse::NodeId as_node,
 auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
                      SemIR::InstId self_id,
                      llvm::ArrayRef<SemIR::InstId> arg_refs,
-                     SemIR::InstId return_slot_arg_id,
+                     llvm::ArrayRef<SemIR::InstId> return_arg_ids,
                      const SemIR::Function& callee,
                      SemIR::SpecificId callee_specific_id,
                      bool is_operator_syntax) -> SemIR::InstBlockId {
@@ -1917,7 +1917,7 @@ auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
 
   return CallerPatternMatch(context, callee_specific_id, callee.self_param_id,
                             callee.param_patterns_id, return_patterns_id,
-                            self_id, arg_refs, return_slot_arg_id,
+                            self_id, arg_refs, return_arg_ids,
                             is_operator_syntax);
 }
 

+ 1 - 1
toolchain/check/convert.h

@@ -136,7 +136,7 @@ auto ConvertForExplicitAs(Context& context, Parse::NodeId as_node,
 auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
                      SemIR::InstId self_id,
                      llvm::ArrayRef<SemIR::InstId> arg_refs,
-                     SemIR::InstId return_slot_arg_id,
+                     llvm::ArrayRef<SemIR::InstId> return_arg_ids,
                      const SemIR::Function& callee,
                      SemIR::SpecificId callee_specific_id,
                      bool is_operator_syntax) -> SemIR::InstBlockId;

+ 10 - 3
toolchain/check/function.cpp

@@ -6,6 +6,7 @@
 
 #include "common/find.h"
 #include "toolchain/check/merge.h"
+#include "toolchain/check/type.h"
 #include "toolchain/check/type_completion.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/pattern.h"
@@ -93,8 +94,7 @@ auto CheckFunctionTypeMatches(Context& context,
 
 auto CheckFunctionReturnType(Context& context, SemIR::LocId loc_id,
                              const SemIR::Function& function,
-                             SemIR::SpecificId specific_id)
-    -> SemIR::ReturnTypeInfo {
+                             SemIR::SpecificId specific_id) -> SemIR::TypeId {
   auto return_info = SemIR::ReturnTypeInfo::ForFunction(context.sem_ir(),
                                                         function, specific_id);
 
@@ -125,7 +125,14 @@ auto CheckFunctionReturnType(Context& context, SemIR::LocId loc_id,
     }
   }
 
-  return return_info;
+  if (return_info.init_repr.kind == SemIR::InitRepr::Incomplete ||
+      return_info.init_repr.kind == SemIR::InitRepr::Abstract) {
+    return SemIR::ErrorInst::TypeId;
+  }
+  if (!return_info.type_id.has_value()) {
+    return GetTupleType(context, {});
+  }
+  return return_info.type_id;
 }
 
 auto CheckFunctionDefinitionSignature(Context& context,

+ 2 - 4
toolchain/check/function.h

@@ -55,12 +55,10 @@ inline auto CheckFunctionTypeMatches(Context& context,
 
 // Checks that the return type of the specified function is complete, issuing an
 // error if not. This computes the return slot usage for the function if
-// necessary, and returns information about how the function returns its return
-// value.
+// necessary, and returns the function's return type.
 auto CheckFunctionReturnType(Context& context, SemIR::LocId loc_id,
                              const SemIR::Function& function,
-                             SemIR::SpecificId specific_id)
-    -> SemIR::ReturnTypeInfo;
+                             SemIR::SpecificId specific_id) -> SemIR::TypeId;
 
 // Checks that a function declaration's signature is suitable to support a
 // function definition. This requires the parameter types to be complete and the

+ 11 - 6
toolchain/check/pattern_match.cpp

@@ -629,17 +629,22 @@ auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
                         SemIR::InstBlockId return_patterns_id,
                         SemIR::InstId self_arg_id,
                         llvm::ArrayRef<SemIR::InstId> arg_refs,
-                        SemIR::InstId return_slot_arg_id,
+                        llvm::ArrayRef<SemIR::InstId> return_arg_ids,
                         bool is_operator_syntax) -> SemIR::InstBlockId {
   MatchContext match(MatchKind::Caller, specific_id);
 
   auto return_patterns = context.inst_blocks().GetOrEmpty(return_patterns_id);
   // Track the return storage, if present.
-  if (return_slot_arg_id.has_value()) {
-    CARBON_CHECK(return_patterns.size() == 1,
-                 "TODO: implement support for multiple return patterns");
-    match.AddWork(
-        {.pattern_id = return_patterns[0], .scrutinee_id = return_slot_arg_id});
+  for (auto [return_pattern_id, return_arg_id] :
+       llvm::zip_equal(return_patterns, return_arg_ids)) {
+    if (return_arg_id.has_value()) {
+      match.AddWork(
+          {.pattern_id = return_pattern_id, .scrutinee_id = return_arg_id});
+    } else {
+      CARBON_CHECK(return_arg_ids.size() == 1,
+                   "TODO: do the match even if return_arg_id is None, so that "
+                   "subsequent args are at the right index in the arg block");
+    }
   }
 
   // Check type conversions per-element.

+ 1 - 1
toolchain/check/pattern_match.h

@@ -43,7 +43,7 @@ auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
                         SemIR::InstBlockId return_patterns_id,
                         SemIR::InstId self_arg_id,
                         llvm::ArrayRef<SemIR::InstId> arg_refs,
-                        SemIR::InstId return_slot_arg_id,
+                        llvm::ArrayRef<SemIR::InstId> return_arg_ids,
                         bool is_operator_syntax) -> SemIR::InstBlockId;
 
 // Emits the pattern-match IR for a local pattern matching operation with the

+ 9 - 5
toolchain/check/return.cpp

@@ -129,10 +129,13 @@ auto BuildReturnWithExpr(Context& context, SemIR::LocId loc_id,
   const auto& function = GetCurrentFunctionForReturn(context);
   auto returned_var_id = GetCurrentReturnedVar(context);
   auto return_slot_id = SemIR::InstId::None;
-  auto return_info =
-      SemIR::ReturnTypeInfo::ForFunction(context.sem_ir(), function);
 
-  if (!return_info.type_id.has_value()) {
+  auto return_type_id = SemIR::TypeId::None;
+  if (function.return_type_inst_id.has_value()) {
+    return_type_id =
+        context.types().GetTypeIdForTypeInstId(function.return_type_inst_id);
+  }
+  if (!return_type_id.has_value()) {
     CARBON_DIAGNOSTIC(
         ReturnStatementDisallowExpr, Error,
         "no return expression should be provided in this context");
@@ -148,8 +151,9 @@ auto BuildReturnWithExpr(Context& context, SemIR::LocId loc_id,
     NoteReturnedVar(diag, returned_var_id);
     diag.Emit();
     expr_id = SemIR::ErrorInst::InstId;
-  } else if (!return_info.is_valid() ||
-             return_info.type_id == SemIR::ErrorInst::TypeId) {
+  } else if (!SemIR::InitRepr::ForType(context.sem_ir(), return_type_id)
+                  .is_valid() ||
+             return_type_id == SemIR::ErrorInst::TypeId) {
     // We already diagnosed that the return type is invalid. Don't try to
     // convert to it.
     expr_id = SemIR::ErrorInst::InstId;