Przeglądaj źródła

Support C++ calling Carbon functions with ref parameters (#7107)

When creating the C++ thunk, make the parameters references if the
corresponding callee parameters are `ref`s.

When creating the Carbon thunk, tag the call arguments as `ref` if the
corresponding callee parameters are `ref`s.
Nicholas Bishop 1 tydzień temu
rodzic
commit
3f63cf4b10

+ 28 - 12
toolchain/check/cpp/export.cpp

@@ -168,6 +168,14 @@ auto ExportClassToCpp(Context& context, SemIR::LocId loc_id,
 
 namespace {
 struct FunctionInfo {
+  struct Param {
+    // Type of the parameter's scrutinee.
+    SemIR::TypeId type_id;
+
+    // Whether this is a `ref` param.
+    bool is_ref;
+  };
+
   explicit FunctionInfo(Context& context, SemIR::FunctionId function_id,
                         const SemIR::Function& function,
                         clang::DeclContext* decl_context)
@@ -188,15 +196,17 @@ struct FunctionInfo {
       self_type_id = scrutinee_type_id;
     }
 
-    // Get the function's explicit parameter types.
+    // Get the function's explicit parameters.
     function_params =
         function_params.drop_front(function.call_param_ranges.implicit_size());
     function_params =
         function_params.drop_back(function.call_param_ranges.return_size());
     for (auto param_inst_id : function_params) {
-      auto scrutinee_type_id = ExtractScrutineeType(
-          context.sem_ir(), context.insts().Get(param_inst_id).type_id());
-      param_type_ids.push_back(scrutinee_type_id);
+      explicit_params.push_back(
+          {.type_id = ExtractScrutineeType(
+               context.sem_ir(), context.insts().Get(param_inst_id).type_id()),
+           .is_ref =
+               context.insts().Is<SemIR::RefParamPattern>(param_inst_id)});
     }
   }
 
@@ -220,9 +230,9 @@ struct FunctionInfo {
   // `CXXRecordDecl`.
   clang::DeclContext* decl_context;
 
-  // Types of the function's explicit parameters (excludes implicit
-  // parameters and return parameters).
-  llvm::SmallVector<SemIR::TypeId> param_type_ids;
+  // For each of the function's explicit parameters, the scrutinee type
+  // and whether the parameter is a reference.
+  llvm::SmallVector<Param> explicit_params;
 
   // Type of the function's `self` parameter, or `None` if the function
   // is not a method.
@@ -256,8 +266,8 @@ static auto BuildCppFunctionDeclForCarbonFn(Context& context,
     cpp_type = context.ast_context().getLValueReferenceType(cpp_type);
     cpp_param_types.push_back(cpp_type);
   }
-  for (auto param_type_id : callee.param_type_ids) {
-    auto cpp_type = MapToCppType(context, param_type_id);
+  for (auto param : callee.explicit_params) {
+    auto cpp_type = MapToCppType(context, param.type_id);
     if (cpp_type.isNull()) {
       context.TODO(loc_id, "failed to map Carbon type to C++");
       return nullptr;
@@ -471,12 +481,15 @@ static auto BuildCppToCarbonThunk(Context& context, SemIR::LocId loc_id,
   auto& thunk_ident = context.ast_context().Idents.get(thunk_name);
 
   llvm::SmallVector<clang::QualType> param_types;
-  for (auto type_id : target.param_type_ids) {
-    auto cpp_type = MapToCppType(context, type_id);
+  for (auto param : target.explicit_params) {
+    auto cpp_type = MapToCppType(context, param.type_id);
     if (cpp_type.isNull()) {
       context.TODO(loc_id, "failed to map C++ type to Carbon");
       return nullptr;
     }
+    if (param.is_ref) {
+      cpp_type = context.ast_context().getLValueReferenceType(cpp_type);
+    }
     param_types.push_back(cpp_type);
   }
 
@@ -513,7 +526,10 @@ static auto BuildCarbonToCarbonThunk(Context& context, SemIR::LocId loc_id,
   // Get the thunk's parameters. These match the callee parameters, with
   // the addition of an output parameter for the callee's return value
   // (if it has one).
-  llvm::SmallVector<SemIR::TypeId> thunk_param_type_ids(target.param_type_ids);
+  llvm::SmallVector<SemIR::TypeId> thunk_param_type_ids;
+  for (const auto& param : target.explicit_params) {
+    thunk_param_type_ids.push_back(param.type_id);
+  }
   auto callee_return_type_id =
       target.function.GetDeclaredReturnType(context.sem_ir());
   if (callee_return_type_id != SemIR::TypeId::None) {

+ 84 - 0
toolchain/check/testdata/interop/cpp/class/export/method.carbon

@@ -0,0 +1,84 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// INCLUDE-FILE: toolchain/testing/testdata/min_prelude/int.carbon
+//
+// AUTOUPDATE
+// TIP: To test this file alone, run:
+// TIP:   bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/check/testdata/interop/cpp/class/export/method.carbon
+// TIP: To dump output, run:
+// TIP:   bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/interop/cpp/class/export/method.carbon
+
+// --- static_method.carbon
+library "[[@TEST_NAME]]";
+import Cpp;
+
+class C {
+  fn M();
+}
+
+inline Cpp '''
+void F() {
+  Carbon::C::M();
+}
+''';
+
+// --- method_lvalue.carbon
+library "[[@TEST_NAME]]";
+import Cpp;
+
+class C {
+  fn M[self: Self]();
+}
+
+inline Cpp '''
+void F() {
+  Carbon::C c;
+  c.M();
+}
+''';
+
+// --- method_rvalue.carbon
+library "[[@TEST_NAME]]";
+import Cpp;
+
+class C {
+  fn M[self: Self]();
+}
+
+inline Cpp '''
+void F() {
+  Carbon::C().M();
+}
+''';
+
+// --- ref_method_lvalue.carbon
+library "[[@TEST_NAME]]";
+import Cpp;
+
+class C {
+  fn M[ref self: Self]();
+}
+
+inline Cpp '''
+void F() {
+  Carbon::C c;
+  c.M();
+}
+''';
+
+// --- todo_ref_method_rvalue.carbon
+library "[[@TEST_NAME]]";
+import Cpp;
+
+class C {
+  fn M[ref self: Self]();
+}
+
+inline Cpp '''
+void F() {
+  // TODO: this should be disallowed.
+  Carbon::C().M();
+}
+''';

+ 15 - 15
toolchain/check/testdata/interop/cpp/function/export/function.carbon

@@ -20,10 +20,6 @@ fn F3(_: i32) {}
 fn HasGenericArg(T:! type, a: T) { a; }
 fn HasDeducedArg[T:! type](a: T) { a; }
 
-class C {
-  fn Method[self: Self]() { self; }
-}
-
 // --- function.carbon
 
 library "[[@TEST_NAME]]";
@@ -57,6 +53,21 @@ void G() {
 }
 ''';
 
+// --- ref_arg.carbon
+
+library "[[@TEST_NAME]]";
+
+import Cpp;
+
+fn F(ref n: i32);
+
+inline Cpp '''
+void G() {
+  int n = 0;
+  Carbon::F(n);
+}
+''';
+
 // --- using.carbon
 
 library "[[@TEST_NAME]]";
@@ -113,14 +124,3 @@ void G() {
   Carbon::Other::HasDeducedArg(123);
 }
 ''';
-
-// --- method.carbon
-
-library "[[@TEST_NAME]]";
-
-import Other;
-import Cpp inline '''
-void G() {
-  Carbon::Other::C().Method();
-}
-''';

+ 24 - 3
toolchain/check/thunk.cpp

@@ -399,12 +399,33 @@ auto BuildThunkDefinitionForExport(Context& context,
     param_pattern_ids =
         context.inst_blocks().Get(thunk_function.param_patterns_id);
   }
-  auto call_param_ids =
-      context.inst_blocks().Get(thunk_function.call_params_id);
+  llvm::SmallVector<SemIR::InstId> call_param_ids(
+      context.inst_blocks().Get(thunk_function.call_params_id));
 
   if (thunk_has_return_param) {
     param_pattern_ids = param_pattern_ids.drop_back();
-    call_param_ids = call_param_ids.drop_back();
+    call_param_ids.pop_back();
+  }
+
+  auto callee_param_ids =
+      context.inst_blocks().Get(callee_function.call_param_patterns_id);
+
+  // If any explicit parameters of the callee are `ref` parameters,
+  // modify the corresponding call arguments to be `ref` tagged.
+  for (auto index = thunk_function.call_param_ranges.explicit_begin().index;
+       index < thunk_function.call_param_ranges.explicit_end().index; index++) {
+    if (context.insts().Is<SemIR::RefParamPattern>(callee_param_ids[index])) {
+      auto& call_param_id = call_param_ids[index];
+      auto type = context.insts().Get(call_param_id).type_id();
+      SemIR::LocId loc_id(thunk_id);
+      call_param_id =
+          AddInst(context, SemIR::LocIdAndInst::RuntimeVerified(
+                               context.sem_ir(), SemIR::LocId(call_param_id),
+                               SemIR::RefTagExpr{
+                                   .type_id = type,
+                                   .expr_id = call_param_id,
+                               }));
+    }
   }
 
   auto call_id = BuildThunkCall(context, thunk_function_id, callee_id,