Przeglądaj źródła

Model return slot as parameter in lowering (#4457)

Co-authored-by: Richard Smith <richard@metafoo.co.uk>
Geoff Romer 1 rok temu
rodzic
commit
ac5cc33da4

+ 85 - 28
toolchain/lower/file_context.cpp

@@ -226,9 +226,14 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
                          implicit_param_patterns.size() + param_patterns.size();
   param_types.reserve(max_llvm_params);
   param_inst_ids.reserve(max_llvm_params);
+  auto return_param_id = SemIR::InstId::Invalid;
   if (return_info.has_return_slot()) {
     param_types.push_back(return_type->getPointerTo());
-    param_inst_ids.push_back(function.return_slot_id);
+    return_param_id = sem_ir()
+                          .insts()
+                          .GetAs<SemIR::ReturnSlot>(function.return_slot_id)
+                          .storage_id;
+    param_inst_ids.push_back(return_param_id);
   }
   for (auto param_pattern_id : llvm::concat<const SemIR::InstId>(
            implicit_param_patterns, param_patterns)) {
@@ -280,7 +285,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
   for (auto [inst_id, arg] :
        llvm::zip_equal(param_inst_ids, llvm_function->args())) {
     auto name_id = SemIR::NameId::Invalid;
-    if (inst_id == function.return_slot_id) {
+    if (inst_id == return_param_id) {
       name_id = SemIR::NameId::ReturnSlot;
       arg.addAttr(
           llvm::Attribute::getWithStructRetType(llvm_context(), return_type));
@@ -324,51 +329,103 @@ auto FileContext::BuildFunctionDefinition(SemIR::FunctionId function_id)
       sem_ir().inst_blocks().GetOrEmpty(function.implicit_param_refs_id);
   auto param_refs = sem_ir().inst_blocks().GetOrEmpty(function.param_refs_id);
   int param_index = 0;
-  if (SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id)
-          .has_return_slot()) {
-    function_lowering.SetLocal(function.return_slot_id,
-                               llvm_function->getArg(param_index));
-    ++param_index;
-  }
+  // The SemIR calling-convention parameters of the function, in order of
+  // runtime index. This is a transitional step toward generating this list
+  // in the check phase, which is why we're using the runtime index order
+  // even though it's less convenient for this usage.
+  llvm::SmallVector<SemIR::InstId> calling_convention_param_ids;
+  // This is an upper bound on the size because `self` and the return slot
+  // are the only runtime parameters that don't appear in the explicit
+  // parameter list.
+  calling_convention_param_ids.reserve(param_refs.size() + 2);
+  bool has_return_slot =
+      SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id)
+          .has_return_slot();
   for (auto param_ref_id :
        llvm::concat<const SemIR::InstId>(implicit_param_refs, param_refs)) {
     auto param_info =
         SemIR::Function::GetParamFromParamRefId(sem_ir(), param_ref_id);
-    if (!param_info.inst.runtime_index.is_valid()) {
-      continue;
+    if (param_info.inst.runtime_index.is_valid()) {
+      calling_convention_param_ids.push_back(param_info.inst_id);
     }
+  }
+  if (has_return_slot) {
+    auto return_slot =
+        sem_ir().insts().GetAs<SemIR::ReturnSlot>(function.return_slot_id);
+    calling_convention_param_ids.push_back(return_slot.storage_id);
+  }
+
+  // TODO: find a way to ensure this code and the function-call lowering use
+  // the same parameter ordering.
 
+  // Lowers the given parameter. Must be called in LLVM calling convention
+  // parameter order.
+  auto lower_param = [&](SemIR::InstId param_id) {
     // Get the value of the parameter from the function argument.
-    auto param_type_id = param_info.inst.type_id;
-    llvm::Value* param_value = llvm::PoisonValue::get(GetType(param_type_id));
-    if (SemIR::ValueRepr::ForType(sem_ir(), param_type_id).kind !=
+    auto param_inst = sem_ir().insts().GetAs<SemIR::AnyParam>(param_id);
+    llvm::Value* param_value =
+        llvm::PoisonValue::get(GetType(param_inst.type_id));
+    if (SemIR::ValueRepr::ForType(sem_ir(), param_inst.type_id).kind !=
         SemIR::ValueRepr::None) {
       param_value = llvm_function->getArg(param_index);
       ++param_index;
     }
-
     // The value of the parameter is the value of the argument.
-    function_lowering.SetLocal(param_info.inst_id, param_value);
-
-    // Match the portion of the pattern corresponding to the parameter against
-    // the parameter value. For now this is always a single name binding,
-    // possibly wrapped in `addr`.
-    //
-    // TODO: Support general patterns here.
-    auto bind_name_id = param_ref_id;
-    auto bind_name = sem_ir().insts().Get(bind_name_id);
-    CARBON_CHECK(bind_name.Is<SemIR::BindName>());
-    function_lowering.SetLocal(bind_name_id, param_value);
+    function_lowering.SetLocal(param_id, param_value);
+  };
+
+  // The subset of calling_convention_param_id that is in sequential order.
+  llvm::ArrayRef<SemIR::InstId> sequential_param_ids =
+      calling_convention_param_ids;
+
+  // The LLVM calling convention has the return slot first rather than last.
+  if (has_return_slot) {
+    lower_param(calling_convention_param_ids.back());
+
+    sequential_param_ids = sequential_param_ids.drop_back();
+  }
+  for (auto param_id : sequential_param_ids) {
+    lower_param(param_id);
   }
 
-  // Lower all blocks.
-  for (auto block_id : body_block_ids) {
+  auto decl_block_id = SemIR::InstBlockId::Invalid;
+  if (function_id == sem_ir().global_ctor_id()) {
+    decl_block_id = SemIR::InstBlockId::Empty;
+  } else {
+    decl_block_id = sem_ir()
+                        .insts()
+                        .GetAs<SemIR::FunctionDecl>(function.latest_decl_id())
+                        .decl_block_id;
+  }
+
+  // Lowers the contents of block_id into the corresponding LLVM block,
+  // creating it if it doesn't already exist.
+  auto lower_block = [&](SemIR::InstBlockId block_id) {
     CARBON_VLOG("Lowering {0}\n", block_id);
     auto* llvm_block = function_lowering.GetBlock(block_id);
     // Keep the LLVM blocks in lexical order.
     llvm_block->moveBefore(llvm_function->end());
     function_lowering.builder().SetInsertPoint(llvm_block);
-    function_lowering.LowerBlock(block_id);
+    function_lowering.LowerBlockContents(block_id);
+  };
+
+  lower_block(decl_block_id);
+
+  // If the decl block is empty, reuse it as the first body block. We don't do
+  // this when the decl block is non-empty so that any branches back to the
+  // first body block don't also re-execute the decl.
+  llvm::BasicBlock* block = function_lowering.builder().GetInsertBlock();
+  if (block->empty() &&
+      function_lowering.TryToReuseBlock(body_block_ids.front(), block)) {
+    // Reuse this block as the first block of the function body.
+  } else {
+    function_lowering.builder().CreateBr(
+        function_lowering.GetBlock(body_block_ids.front()));
+  }
+
+  // Lower all blocks.
+  for (auto block_id : body_block_ids) {
+    lower_block(block_id);
   }
 
   // LLVM requires that the entry block has no predecessors.

+ 1 - 1
toolchain/lower/function_context.cpp

@@ -49,7 +49,7 @@ auto FunctionContext::TryToReuseBlock(SemIR::InstBlockId block_id,
   return true;
 }
 
-auto FunctionContext::LowerBlock(SemIR::InstBlockId block_id) -> void {
+auto FunctionContext::LowerBlockContents(SemIR::InstBlockId block_id) -> void {
   for (auto inst_id : sem_ir().inst_blocks().Get(block_id)) {
     LowerInst(inst_id);
   }

+ 1 - 1
toolchain/lower/function_context.h

@@ -33,7 +33,7 @@ class FunctionContext {
       -> bool;
 
   // Builds LLVM IR for the sequence of instructions in `block_id`.
-  auto LowerBlock(SemIR::InstBlockId block_id) -> void;
+  auto LowerBlockContents(SemIR::InstBlockId block_id) -> void;
 
   // Builds LLVM IR for the specified instruction.
   auto LowerInst(SemIR::InstId inst_id) -> void;

+ 9 - 6
toolchain/lower/handle.cpp

@@ -179,17 +179,20 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
 
 auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
                 SemIR::OutParam /*inst*/) -> void {
-  CARBON_FATAL("Parameters should be lowered by `BuildFunctionDefinition`");
+  // Parameters are lowered by `BuildFunctionDefinition`.
 }
 
 auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
                 SemIR::ValueParam /*inst*/) -> void {
-  CARBON_FATAL("Parameters should be lowered by `BuildFunctionDefinition`");
+  // Parameters are lowered by `BuildFunctionDefinition`.
 }
 
-auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
-                SemIR::ReturnSlot /*inst*/) -> void {
-  CARBON_FATAL("Return slots should be lowered by `BuildFunctionDefinition`");
+auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
+                SemIR::ReturnSlot inst) -> void {
+  if (SemIR::InitRepr::ForType(context.sem_ir(), inst.type_id).kind ==
+      SemIR::InitRepr::InPlace) {
+    context.SetLocal(inst_id, context.GetValue(inst.storage_id));
+  }
 }
 
 auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
@@ -226,7 +229,7 @@ auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
 
 auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
                 SemIR::SpliceBlock inst) -> void {
-  context.LowerBlock(inst.block_id);
+  context.LowerBlockContents(inst.block_id);
   context.SetLocal(inst_id, context.GetValue(inst.result_id));
 }