Просмотр исходного кода

Remove return_type_id from Function. (#4051)

Instead of redundantly storing both the `return_type_id` and
`return_storage_id`, where the declared return type is just the type of
the return storage, store only the `return_storage_id`.

Add a convenience property to get the declared return type of the
function.

In addition to avoiding storing redundant information, this is a
preparatory step for an upcoming change for generics support that will
make it more expensive and awkward to store `TypeId`s in places other
than the type of an instruction.
Richard Smith 1 год назад
Родитель
Сommit
e3c15edb92

+ 2 - 2
toolchain/check/call.cpp

@@ -89,7 +89,7 @@ auto PerformCall(Context& context, Parse::NodeId node_id,
 
   // For functions with an implicit return type, the return type is the empty
   // tuple type.
-  SemIR::TypeId type_id = callable.return_type_id;
+  SemIR::TypeId type_id = callable.declared_return_type(context.sem_ir());
   if (!type_id.is_valid()) {
     type_id = context.GetTupleType({});
   }
@@ -110,7 +110,7 @@ auto PerformCall(Context& context, Parse::NodeId node_id,
       // Tentatively put storage for a temporary in the function's return slot.
       // This will be replaced if necessary when we perform initialization.
       return_storage_id = context.AddInst<SemIR::TemporaryStorage>(
-          node_id, {.type_id = callable.return_type_id});
+          node_id, {.type_id = type_id});
       break;
     case SemIR::Function::ReturnSlot::Absent:
       break;

+ 0 - 1
toolchain/check/context.cpp

@@ -631,7 +631,6 @@ auto Context::FinalizeGlobalInit() -> void {
          .decl_id = SemIR::InstId::Invalid,
          .implicit_param_refs_id = SemIR::InstBlockId::Invalid,
          .param_refs_id = SemIR::InstBlockId::Empty,
-         .return_type_id = SemIR::TypeId::Invalid,
          .return_storage_id = SemIR::InstId::Invalid,
          .is_extern = false,
          .return_slot = SemIR::Function::ReturnSlot::Absent,

+ 20 - 13
toolchain/check/function.cpp

@@ -19,15 +19,18 @@ auto CheckFunctionTypeMatches(Context& context,
     return false;
   }
 
-  if (new_function.return_type_id == SemIR::TypeId::Error ||
-      prev_function.return_type_id == SemIR::TypeId::Error) {
+  auto new_return_type_id = new_function.declared_return_type(context.sem_ir());
+  auto prev_return_type_id =
+      prev_function.declared_return_type(context.sem_ir());
+  if (new_return_type_id == SemIR::TypeId::Error ||
+      prev_return_type_id == SemIR::TypeId::Error) {
     return false;
   }
-  auto prev_return_type_id =
-      prev_function.return_type_id.is_valid()
-          ? SubstType(context, prev_function.return_type_id, substitutions)
-          : SemIR::TypeId::Invalid;
-  if (new_function.return_type_id != prev_return_type_id) {
+  if (prev_return_type_id.is_valid()) {
+    prev_return_type_id =
+        SubstType(context, prev_return_type_id, substitutions);
+  }
+  if (new_return_type_id != prev_return_type_id) {
     CARBON_DIAGNOSTIC(
         FunctionRedeclReturnTypeDiffers, Error,
         "Function redeclaration differs because return type is `{0}`.",
@@ -36,10 +39,10 @@ auto CheckFunctionTypeMatches(Context& context,
         FunctionRedeclReturnTypeDiffersNoReturn, Error,
         "Function redeclaration differs because no return type is provided.");
     auto diag =
-        new_function.return_type_id.is_valid()
+        new_return_type_id.is_valid()
             ? context.emitter().Build(new_function.decl_id,
                                       FunctionRedeclReturnTypeDiffers,
-                                      new_function.return_type_id)
+                                      new_return_type_id)
             : context.emitter().Build(new_function.decl_id,
                                       FunctionRedeclReturnTypeDiffersNoReturn);
     if (prev_return_type_id.is_valid()) {
@@ -68,27 +71,31 @@ auto CheckFunctionReturnType(Context& context, SemIRLoc loc,
     return;
   }
 
-  if (!function.return_type_id.is_valid()) {
+  if (!function.return_storage_id.is_valid()) {
     // Implicit `-> ()` has no return slot.
     function.return_slot = SemIR::Function::ReturnSlot::Absent;
     return;
   }
 
+  auto return_type_id = function.declared_return_type(context.sem_ir());
+  CARBON_CHECK(return_type_id.is_valid())
+      << "Have return storage but no return type.";
+
   // Check the return type is complete. Only diagnose incompleteness if we've
   // not already done so.
   auto diagnose_incomplete_return_type = [&] {
     CARBON_DIAGNOSTIC(IncompleteTypeInFunctionReturnType, Error,
                       "Function returns incomplete type `{0}`.", SemIR::TypeId);
     return context.emitter().Build(loc, IncompleteTypeInFunctionReturnType,
-                                   function.return_type_id);
+                                   return_type_id);
   };
   if (!context.TryToCompleteType(
-          function.return_type_id,
+          return_type_id,
           function.return_slot == SemIR::Function::ReturnSlot::Error
               ? std::nullopt
               : std::optional(diagnose_incomplete_return_type))) {
     function.return_slot = SemIR::Function::ReturnSlot::Error;
-  } else if (SemIR::GetInitRepr(context.sem_ir(), function.return_type_id)
+  } else if (SemIR::GetInitRepr(context.sem_ir(), return_type_id)
                  .has_return_slot()) {
     function.return_slot = SemIR::Function::ReturnSlot::Present;
   } else {

+ 6 - 9
toolchain/check/handle_function.cpp

@@ -118,7 +118,6 @@ static auto MergeFunctionRedecl(Context& context, SemIRLoc new_loc,
     prev_function.definition_id = new_function.definition_id;
     prev_function.implicit_param_refs_id = new_function.implicit_param_refs_id;
     prev_function.param_refs_id = new_function.param_refs_id;
-    prev_function.return_type_id = new_function.return_type_id;
     prev_function.return_storage_id = new_function.return_storage_id;
   }
   // The new function might have return slot information if it was imported.
@@ -199,13 +198,11 @@ static auto BuildFunctionDecl(Context& context,
     -> std::pair<SemIR::FunctionId, SemIR::InstId> {
   auto decl_block_id = context.inst_block_stack().Pop();
 
-  auto return_type_id = SemIR::TypeId::Invalid;
   auto return_storage_id = SemIR::InstId::Invalid;
   auto return_slot = SemIR::Function::ReturnSlot::NotComputed;
   if (auto [return_node, maybe_return_storage_id] =
           context.node_stack().PopWithNodeIdIf<Parse::NodeKind::ReturnType>();
       maybe_return_storage_id) {
-    return_type_id = context.insts().Get(*maybe_return_storage_id).type_id();
     return_storage_id = *maybe_return_storage_id;
   } else {
     // If there's no return type, there's no return slot.
@@ -250,7 +247,6 @@ static auto BuildFunctionDecl(Context& context,
           SemIR::LocIdAndInst(node_id, function_decl)),
       .implicit_param_refs_id = name.implicit_params_id,
       .param_refs_id = name.params_id,
-      .return_type_id = return_type_id,
       .return_storage_id = return_storage_id,
       .is_extern = is_extern,
       .return_slot = return_slot};
@@ -291,14 +287,15 @@ static auto BuildFunctionDecl(Context& context,
   }
 
   if (SemIR::IsEntryPoint(context.sem_ir(), function_decl.function_id)) {
+    auto return_type_id = function_info.declared_return_type(context.sem_ir());
     // TODO: Update this once valid signatures for the entry point are decided.
     if (function_info.implicit_param_refs_id.is_valid() ||
         !function_info.param_refs_id.is_valid() ||
         !context.inst_blocks().Get(function_info.param_refs_id).empty() ||
-        (function_info.return_type_id.is_valid() &&
-         function_info.return_type_id !=
+        (return_type_id.is_valid() &&
+         return_type_id !=
              context.GetBuiltinType(SemIR::BuiltinKind::IntType) &&
-         function_info.return_type_id != context.GetTupleType({}))) {
+         return_type_id != context.GetTupleType({}))) {
       CARBON_DIAGNOSTIC(InvalidMainRunSignature, Error,
                         "Invalid signature for `Main.Run` function. Expected "
                         "`fn ()` or `fn () -> i32`.");
@@ -398,7 +395,7 @@ auto HandleFunctionDefinition(Context& context,
   // If the `}` of the function is reachable, reject if we need a return value
   // and otherwise add an implicit `return;`.
   if (context.is_current_position_reachable()) {
-    if (context.functions().Get(function_id).return_type_id.is_valid()) {
+    if (context.functions().Get(function_id).return_storage_id.is_valid()) {
       CARBON_DIAGNOSTIC(
           MissingReturnStatement, Error,
           "Missing `return` at end of function with declared return type.");
@@ -467,7 +464,7 @@ static auto IsValidBuiltinDeclaration(Context& context,
   }
 
   // Get the return type. This is `()` if none was specified.
-  auto return_type_id = function.return_type_id;
+  auto return_type_id = function.declared_return_type(context.sem_ir());
   if (!return_type_id.is_valid()) {
     return_type_id = context.GetTupleType({});
   }

+ 4 - 8
toolchain/check/import_ref.cpp

@@ -1034,8 +1034,9 @@ class ImportRefResolver {
 
     const auto& function = import_ir_.functions().Get(inst.function_id);
     auto return_type_const_id = SemIR::ConstantId::Invalid;
-    if (function.return_type_id.is_valid()) {
-      return_type_const_id = GetLocalConstantId(function.return_type_id);
+    if (function.return_storage_id.is_valid()) {
+      return_type_const_id =
+          GetLocalConstantId(function.declared_return_type(import_ir_));
     }
     auto parent_scope_id = GetLocalNameScopeId(function.parent_scope_id);
     llvm::SmallVector<SemIR::ConstantId> implicit_param_const_ids =
@@ -1059,10 +1060,6 @@ class ImportRefResolver {
     auto function_decl_id = context_.AddPlaceholderInstInNoBlock(
         SemIR::LocIdAndInst(import_ir_inst_id, function_decl));
 
-    auto new_return_type_id =
-        return_type_const_id.is_valid()
-            ? context_.GetTypeIdForTypeConstant(return_type_const_id)
-            : SemIR::TypeId::Invalid;
     auto new_return_storage = SemIR::InstId::Invalid;
     if (function.return_storage_id.is_valid()) {
       // Recreate the return slot from scratch.
@@ -1070,7 +1067,7 @@ class ImportRefResolver {
       // use the same return storage variable in the declaration and definition.
       new_return_storage = context_.AddInstInNoBlock<SemIR::VarStorage>(
           AddImportIRInst(function.return_storage_id),
-          {.type_id = new_return_type_id,
+          {.type_id = context_.GetTypeIdForTypeConstant(return_type_const_id),
            .name_id = SemIR::NameId::ReturnSlot});
     }
     function_decl.function_id = context_.functions().Add(
@@ -1081,7 +1078,6 @@ class ImportRefResolver {
              function.implicit_param_refs_id, implicit_param_const_ids),
          .param_refs_id =
              GetLocalParamRefsId(function.param_refs_id, param_const_ids),
-         .return_type_id = new_return_type_id,
          .return_storage_id = new_return_storage,
          .is_extern = function.is_extern,
          .return_slot = function.return_slot,

+ 13 - 11
toolchain/check/return.cpp

@@ -38,11 +38,11 @@ static auto NoteNoReturnTypeProvided(Context::DiagnosticBuilder& diag,
 
 // Produces a note describing the return type of the given function.
 static auto NoteReturnType(Context::DiagnosticBuilder& diag,
-                           const SemIR::Function& function) {
+                           const SemIR::Function& function,
+                           SemIR::TypeId return_type_id) {
   CARBON_DIAGNOSTIC(ReturnTypeHereNote, Note,
                     "Return type of function is `{0}`.", SemIR::TypeId);
-  diag.Note(function.return_storage_id, ReturnTypeHereNote,
-            function.return_type_id);
+  diag.Note(function.return_storage_id, ReturnTypeHereNote, return_type_id);
 }
 
 // Produces a note pointing at the currently in scope `returned var`.
@@ -58,7 +58,8 @@ auto CheckReturnedVar(Context& context, Parse::NodeId returned_node,
     -> SemIR::InstId {
   // A `returned var` requires an explicit return type.
   auto& function = GetCurrentFunction(context);
-  if (!function.return_type_id.is_valid()) {
+  auto return_type_id = function.declared_return_type(context.sem_ir());
+  if (!return_type_id.is_valid()) {
     CARBON_DIAGNOSTIC(ReturnedVarWithNoReturnType, Error,
                       "Cannot declare a `returned var` in this function.");
     auto diag =
@@ -69,14 +70,14 @@ auto CheckReturnedVar(Context& context, Parse::NodeId returned_node,
   }
 
   // The declared type of the var must match the return type of the function.
-  if (function.return_type_id != type_id) {
+  if (return_type_id != type_id) {
     CARBON_DIAGNOSTIC(ReturnedVarWrongType, Error,
                       "Type `{0}` of `returned var` does not match "
                       "return type of enclosing function.",
                       SemIR::TypeId);
     auto diag =
         context.emitter().Build(type_node, ReturnedVarWrongType, type_id);
-    NoteReturnType(diag, function);
+    NoteReturnType(diag, function, return_type_id);
     diag.Emit();
     return SemIR::InstId::BuiltinError;
   }
@@ -105,12 +106,13 @@ auto RegisterReturnedVar(Context& context, SemIR::InstId bind_id) -> void {
 auto BuildReturnWithNoExpr(Context& context, Parse::ReturnStatementId node_id)
     -> void {
   const auto& function = GetCurrentFunction(context);
+  auto return_type_id = function.declared_return_type(context.sem_ir());
 
-  if (function.return_type_id.is_valid()) {
+  if (return_type_id.is_valid()) {
     CARBON_DIAGNOSTIC(ReturnStatementMissingExpr, Error,
                       "Missing return value.");
     auto diag = context.emitter().Build(node_id, ReturnStatementMissingExpr);
-    NoteReturnType(diag, function);
+    NoteReturnType(diag, function, return_type_id);
     diag.Emit();
   }
 
@@ -122,8 +124,9 @@ auto BuildReturnWithExpr(Context& context, Parse::ReturnStatementId node_id,
   const auto& function = GetCurrentFunction(context);
   auto returned_var_id = GetCurrentReturnedVar(context);
   auto return_slot_id = SemIR::InstId::Invalid;
+  auto return_type_id = function.declared_return_type(context.sem_ir());
 
-  if (!function.return_type_id.is_valid()) {
+  if (!return_type_id.is_valid()) {
     CARBON_DIAGNOSTIC(
         ReturnStatementDisallowExpr, Error,
         "No return expression should be provided in this context.");
@@ -146,8 +149,7 @@ auto BuildReturnWithExpr(Context& context, Parse::ReturnStatementId node_id,
     // Don't produce a second error complaining the return type is incomplete.
     expr_id = SemIR::InstId::BuiltinError;
   } else {
-    expr_id = ConvertToValueOfType(context, node_id, expr_id,
-                                   function.return_type_id);
+    expr_id = ConvertToValueOfType(context, node_id, expr_id, return_type_id);
   }
 
   context.AddInst<SemIR::ReturnExpr>(

+ 1 - 1
toolchain/check/testdata/basics/no_prelude/raw_and_textual_ir.carbon

@@ -25,7 +25,7 @@ fn Foo(n: ()) -> ((), ()) {
 // CHECK:STDOUT:   bind_names:
 // CHECK:STDOUT:     bindName0:       {name: name1, parent_scope: name_scope<invalid>, index: compTimeBind<invalid>}
 // CHECK:STDOUT:   functions:
-// CHECK:STDOUT:     function0:       {name: name0, parent_scope: name_scope0, param_refs: block3, return_type: type2, return_storage: inst+13, return_slot: present, body: [block6]}
+// CHECK:STDOUT:     function0:       {name: name0, parent_scope: name_scope0, param_refs: block3, return_storage: inst+13, return_slot: present, body: [block6]}
 // CHECK:STDOUT:   classes:         {}
 // CHECK:STDOUT:   types:
 // CHECK:STDOUT:     type0:           {constant: template instNamespaceType, value_rep: {kind: copy, type: type0}}

+ 1 - 1
toolchain/check/testdata/basics/no_prelude/raw_ir.carbon

@@ -25,7 +25,7 @@ fn Foo(n: ()) -> ((), ()) {
 // CHECK:STDOUT:   bind_names:
 // CHECK:STDOUT:     bindName0:       {name: name1, parent_scope: name_scope<invalid>, index: compTimeBind<invalid>}
 // CHECK:STDOUT:   functions:
-// CHECK:STDOUT:     function0:       {name: name0, parent_scope: name_scope0, param_refs: block3, return_type: type2, return_storage: inst+13, return_slot: present, body: [block6]}
+// CHECK:STDOUT:     function0:       {name: name0, parent_scope: name_scope0, param_refs: block3, return_storage: inst+13, return_slot: present, body: [block6]}
 // CHECK:STDOUT:   classes:         {}
 // CHECK:STDOUT:   types:
 // CHECK:STDOUT:     type0:           {constant: template instNamespaceType, value_rep: {kind: copy, type: type0}}

+ 6 - 5
toolchain/lower/file_context.cpp

@@ -142,10 +142,11 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id)
       sem_ir().inst_blocks().GetOrEmpty(function.implicit_param_refs_id);
   // TODO: Include parameters corresponding to positional parameters.
   auto param_refs = sem_ir().inst_blocks().GetOrEmpty(function.param_refs_id);
+  auto return_type_id = function.declared_return_type(sem_ir());
 
   SemIR::InitRepr return_rep =
-      function.return_type_id.is_valid()
-          ? SemIR::GetInitRepr(sem_ir(), function.return_type_id)
+      return_type_id.is_valid()
+          ? SemIR::GetInitRepr(sem_ir(), return_type_id)
           : SemIR::InitRepr{.kind = SemIR::InitRepr::None};
   CARBON_CHECK(return_rep.has_return_slot() == has_return_slot);
 
@@ -160,7 +161,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id)
   param_types.reserve(max_llvm_params);
   param_inst_ids.reserve(max_llvm_params);
   if (has_return_slot) {
-    param_types.push_back(GetType(function.return_type_id)->getPointerTo());
+    param_types.push_back(GetType(return_type_id)->getPointerTo());
     param_inst_ids.push_back(function.return_storage_id);
   }
   for (auto param_ref_id :
@@ -187,7 +188,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id)
   // If the initializing representation doesn't produce a value, set the return
   // type to void.
   llvm::Type* return_type = return_rep.kind == SemIR::InitRepr::ByCopy
-                                ? GetType(function.return_type_id)
+                                ? GetType(return_type_id)
                                 : llvm::Type::getVoidTy(llvm_context());
 
   std::string mangled_name;
@@ -216,7 +217,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id)
     if (inst_id == function.return_storage_id) {
       name_id = SemIR::NameId::ReturnSlot;
       arg.addAttr(llvm::Attribute::getWithStructRetType(
-          llvm_context(), GetType(function.return_type_id)));
+          llvm_context(), GetType(return_type_id)));
     } else {
       name_id = SemIR::Function::GetParamFromParamRefId(sem_ir(), inst_id)
                     .second.name_id;

+ 2 - 2
toolchain/sem_ir/formatter.cpp

@@ -246,13 +246,13 @@ class Formatter {
       out_ << ")";
     }
 
-    if (fn.return_type_id.is_valid()) {
+    if (fn.return_storage_id.is_valid()) {
       out_ << " -> ";
       if (!fn.body_block_ids.empty() && fn.has_return_slot()) {
         FormatInstName(fn.return_storage_id);
         out_ << ": ";
       }
-      FormatType(fn.return_type_id);
+      FormatType(sem_ir_.insts().Get(fn.return_storage_id).type_id());
     }
 
     if (fn.builtin_kind != BuiltinFunctionKind::None) {

+ 7 - 0
toolchain/sem_ir/function.cpp

@@ -38,4 +38,11 @@ auto GetCalleeFunction(const File& sem_ir, InstId callee_id) -> CalleeFunction {
   return result;
 }
 
+auto Function::declared_return_type(const File& file) const -> TypeId {
+  if (!return_storage_id.is_valid()) {
+    return TypeId::Invalid;
+  }
+  return file.insts().Get(return_storage_id).type_id();
+}
+
 }  // namespace Carbon::SemIR

+ 6 - 6
toolchain/sem_ir/function.h

@@ -30,9 +30,6 @@ struct Function : public Printable<Function> {
   auto Print(llvm::raw_ostream& out) const -> void {
     out << "{name: " << name_id << ", parent_scope: " << parent_scope_id
         << ", param_refs: " << param_refs_id;
-    if (return_type_id.is_valid()) {
-      out << ", return_type: " << return_type_id;
-    }
     if (return_storage_id.is_valid()) {
       out << ", return_storage: " << return_storage_id;
       out << ", return_slot: ";
@@ -65,6 +62,10 @@ struct Function : public Printable<Function> {
   static auto GetParamFromParamRefId(const File& sem_ir, InstId param_ref_id)
       -> std::pair<InstId, Param>;
 
+  // Gets the declared return type of the function. Returns `Invalid` if no
+  // return type was specified,
+  auto declared_return_type(const File& file) const -> TypeId;
+
   // Returns whether the function has a return slot. Can only be called for a
   // function that has either been called or defined, otherwise this is not
   // known.
@@ -87,11 +88,10 @@ struct Function : public Printable<Function> {
   InstBlockId implicit_param_refs_id;
   // A block containing a single reference instruction per parameter.
   InstBlockId param_refs_id;
-  // The return type. This will be invalid if the return type wasn't specified.
-  TypeId return_type_id;
   // The storage for the return value, which is a reference expression whose
   // type is the return type of the function. This may or may not be used by the
-  // function, depending on whether the return type needs a return slot.
+  // function, depending on whether the return type needs a return slot, but is
+  // always present if the function has a declared return type.
   InstId return_storage_id;
   // Whether the declaration is extern.
   bool is_extern;