Ver código fonte

Always build ReturnTypeInfo from a function (#6490)

This is a step toward using it to represent the return form, not just
the return type.
Geoff Romer 4 meses atrás
pai
commit
2078721e1c

+ 1 - 1
toolchain/check/testdata/class/fail_abstract.carbon

@@ -875,7 +875,7 @@ fn CallReturnAbstract() {
 // CHECK:STDOUT: fn @CallReturnAbstract() {
 // CHECK:STDOUT: !entry:
 // CHECK:STDOUT:   %ReturnAbstract.ref: %ReturnAbstract.type = name_ref ReturnAbstract, file.%ReturnAbstract.decl [concrete = constants.%ReturnAbstract]
-// CHECK:STDOUT:   %ReturnAbstract.call: init <error> = call %ReturnAbstract.ref()
+// CHECK:STDOUT:   %ReturnAbstract.call: init <error> = call %ReturnAbstract.ref(<invalid return info>)
 // CHECK:STDOUT:   return
 // CHECK:STDOUT: }
 // CHECK:STDOUT:

+ 4 - 4
toolchain/check/testdata/function/declaration/fail_import_incomplete_return.carbon

@@ -187,9 +187,9 @@ fn CallFAndGIncomplete() {
 // CHECK:STDOUT: fn @Call() {
 // CHECK:STDOUT: !entry:
 // CHECK:STDOUT:   %ReturnCUsed.ref: %ReturnCUsed.type = name_ref ReturnCUsed, file.%ReturnCUsed.decl [concrete = constants.%ReturnCUsed]
-// CHECK:STDOUT:   %ReturnCUsed.call: init <error> = call %ReturnCUsed.ref()
+// CHECK:STDOUT:   %ReturnCUsed.call: init <error> = call %ReturnCUsed.ref(<invalid return info>)
 // CHECK:STDOUT:   %ReturnDUsed.ref: %ReturnDUsed.type = name_ref ReturnDUsed, file.%ReturnDUsed.decl [concrete = constants.%ReturnDUsed]
-// CHECK:STDOUT:   %ReturnDUsed.call: init <error> = call %ReturnDUsed.ref()
+// CHECK:STDOUT:   %ReturnDUsed.call: init <error> = call %ReturnDUsed.ref(<invalid return info>)
 // CHECK:STDOUT:   return
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
@@ -262,9 +262,9 @@ fn CallFAndGIncomplete() {
 // CHECK:STDOUT: fn @CallFAndGIncomplete() {
 // CHECK:STDOUT: !entry:
 // CHECK:STDOUT:   %ReturnCUnused.ref: %ReturnCUnused.type = name_ref ReturnCUnused, imports.%Main.ReturnCUnused [concrete = constants.%ReturnCUnused]
-// CHECK:STDOUT:   %ReturnCUnused.call: init <error> = call %ReturnCUnused.ref()
+// CHECK:STDOUT:   %ReturnCUnused.call: init <error> = call %ReturnCUnused.ref(<invalid return info>)
 // CHECK:STDOUT:   %ReturnCUsed.ref: %ReturnCUsed.type = name_ref ReturnCUsed, imports.%Main.ReturnCUsed [concrete = constants.%ReturnCUsed]
-// CHECK:STDOUT:   %ReturnCUsed.call: init <error> = call %ReturnCUsed.ref()
+// CHECK:STDOUT:   %ReturnCUsed.call: init <error> = call %ReturnCUsed.ref(<invalid return info>)
 // CHECK:STDOUT:   %ReturnDUnused.ref: %ReturnDUnused.type = name_ref ReturnDUnused, imports.%Main.ReturnDUnused [concrete = constants.%ReturnDUnused]
 // CHECK:STDOUT:   %.loc33_17.1: ref %D = temporary_storage
 // CHECK:STDOUT:   %ReturnDUnused.call: init %D = call %ReturnDUnused.ref() to %.loc33_17.1

+ 2 - 2
toolchain/check/testdata/namespace/imported_indirect.carbon

@@ -338,7 +338,7 @@ fn G() { Same.F(); }
 // CHECK:STDOUT: fn @G() {
 // CHECK:STDOUT: !entry:
 // CHECK:STDOUT:   %F.ref: %F.type = name_ref F, imports.%Same.F [concrete = constants.%F]
-// CHECK:STDOUT:   %F.call: init <error> = call %F.ref()
+// CHECK:STDOUT:   %F.call: init <error> = call %F.ref(<invalid return info>)
 // CHECK:STDOUT:   return
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
@@ -382,7 +382,7 @@ fn G() { Same.F(); }
 // CHECK:STDOUT: !entry:
 // CHECK:STDOUT:   %Same.ref: <namespace> = name_ref Same, imports.%Same [concrete = imports.%Same]
 // CHECK:STDOUT:   %F.ref: %F.type = name_ref F, imports.%Same.F [concrete = constants.%F]
-// CHECK:STDOUT:   %F.call: init <error> = call %F.ref()
+// CHECK:STDOUT:   %F.call: init <error> = call %F.ref(<invalid return info>)
 // CHECK:STDOUT:   return
 // CHECK:STDOUT: }
 // CHECK:STDOUT:

+ 3 - 3
toolchain/lower/function_context.cpp

@@ -315,11 +315,11 @@ auto FunctionContext::GetInitRepr(TypeInFile type) -> SemIR::InitRepr {
   return result;
 }
 
-auto FunctionContext::GetReturnTypeInfo(TypeInFile type)
+auto FunctionContext::GetReturnTypeInfo(InstInFile callee)
     -> ReturnTypeInfoInFile {
   ReturnTypeInfoInFile result = {
-      .file = type.file,
-      .info = SemIR::ReturnTypeInfo::ForType(*type.file, type.type_id)};
+      .file = callee.file,
+      .info = SemIR::ReturnTypeInfo::ForCallee(*callee.file, callee.inst_id)};
   AddEnumToCurrentFingerprint(result.info.init_repr.kind);
   return result;
 }

+ 9 - 3
toolchain/lower/function_context.h

@@ -80,6 +80,12 @@ class FunctionContext {
     }
   };
 
+  // An inst in a particular file.
+  struct InstInFile {
+    const SemIR::File* file;
+    SemIR::InstId inst_id;
+  };
+
   // Information about a function's return type in a particular file. By
   // convention, this represents a value whose initializing representation has
   // been added to the fingerprint but whose type has not.
@@ -165,9 +171,9 @@ class FunctionContext {
   // kind of initializing representation to the fingerprint.
   auto GetInitRepr(TypeInFile type) -> SemIR::InitRepr;
 
-  // Returns the return type information for the given type. This adds the
-  // kind of initializing representation to the fingerprint.
-  auto GetReturnTypeInfo(TypeInFile type) -> ReturnTypeInfoInFile;
+  // Returns the return type information for the given callee inst. This adds
+  // the kind of initializing representation to the fingerprint.
+  auto GetReturnTypeInfo(InstInFile callee) -> ReturnTypeInfoInFile;
 
   // Returns a lowered value to use for a value of type `type`.
   auto GetTypeAsValue() -> llvm::Value* {

+ 22 - 22
toolchain/lower/handle_call.cpp

@@ -574,29 +574,30 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
   // TODO: Should the `bound_method` be removed when forming the `call`
   // instruction? The `self` parameter is transferred into the call argument
   // list.
-  auto callee_id = inst.callee_id;
-  if (auto bound_method =
-          context.sem_ir().insts().TryGetAs<SemIR::BoundMethod>(callee_id)) {
-    callee_id = bound_method->function_decl_id;
+  FunctionContext::InstInFile callee = {.file = &context.sem_ir(),
+                                        .inst_id = inst.callee_id};
+  if (auto bound_method = context.sem_ir().insts().TryGetAs<SemIR::BoundMethod>(
+          callee.inst_id)) {
+    callee.inst_id = bound_method->function_decl_id;
   }
 
   // Map to the callee in the specific. This might be in a different file than
   // the one we're currently lowering.
-  const auto* callee_file = &context.sem_ir();
   if (context.specific_id().has_value()) {
     auto [const_file, const_id] = GetConstantValueInSpecific(
         context.specific_sem_ir(), context.specific_id(), context.sem_ir(),
-        callee_id);
-    callee_file = const_file;
-    callee_id = const_file->constant_values().GetInstIdIfValid(const_id);
-    CARBON_CHECK(callee_id.has_value());
+        callee.inst_id);
+    callee.file = const_file;
+    callee.inst_id = const_file->constant_values().GetInstIdIfValid(const_id);
+    CARBON_CHECK(callee.inst_id.has_value());
   }
 
-  auto callee_function = SemIR::GetCalleeAsFunction(*callee_file, callee_id);
+  auto callee_function =
+      SemIR::GetCalleeAsFunction(*callee.file, callee.inst_id);
 
   const SemIR::Function& function =
-      callee_file->functions().Get(callee_function.function_id);
-  context.AddCallToCurrentFingerprint(callee_file->check_ir_id(),
+      callee.file->functions().Get(callee_function.function_id);
+  context.AddCallToCurrentFingerprint(callee.file->check_ir_id(),
                                       callee_function.function_id,
                                       callee_function.resolved_specific_id);
 
@@ -608,11 +609,10 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
 
   std::vector<llvm::Value*> args;
 
-  auto inst_type = context.GetTypeIdOfInst(inst_id);
   bool call_has_return_slot =
-      SemIR::ReturnTypeInfo::ForType(context.sem_ir(), inst.type_id)
+      SemIR::ReturnTypeInfo::ForCallee(context.sem_ir(), inst.callee_id)
           .has_return_slot();
-  if (context.GetReturnTypeInfo(inst_type).info.has_return_slot()) {
+  if (context.GetReturnTypeInfo(callee).info.has_return_slot()) {
     CARBON_CHECK(call_has_return_slot);
     args.push_back(context.GetValue(arg_ids.consume_back()));
   } else if (call_has_return_slot) {
@@ -630,14 +630,14 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
 
   llvm::CallInst* call;
   if (function.virtual_modifier == SemIR::Function::VirtualModifier::None) {
-    auto* callee =
-        context.GetFileContext(callee_file)
+    auto* llvm_callee =
+        context.GetFileContext(callee.file)
             .GetOrCreateFunction(callee_function.function_id,
                                  callee_function.resolved_specific_id);
     auto describe_call = [&] {
       RawStringOstream out;
       out << "call ";
-      callee->printAsOperand(out);
+      llvm_callee->printAsOperand(out);
       out << "(";
       llvm::ListSeparator sep;
       for (auto* arg : args) {
@@ -645,14 +645,14 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
         arg->printAsOperand(out);
       }
       out << ")\n";
-      callee->print(out);
+      llvm_callee->print(out);
       return out.TakeStr();
     };
-    CARBON_CHECK(callee->arg_size() == args.size(),
+    CARBON_CHECK(llvm_callee->arg_size() == args.size(),
                  "Argument count mismatch: {0}", describe_call());
-    call = context.builder().CreateCall(callee, args);
+    call = context.builder().CreateCall(llvm_callee, args);
   } else {
-    call = HandleVirtualCall(context, args, callee_file, function,
+    call = HandleVirtualCall(context, args, callee.file, function,
                              callee_function);
   }
 

+ 3 - 2
toolchain/sem_ir/expr_info.cpp

@@ -122,13 +122,14 @@ auto FindReturnSlotArgForInitializer(const File& sem_ir, InstId init_id)
         return init.dest_id;
       }
       case CARBON_KIND(InPlaceInit init): {
-        if (!ReturnTypeInfo::ForType(sem_ir, init.type_id).has_return_slot()) {
+        if (!InitRepr::ForType(sem_ir, init.type_id).MightBeInPlace()) {
           return InstId::None;
         }
         return init.dest_id;
       }
       case CARBON_KIND(Call call): {
-        if (!ReturnTypeInfo::ForType(sem_ir, call.type_id).has_return_slot()) {
+        if (!ReturnTypeInfo::ForCallee(sem_ir, call.callee_id)
+                 .has_return_slot()) {
           return InstId::None;
         }
         if (!call.args_id.has_value()) {

+ 6 - 2
toolchain/sem_ir/formatter.cpp

@@ -1287,12 +1287,16 @@ auto Formatter::FormatCallRhs(Call inst) -> void {
 
   llvm::ArrayRef<InstId> args = sem_ir_->inst_blocks().Get(inst.args_id);
 
-  auto return_info = ReturnTypeInfo::ForType(*sem_ir_, inst.type_id);
+  auto return_info = ReturnTypeInfo::ForCallee(*sem_ir_, inst.callee_id);
   if (!return_info.is_valid()) {
     out_ << "(<invalid return info>)";
     return;
   }
-  bool has_return_slot = return_info.has_return_slot();
+
+  // Error in the inst type may indicate that the return type was incomplete
+  // when the inst was created, and so no return slot was added.
+  bool has_return_slot =
+      return_info.has_return_slot() && inst.type_id != SemIR::ErrorInst::TypeId;
   InstId return_slot_arg_id = InstId::None;
   if (has_return_slot) {
     return_slot_arg_id = args.consume_back();

+ 7 - 6
toolchain/sem_ir/function.cpp

@@ -14,8 +14,8 @@
 
 namespace Carbon::SemIR {
 
-auto GetCallee(const File& sem_ir, InstId callee_id, SpecificId specific_id)
-    -> Callee {
+auto GetCallee(const File& sem_ir, InstId callee_id,
+               SpecificId caller_specific_id) -> Callee {
   CalleeFunction fn = {.function_id = FunctionId::None,
                        .enclosing_specific_id = SpecificId::None,
                        .resolved_specific_id = SpecificId::None,
@@ -26,9 +26,9 @@ auto GetCallee(const File& sem_ir, InstId callee_id, SpecificId specific_id)
     callee_id = bound_method->function_decl_id;
   }
 
-  if (specific_id.has_value()) {
+  if (caller_specific_id.has_value()) {
     callee_id = sem_ir.constant_values().GetInstIdIfValid(
-        GetConstantValueInSpecific(sem_ir, specific_id, callee_id));
+        GetConstantValueInSpecific(sem_ir, caller_specific_id, callee_id));
     CARBON_CHECK(callee_id.has_value(),
                  "Invalid callee id in a specific context");
   }
@@ -80,8 +80,9 @@ auto GetCallee(const File& sem_ir, InstId callee_id, SpecificId specific_id)
 }
 
 auto GetCalleeAsFunction(const File& sem_ir, InstId callee_id,
-                         SpecificId specific_id) -> CalleeFunction {
-  return std::get<CalleeFunction>(GetCallee(sem_ir, callee_id, specific_id));
+                         SpecificId caller_specific_id) -> CalleeFunction {
+  return std::get<CalleeFunction>(
+      GetCallee(sem_ir, callee_id, caller_specific_id));
 }
 
 auto DecomposeVirtualFunction(const File& sem_ir, InstId fn_decl_id,

+ 4 - 3
toolchain/sem_ir/function.h

@@ -232,13 +232,14 @@ struct CalleeNonFunction {};
 using Callee = std::variant<CalleeCppOverloadSet, CalleeError, CalleeFunction,
                             CalleeNonFunction>;
 
-// Returns information for the function corresponding to callee_id.
+// Returns information for the function corresponding to callee_id in
+// caller_specific_id.
 auto GetCallee(const File& sem_ir, InstId callee_id,
-               SpecificId specific_id = SpecificId::None) -> Callee;
+               SpecificId caller_specific_id = SpecificId::None) -> Callee;
 
 // Like `GetCallee`, but restricts to the `Function` callee kind.
 auto GetCalleeAsFunction(const File& sem_ir, InstId callee_id,
-                         SpecificId specific_id = SpecificId::None)
+                         SpecificId caller_specific_id = SpecificId::None)
     -> CalleeFunction;
 
 struct DecomposedVirtualFunction {

+ 9 - 0
toolchain/sem_ir/type_info.cpp

@@ -79,6 +79,15 @@ auto InitRepr::ForType(const File& file, TypeId type_id) -> InitRepr {
   }
 }
 
+auto ReturnTypeInfo::ForCallee(const File& file, InstId callee_id,
+                               SemIR::SpecificId caller_specific_id)
+    -> ReturnTypeInfo {
+  auto callee_function =
+      SemIR::GetCalleeAsFunction(file, callee_id, caller_specific_id);
+  auto function = file.functions().Get(callee_function.function_id);
+  return ForFunction(file, function, callee_function.resolved_specific_id);
+}
+
 auto NumericTypeLiteralInfo::ForType(const File& file, ClassType class_type)
     -> NumericTypeLiteralInfo {
   // Quickly rule out any class that's not a specific.

+ 10 - 8
toolchain/sem_ir/type_info.h

@@ -166,20 +166,22 @@ struct InitRepr : Printable<InitRepr> {
 
 // Information about a function's return type.
 struct ReturnTypeInfo : public Printable<ReturnTypeInfo> {
-  // Builds return type information for a given declared return type.
-  static auto ForType(const File& file, TypeId type_id) -> ReturnTypeInfo {
+  // Builds return type information for a given function.
+  static auto ForFunction(const File& file, const Function& function,
+                          SpecificId specific_id = SpecificId::None)
+      -> ReturnTypeInfo {
+    auto type_id = function.GetDeclaredReturnType(file, specific_id);
     return {.type_id = type_id,
             .init_repr = type_id.has_value()
                              ? InitRepr::ForType(file, type_id)
                              : InitRepr{.kind = InitRepr::None}};
   }
 
-  // Builds return type information for a given function.
-  static auto ForFunction(const File& file, const Function& function,
-                          SpecificId specific_id = SpecificId::None)
-      -> ReturnTypeInfo {
-    return ForType(file, function.GetDeclaredReturnType(file, specific_id));
-  }
+  // Builds return type information for the function corresponding to callee_id
+  // in caller_specific_id.
+  static auto ForCallee(const File& file, InstId callee_id,
+                        SpecificId caller_specific_id = SemIR::SpecificId::None)
+      -> ReturnTypeInfo;
 
   // Returns whether the return information could be fully computed.
   auto is_valid() const -> bool { return init_repr.is_valid(); }