Răsfoiți Sursa

Implement virtual call dispatch (#5308)

Adds a `virtual_index` to `SemIR::Function` used to determine which
vtable slot
to use when calling the given function.

Then use that to lower the function call to use the vtable and
specifically the
relative vtable ABI to match the vtable entries.
David Blaikie 1 an în urmă
părinte
comite
f45a632d77

+ 4 - 2
toolchain/check/class.cpp

@@ -151,7 +151,7 @@ static auto BuildVtable(Context& context, Parse::NodeId node_id,
     // elements of the top of `vtable_stack`.
     for (auto fn_decl_id : base_vtable_inst_block) {
       auto fn_decl = GetCalleeFunction(context.sem_ir(), fn_decl_id);
-      const auto& fn = context.functions().Get(fn_decl.function_id);
+      auto& fn = context.functions().Get(fn_decl.function_id);
       for (auto override_fn_decl_id : vtable_contents) {
         auto override_fn_decl =
             context.insts().GetAs<SemIR::FunctionDecl>(override_fn_decl_id);
@@ -169,14 +169,16 @@ static auto BuildVtable(Context& context, Parse::NodeId node_id,
           fn_decl_id = override_fn_decl_id;
         }
       }
+      fn.virtual_index = vtable.size();
       vtable.push_back(fn_decl_id);
     }
   }
 
   for (auto inst_id : vtable_contents) {
     auto fn_decl = context.insts().GetAs<SemIR::FunctionDecl>(inst_id);
-    const auto& fn = context.functions().Get(fn_decl.function_id);
+    auto& fn = context.functions().Get(fn_decl.function_id);
     if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
+      fn.virtual_index = vtable.size();
       vtable.push_back(inst_id);
     }
   }

+ 1 - 1
toolchain/lower/file_context.h

@@ -106,7 +106,6 @@ class FileContext {
     printf_int_format_string_ = printf_int_format_string;
   }
 
- private:
   struct FunctionTypeInfo {
     llvm::FunctionType* type;
     llvm::SmallVector<SemIR::InstId> param_inst_ids;
@@ -121,6 +120,7 @@ class FileContext {
   auto BuildFunctionTypeInfo(const SemIR::Function& function,
                              SemIR::SpecificId specific_id) -> FunctionTypeInfo;
 
+ private:
   // Builds the declaration for the given function, which should then be cached
   // by the caller.
   auto BuildFunctionDecl(SemIR::FunctionId function_id,

+ 7 - 0
toolchain/lower/function_context.h

@@ -80,6 +80,13 @@ class FunctionContext {
     return file_context_->GetOrCreateFunction(function_id, specific_id);
   }
 
+  // Builds LLVM function type information for the specified function.
+  auto BuildFunctionTypeInfo(const SemIR::Function& function,
+                             SemIR::SpecificId specific_id)
+      -> FileContext::FunctionTypeInfo {
+    return file_context_->BuildFunctionTypeInfo(function, specific_id);
+  }
+
   // Returns a lowered type for the given type_id.
   auto GetType(SemIR::TypeId type_id) -> llvm::Type* {
     return file_context_->GetType(type_id);

+ 34 - 4
toolchain/lower/handle_call.cpp

@@ -439,9 +439,6 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
     return;
   }
 
-  auto* callee = context.GetOrCreateFunction(
-      callee_function.function_id, callee_function.resolved_specific_id);
-
   std::vector<llvm::Value*> args;
 
   auto inst_type_id = SemIR::GetTypeOfInstInSpecific(
@@ -461,7 +458,40 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
     }
   }
 
-  context.SetLocal(inst_id, context.builder().CreateCall(callee, args));
+  llvm::CallInst* call;
+  const auto& function =
+      context.sem_ir().functions().Get(callee_function.function_id);
+  if (function.virtual_index != -1) {
+    CARBON_CHECK(!args.empty(),
+                 "Virtual functions must have at least one parameter");
+    auto* ptr_type =
+        llvm::PointerType::get(context.llvm_context(), /*AddressSpace=*/0);
+    // The vtable pointer is always at the start of the object in the Carbon
+    // ABI, so a pointer to the object is a pointer to the vtable pointer - load
+    // that to get a pointer to the vtable.
+    auto* vtable =
+        context.builder().CreateLoad(ptr_type, args.front(), "vtable");
+    auto* i32_type = llvm::IntegerType::getInt32Ty(context.llvm_context());
+    auto function_type_info = context.BuildFunctionTypeInfo(
+        function, callee_function.resolved_specific_id);
+    call = context.builder().CreateCall(
+        function_type_info.type,
+        context.builder().CreateCall(
+            llvm::Intrinsic::getOrInsertDeclaration(
+                &context.llvm_module(), llvm::Intrinsic::load_relative,
+                {i32_type}),
+            {vtable,
+             llvm::ConstantInt::get(
+                 i32_type, static_cast<uint64_t>(function.virtual_index) * 4)}),
+        args);
+  } else {
+    call = context.builder().CreateCall(
+        context.GetOrCreateFunction(callee_function.function_id,
+                                    callee_function.resolved_specific_id),
+        args);
+  }
+
+  context.SetLocal(inst_id, call);
 }
 
 }  // namespace Carbon::Lower

+ 45 - 0
toolchain/lower/testdata/class/virtual.carbon

@@ -75,6 +75,19 @@ fn Use() {
   var v : Derived = {.base = {}};
 }
 
+// --- call.carbon
+
+library "[[@TEST_NAME]]";
+
+base class Base {
+  virtual fn F[self: Self]();
+}
+
+fn Use(b: Base) {
+  b.F();
+}
+
+
 // CHECK:STDOUT: ; ModuleID = 'classes.carbon'
 // CHECK:STDOUT: source_filename = "classes.carbon"
 // CHECK:STDOUT:
@@ -271,3 +284,35 @@ fn Use() {
 // CHECK:STDOUT: !8 = !DILocation(line: 13, column: 21, scope: !4)
 // CHECK:STDOUT: !9 = !DILocation(line: 13, column: 30, scope: !4)
 // CHECK:STDOUT: !10 = !DILocation(line: 12, column: 1, scope: !4)
+// CHECK:STDOUT: ; ModuleID = 'call.carbon'
+// CHECK:STDOUT: source_filename = "call.carbon"
+// CHECK:STDOUT:
+// CHECK:STDOUT: @"_CBase.Main.$vtable" = unnamed_addr constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr @_CF.Base.Main to i64), i64 ptrtoint (ptr @"_CBase.Main.$vtable" to i64)) to i32)]
+// CHECK:STDOUT:
+// CHECK:STDOUT: declare void @_CF.Base.Main(ptr)
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_CUse.Main(ptr %b) !dbg !4 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   %F.call.vtable = load ptr, ptr %b, align 8, !dbg !7
+// CHECK:STDOUT:   %F.call = call ptr @llvm.load.relative.i32(ptr %F.call.vtable, i32 0), !dbg !7
+// CHECK:STDOUT:   call void %F.call(ptr %b), !dbg !7
+// CHECK:STDOUT:   ret void, !dbg !8
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: read)
+// CHECK:STDOUT: declare ptr @llvm.load.relative.i32(ptr, i32) #0
+// CHECK:STDOUT:
+// CHECK:STDOUT: attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: read) }
+// CHECK:STDOUT:
+// CHECK:STDOUT: !llvm.module.flags = !{!0, !1}
+// CHECK:STDOUT: !llvm.dbg.cu = !{!2}
+// CHECK:STDOUT:
+// CHECK:STDOUT: !0 = !{i32 7, !"Dwarf Version", i32 5}
+// CHECK:STDOUT: !1 = !{i32 2, !"Debug Info Version", i32 3}
+// CHECK:STDOUT: !2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "carbon", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug)
+// CHECK:STDOUT: !3 = !DIFile(filename: "call.carbon", directory: "")
+// CHECK:STDOUT: !4 = distinct !DISubprogram(name: "Use", linkageName: "_CUse.Main", scope: null, file: !3, line: 8, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !5 = !DISubroutineType(types: !6)
+// CHECK:STDOUT: !6 = !{}
+// CHECK:STDOUT: !7 = !DILocation(line: 9, column: 3, scope: !4)
+// CHECK:STDOUT: !8 = !DILocation(line: 8, column: 1, scope: !4)

+ 4 - 0
toolchain/sem_ir/function.h

@@ -44,6 +44,10 @@ struct FunctionFields {
   // this function.
   VirtualModifier virtual_modifier;
 
+  // The index of the vtable slot for this virtual function. -1 if the function
+  // is not virtual (ie: (virtual_modifier == None) == (virtual_index == -1)).
+  int32_t virtual_index = -1;
+
   // The implicit self parameter, if any, in implicit_param_patterns_id from
   // EntityWithParamsBase.
   InstId self_param_id = SemIR::InstId::None;