Просмотр исходного кода

Basic lowering generic function definitions. (#5016)

Resolve the specific type for the callee, to lower the proper specific
function called.
Alina Sbirlea 1 год назад
Родитель
Сommit
4e21c0c1fc

+ 17 - 4
toolchain/lower/file_context.cpp

@@ -219,6 +219,15 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
   // TODO: Consider tracking whether the function has been used, and only
   // lowering it if it's needed.
 
+  // TODO nit: add is_symbolic() to type_id to forward to
+  // type_id.AsConstantId().is_symbolic(). Update call below too.
+  auto get_llvm_type = [&](SemIR::TypeId type_id) -> llvm::Type* {
+    if (!type_id.has_value()) {
+      return nullptr;
+    }
+    return GetType(SemIR::GetTypeInSpecific(sem_ir(), specific_id, type_id));
+  };
+
   const auto return_info =
       SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id);
   CARBON_CHECK(return_info.is_valid(), "Should not lower invalid functions.");
@@ -229,8 +238,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
   auto param_patterns =
       sem_ir().inst_blocks().GetOrEmpty(function.param_patterns_id);
 
-  auto* return_type =
-      return_info.type_id.has_value() ? GetType(return_info.type_id) : nullptr;
+  auto* return_type = get_llvm_type(return_info.type_id);
 
   llvm::SmallVector<llvm::Type*> param_types;
   // TODO: Consider either storing `param_inst_ids` somewhere so that we can
@@ -259,6 +267,10 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
     }
     auto param_type_id =
         SemIR::GetTypeInSpecific(sem_ir(), specific_id, param_pattern.type_id);
+    CARBON_CHECK(
+        !param_type_id.AsConstantId().is_symbolic(),
+        "Found symbolic type id after resolution when lowering type {0}.",
+        param_pattern.type_id);
     switch (auto value_rep = SemIR::ValueRepr::ForType(sem_ir(), param_type_id);
             value_rep.kind) {
       case SemIR::ValueRepr::Unknown:
@@ -268,7 +280,8 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
       case SemIR::ValueRepr::Copy:
       case SemIR::ValueRepr::Custom:
       case SemIR::ValueRepr::Pointer:
-        param_types.push_back(GetType(value_rep.type_id));
+        auto* param_types_to_add = get_llvm_type(value_rep.type_id);
+        param_types.push_back(param_types_to_add);
         param_inst_ids.push_back(param_pattern_id);
         break;
     }
@@ -349,7 +362,7 @@ auto FileContext::BuildFunctionBody(SemIR::FunctionId function_id,
   CARBON_DCHECK(!body_block_ids.empty(),
                 "No function body blocks found during lowering.");
 
-  FunctionContext function_lowering(*this, llvm_function,
+  FunctionContext function_lowering(*this, llvm_function, specific_id,
                                     BuildDISubprogram(function, llvm_function),
                                     vlog_stream_);
 

+ 2 - 0
toolchain/lower/function_context.cpp

@@ -12,10 +12,12 @@ namespace Carbon::Lower {
 
 FunctionContext::FunctionContext(FileContext& file_context,
                                  llvm::Function* function,
+                                 SemIR::SpecificId specific_id,
                                  llvm::DISubprogram* di_subprogram,
                                  llvm::raw_ostream* vlog_stream)
     : file_context_(&file_context),
       function_(function),
+      specific_id_(specific_id),
       builder_(file_context.llvm_context(), llvm::ConstantFolder(),
                Inserter(file_context.inst_namer())),
       di_subprogram_(di_subprogram),

+ 7 - 0
toolchain/lower/function_context.h

@@ -19,6 +19,7 @@ namespace Carbon::Lower {
 class FunctionContext {
  public:
   explicit FunctionContext(FileContext& file_context, llvm::Function* function,
+                           SemIR::SpecificId specific_id,
                            llvm::DISubprogram* di_subprogram,
                            llvm::raw_ostream* vlog_stream);
 
@@ -45,6 +46,8 @@ class FunctionContext {
 
   // Returns a value for the given instruction.
   auto GetValue(SemIR::InstId inst_id) -> llvm::Value* {
+    // TODO: if(specific_id_.has_value()) may need to update inst_id first.
+
     // All builtins are types, with the same empty lowered value.
     if (SemIR::IsSingletonInstId(inst_id)) {
       return GetTypeAsValue();
@@ -130,6 +133,7 @@ class FunctionContext {
   }
   auto llvm_module() -> llvm::Module& { return file_context_->llvm_module(); }
   auto llvm_function() -> llvm::Function& { return *function_; }
+  auto specific_id() -> SemIR::SpecificId { return specific_id_; }
   auto builder() -> llvm::IRBuilderBase& { return builder_; }
   auto sem_ir() -> const SemIR::File& { return file_context_->sem_ir(); }
 
@@ -174,6 +178,9 @@ class FunctionContext {
   // The IR function we're generating.
   llvm::Function* function_;
 
+  // The specific id, if the function is a specific.
+  SemIR::SpecificId specific_id_;
+
   // Builder for creating code in this function. The insertion point is held at
   // the location of the current SemIR instruction.
   llvm::IRBuilder<llvm::ConstantFolder, Inserter> builder_;

+ 6 - 3
toolchain/lower/handle_call.cpp

@@ -424,8 +424,8 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
   llvm::ArrayRef<SemIR::InstId> arg_ids =
       context.sem_ir().inst_blocks().Get(inst.args_id);
 
-  auto callee_function =
-      SemIR::GetCalleeFunction(context.sem_ir(), inst.callee_id);
+  auto callee_function = SemIR::GetCalleeFunction(
+      context.sem_ir(), inst.callee_id, context.specific_id());
   CARBON_CHECK(callee_function.function_id.has_value());
 
   if (auto builtin_kind = context.sem_ir()
@@ -442,7 +442,10 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
 
   std::vector<llvm::Value*> args;
 
-  if (SemIR::ReturnTypeInfo::ForType(context.sem_ir(), inst.type_id)
+  auto inst_type_id = SemIR::GetTypeInSpecific(
+      context.sem_ir(), callee_function.resolved_specific_id, inst.type_id);
+
+  if (SemIR::ReturnTypeInfo::ForType(context.sem_ir(), inst_type_id)
           .has_return_slot()) {
     args.push_back(context.GetValue(arg_ids.back()));
     arg_ids = arg_ids.drop_back();

+ 30 - 12
toolchain/lower/testdata/function/generic/call_basic.carbon

@@ -15,8 +15,8 @@ fn H[T:! type](x: T) {
 }
 
 fn G[T:! type](x: T) -> T {
-  // TODO: the call below is crashing because proper type resolution to
-  // use the G specific, not the G generic is not done yet.
+  H(x);
+  // TODO: Call crashes, see TODO in FunctionContext::GetValue()
   // H(T);
   return x;
 }
@@ -72,17 +72,29 @@ fn M() {
 // CHECK:STDOUT:
 // CHECK:STDOUT: define i32 @_CG.Main.b88d1103f417c6d4(i32 %x) !dbg !22 {
 // CHECK:STDOUT: entry:
-// CHECK:STDOUT:   ret i32 %x, !dbg !23
+// CHECK:STDOUT:   call void @_CH.Main.b88d1103f417c6d4(i32 %x), !dbg !23
+// CHECK:STDOUT:   ret i32 %x, !dbg !24
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: define void @_CF.Main.66be507887ceee78(double %x) !dbg !24 {
+// CHECK:STDOUT: define void @_CF.Main.66be507887ceee78(double %x) !dbg !25 {
 // CHECK:STDOUT: entry:
-// CHECK:STDOUT:   ret void, !dbg !25
+// CHECK:STDOUT:   ret void, !dbg !26
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: define double @_CG.Main.66be507887ceee78(double %x) !dbg !26 {
+// CHECK:STDOUT: define double @_CG.Main.66be507887ceee78(double %x) !dbg !27 {
 // CHECK:STDOUT: entry:
-// CHECK:STDOUT:   ret double %x, !dbg !27
+// CHECK:STDOUT:   call void @_CH.Main.66be507887ceee78(double %x), !dbg !28
+// CHECK:STDOUT:   ret double %x, !dbg !29
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_CH.Main.b88d1103f417c6d4(i32 %x) !dbg !30 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   ret void, !dbg !31
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_CH.Main.66be507887ceee78(double %x) !dbg !32 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   ret void, !dbg !33
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: ; uselistorder directives
@@ -116,8 +128,14 @@ fn M() {
 // CHECK:STDOUT: !20 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.b88d1103f417c6d4", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2)
 // CHECK:STDOUT: !21 = !DILocation(line: 11, column: 1, scope: !20)
 // CHECK:STDOUT: !22 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.b88d1103f417c6d4", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2)
-// CHECK:STDOUT: !23 = !DILocation(line: 21, column: 3, scope: !22)
-// CHECK:STDOUT: !24 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.66be507887ceee78", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2)
-// CHECK:STDOUT: !25 = !DILocation(line: 11, column: 1, scope: !24)
-// CHECK:STDOUT: !26 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.66be507887ceee78", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2)
-// CHECK:STDOUT: !27 = !DILocation(line: 21, column: 3, scope: !26)
+// CHECK:STDOUT: !23 = !DILocation(line: 18, column: 3, scope: !22)
+// CHECK:STDOUT: !24 = !DILocation(line: 21, column: 3, scope: !22)
+// CHECK:STDOUT: !25 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.66be507887ceee78", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !26 = !DILocation(line: 11, column: 1, scope: !25)
+// CHECK:STDOUT: !27 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.66be507887ceee78", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !28 = !DILocation(line: 18, column: 3, scope: !27)
+// CHECK:STDOUT: !29 = !DILocation(line: 21, column: 3, scope: !27)
+// CHECK:STDOUT: !30 = distinct !DISubprogram(name: "H", linkageName: "_CH.Main.b88d1103f417c6d4", scope: null, file: !3, line: 14, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !31 = !DILocation(line: 14, column: 1, scope: !30)
+// CHECK:STDOUT: !32 = distinct !DISubprogram(name: "H", linkageName: "_CH.Main.66be507887ceee78", scope: null, file: !3, line: 14, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !33 = !DILocation(line: 14, column: 1, scope: !32)

+ 6 - 1
toolchain/sem_ir/function.cpp

@@ -10,13 +10,18 @@
 
 namespace Carbon::SemIR {
 
-auto GetCalleeFunction(const File& sem_ir, InstId callee_id) -> CalleeFunction {
+auto GetCalleeFunction(const File& sem_ir, InstId callee_id,
+                       SpecificId specific_id) -> CalleeFunction {
   CalleeFunction result = {.function_id = FunctionId::None,
                            .enclosing_specific_id = SpecificId::None,
                            .resolved_specific_id = SpecificId::None,
                            .self_type_id = InstId::None,
                            .self_id = InstId::None,
                            .is_error = false};
+  if (specific_id.has_value()) {
+    callee_id = sem_ir.constant_values().GetInstIdIfValid(
+        GetConstantValueInSpecific(sem_ir, specific_id, callee_id));
+  }
 
   if (auto specific_function =
           sem_ir.insts().TryGetAs<SpecificFunction>(callee_id)) {

+ 3 - 1
toolchain/sem_ir/function.h

@@ -115,7 +115,9 @@ struct CalleeFunction {
 };
 
 // Returns information for the function corresponding to callee_id.
-auto GetCalleeFunction(const File& sem_ir, InstId callee_id) -> CalleeFunction;
+auto GetCalleeFunction(const File& sem_ir, InstId callee_id,
+                       SpecificId specific_id = SpecificId::None)
+    -> CalleeFunction;
 
 }  // namespace Carbon::SemIR