Просмотр исходного кода

Refactor PerformCall (#5302)

`PerformCall` has gotten a little long, 130 lines; splitting out some
functions to help size. `PerformCallToFunction` in particular seems
symmetric with `PerformCallToGenericClass` and
`PerformCallToGenericInterface`.
Jon Ross-Perkins 1 год назад
Родитель
Сommit
dc8fab1d81
1 измененных файлов с 110 добавлено и 79 удалено
  1. 110 79
      toolchain/check/call.cpp

+ 110 - 79
toolchain/check/call.cpp

@@ -16,6 +16,7 @@
 #include "toolchain/diagnostics/format_providers.h"
 #include "toolchain/sem_ir/builtin_function_kind.h"
 #include "toolchain/sem_ir/entity_with_params_base.h"
+#include "toolchain/sem_ir/function.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/inst.h"
 #include "toolchain/sem_ir/typed_insts.h"
@@ -129,35 +130,77 @@ static auto PerformCallToGenericInterface(
       FacetTypeFromInterface(context, interface_id, *callee_specific_id));
 }
 
-auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
-                 llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::InstId {
-  // Identify the function we're calling.
-  auto callee_function = GetCalleeFunction(context.sem_ir(), callee_id);
-  if (!callee_function.function_id.has_value()) {
-    auto type_inst =
-        context.types().GetAsInst(context.insts().Get(callee_id).type_id());
-    CARBON_KIND_SWITCH(type_inst) {
-      case CARBON_KIND(SemIR::GenericClassType generic_class): {
-        return PerformCallToGenericClass(
-            context, loc_id, generic_class.class_id,
-            generic_class.enclosing_specific_id, arg_ids);
-      }
-      case CARBON_KIND(SemIR::GenericInterfaceType generic_interface): {
-        return PerformCallToGenericInterface(
-            context, loc_id, generic_interface.interface_id,
-            generic_interface.enclosing_specific_id, arg_ids);
-      }
-      default: {
-        if (!callee_function.is_error) {
-          CARBON_DIAGNOSTIC(CallToNonCallable, Error,
-                            "value of type {0} is not callable", TypeOfInstId);
-          context.emitter().Emit(loc_id, CallToNonCallable, callee_id);
-        }
-        return SemIR::ErrorInst::SingletonInstId;
-      }
-    }
+// Builds an appropriate specific function for the callee, also handling
+// instance binding.
+static auto BuildCalleeSpecificFunction(
+    Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
+    SemIR::InstId callee_function_self_type_id,
+    SemIR::SpecificId callee_specific_id) -> SemIR::InstId {
+  auto generic_callee_id = callee_id;
+
+  // Strip off a bound_method so that we can form a constant specific callee.
+  auto bound_method = context.insts().TryGetAs<SemIR::BoundMethod>(callee_id);
+  if (bound_method) {
+    generic_callee_id = bound_method->function_decl_id;
+  }
+
+  // Form a specific callee.
+  if (callee_function_self_type_id.has_value()) {
+    // This is an associated function in an interface; the callee is the
+    // specific function in the impl that corresponds to the specific function
+    // we deduced.
+    callee_id = GetOrAddInst(
+        context, context.insts().GetLocId(generic_callee_id),
+        SemIR::SpecificImplFunction{
+            .type_id = GetSingletonType(
+                context, SemIR::SpecificFunctionType::SingletonInstId),
+            .callee_id = generic_callee_id,
+            .specific_id = callee_specific_id});
+  } else {
+    // This is a regular generic function. The callee is the specific function
+    // we deduced.
+    callee_id = GetOrAddInst(
+        context, context.insts().GetLocId(generic_callee_id),
+        SemIR::SpecificFunction{
+            .type_id = GetSingletonType(
+                context, SemIR::SpecificFunctionType::SingletonInstId),
+            .callee_id = generic_callee_id,
+            .specific_id = callee_specific_id});
+  }
+
+  // Add the `self` argument back if there was one.
+  if (bound_method) {
+    callee_id =
+        GetOrAddInst<SemIR::BoundMethod>(context, loc_id,
+                                         {.type_id = bound_method->type_id,
+                                          .object_id = bound_method->object_id,
+                                          .function_decl_id = callee_id});
   }
 
+  return callee_id;
+}
+
+// Returns the return type, with a scoped annotation for any diagnostics.
+static auto CheckCalleeFunctionReturnType(Context& context, SemIR::LocId loc_id,
+                                          SemIR::FunctionId callee_function_id,
+                                          SemIR::SpecificId callee_specific_id)
+    -> SemIR::ReturnTypeInfo {
+  auto& function = context.functions().Get(callee_function_id);
+  Diagnostics::AnnotationScope annotate_diagnostics(
+      &context.emitter(), [&](auto& builder) {
+        CARBON_DIAGNOSTIC(IncompleteReturnTypeHere, Note,
+                          "return type declared here");
+        builder.Note(function.return_slot_pattern_id, IncompleteReturnTypeHere);
+      });
+  return CheckFunctionReturnType(context, loc_id, function, callee_specific_id);
+}
+
+// Performs a call where the callee is a function.
+static auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
+                                  SemIR::InstId callee_id,
+                                  const SemIR::CalleeFunction& callee_function,
+                                  llvm::ArrayRef<SemIR::InstId> arg_ids)
+    -> SemIR::InstId {
   // If the callee is a generic function, determine the generic argument values
   // for the call.
   auto callee_specific_id = ResolveCalleeInCall(
@@ -169,62 +212,15 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
   }
 
   if (callee_specific_id->has_value()) {
-    auto generic_callee_id = callee_id;
-
-    // Strip off a bound_method so that we can form a constant specific callee.
-    auto bound_method = context.insts().TryGetAs<SemIR::BoundMethod>(callee_id);
-    if (bound_method) {
-      generic_callee_id = bound_method->function_decl_id;
-    }
-
-    // Form a specific callee.
-    if (callee_function.self_type_id.has_value()) {
-      // This is an associated function in an interface; the callee is the
-      // specific function in the impl that corresponds to the specific function
-      // we deduced.
-      callee_id = GetOrAddInst(
-          context, context.insts().GetLocId(generic_callee_id),
-          SemIR::SpecificImplFunction{
-              .type_id = GetSingletonType(
-                  context, SemIR::SpecificFunctionType::SingletonInstId),
-              .callee_id = generic_callee_id,
-              .specific_id = *callee_specific_id});
-    } else {
-      // This is a regular generic function. The callee is the specific function
-      // we deduced.
-      callee_id = GetOrAddInst(
-          context, context.insts().GetLocId(generic_callee_id),
-          SemIR::SpecificFunction{
-              .type_id = GetSingletonType(
-                  context, SemIR::SpecificFunctionType::SingletonInstId),
-              .callee_id = generic_callee_id,
-              .specific_id = *callee_specific_id});
-    }
-
-    // Add the `self` argument back if there was one.
-    if (bound_method) {
-      callee_id = GetOrAddInst<SemIR::BoundMethod>(
-          context, loc_id,
-          {.type_id = bound_method->type_id,
-           .object_id = bound_method->object_id,
-           .function_decl_id = callee_id});
-    }
+    callee_id = BuildCalleeSpecificFunction(context, loc_id, callee_id,
+                                            callee_function.self_type_id,
+                                            *callee_specific_id);
   }
 
   // If there is a return slot, build storage for the result.
+  SemIR::ReturnTypeInfo return_info = CheckCalleeFunctionReturnType(
+      context, loc_id, callee_function.function_id, *callee_specific_id);
   SemIR::InstId return_slot_arg_id = SemIR::InstId::None;
-  SemIR::ReturnTypeInfo return_info = [&] {
-    auto& function = context.functions().Get(callee_function.function_id);
-    Diagnostics::AnnotationScope annotate_diagnostics(
-        &context.emitter(), [&](auto& builder) {
-          CARBON_DIAGNOSTIC(IncompleteReturnTypeHere, Note,
-                            "return type declared here");
-          builder.Note(function.return_slot_pattern_id,
-                       IncompleteReturnTypeHere);
-        });
-    return CheckFunctionReturnType(context, loc_id, function,
-                                   *callee_specific_id);
-  }();
   switch (return_info.init_repr.kind) {
     case SemIR::InitRepr::InPlace:
       // Tentatively put storage for a temporary in the function's return slot.
@@ -261,4 +257,39 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
   return call_inst_id;
 }
 
+auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
+                 llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::InstId {
+  // Try treating the callee as a function first.
+  auto callee_function = GetCalleeFunction(context.sem_ir(), callee_id);
+  if (callee_function.is_error) {
+    return SemIR::ErrorInst::SingletonInstId;
+  }
+  if (callee_function.function_id.has_value()) {
+    return PerformCallToFunction(context, loc_id, callee_id, callee_function,
+                                 arg_ids);
+  }
+
+  // Callee isn't a function, so try treating it as a generic type.
+  auto type_inst =
+      context.types().GetAsInst(context.insts().Get(callee_id).type_id());
+  CARBON_KIND_SWITCH(type_inst) {
+    case CARBON_KIND(SemIR::GenericClassType generic_class): {
+      return PerformCallToGenericClass(context, loc_id, generic_class.class_id,
+                                       generic_class.enclosing_specific_id,
+                                       arg_ids);
+    }
+    case CARBON_KIND(SemIR::GenericInterfaceType generic_interface): {
+      return PerformCallToGenericInterface(
+          context, loc_id, generic_interface.interface_id,
+          generic_interface.enclosing_specific_id, arg_ids);
+    }
+    default: {
+      CARBON_DIAGNOSTIC(CallToNonCallable, Error,
+                        "value of type {0} is not callable", TypeOfInstId);
+      context.emitter().Emit(loc_id, CallToNonCallable, callee_id);
+      return SemIR::ErrorInst::SingletonInstId;
+    }
+  }
+}
+
 }  // namespace Carbon::Check