Explorar el Código

Limited support for indirect import of template specializations. (#7121)

When a class template specialization is indirectly imported, map the
template arguments into the importing File and find the corresponding
local class template specialization. This is a short-term fix:
eventually we should import the C++ AST from the imported file into the
C++ AST for the current file, but we're not ready to do that yet.

So far we only support very simple template arguments: just classes and
builtin types. Unfortunately we can't just map the C++ template
arguments to Carbon types, then import the Carbon types, then map them
back, because mapping from C++ template arguments to Carbon types would
require a `Check::Context` for the imported code, which we don't have.
As this is only a temporary workaround, directly mapping from one C++
AST to another will do for now.

Assisted-by: Gemini via Antigravity
Richard Smith hace 4 días
padre
commit
73adc479e3

+ 92 - 7
toolchain/check/cpp/import.cpp

@@ -150,22 +150,62 @@ auto ImportCpp(Context& context,
   }
 }
 
+// NOLINTNEXTLINE(misc-no-recursion)
+static auto FindCorrespondingType(Context& context, SemIR::LocId loc_id,
+                                  clang::QualType type) -> clang::QualType;
+
+// Given a class template specialization in some C++ AST which is *not* expected
+// to be `context`, find the corresponding declaration in `context`, if there is
+// one.
+// NOLINTNEXTLINE(misc-no-recursion)
+static auto FindCorrespondingTemplateSpecialization(
+    Context& context, SemIR::LocId loc_id,
+    const clang::ClassTemplateSpecializationDecl* source_spec,
+    clang::ClassTemplateDecl* target_template) -> clang::Decl* {
+  const auto& args = source_spec->getTemplateArgs();
+  auto loc = GetCppLocation(context, loc_id);
+  clang::TemplateArgumentListInfo arg_list(loc, loc);
+  for (unsigned i = 0; i < args.size(); ++i) {
+    const auto& arg = args[i];
+    if (arg.getKind() == clang::TemplateArgument::Type) {
+      auto type = FindCorrespondingType(context, loc_id, arg.getAsType());
+      if (type.isNull()) {
+        return nullptr;
+      }
+      arg_list.addArgument(clang::TemplateArgumentLoc(
+          clang::TemplateArgument(type),
+          context.ast_context().getTrivialTypeSourceInfo(type, loc)));
+    } else {
+      return nullptr;
+    }
+  }
+
+  clang::TemplateName template_name(target_template);
+  auto clang_type = context.clang_sema().CheckTemplateIdType(
+      clang::ElaboratedTypeKeyword::None, template_name, loc, arg_list,
+      /*Scope=*/nullptr, /*ForNestedNameSpecifier=*/false);
+  if (!clang_type.isNull()) {
+    return clang_type->getAsCXXRecordDecl();
+  }
+  return nullptr;
+}
+
 // Given a declaration in some C++ AST which is *not* expected to be `context`,
 // find the corresponding declaration in `context`, if there is one.
 // TODO: Make this non-recursive, or remove it once we support importing C++
 // ASTs for cross file imports.
 // NOLINTNEXTLINE(misc-no-recursion)
-static auto FindCorrespondingDecl(clang::ASTContext& context,
+static auto FindCorrespondingDecl(Context& context, SemIR::LocId loc_id,
                                   const clang::Decl* decl) -> clang::Decl* {
   if (const auto* named_decl = dyn_cast<clang::NamedDecl>(decl)) {
     auto* parent = dyn_cast_or_null<clang::DeclContext>(FindCorrespondingDecl(
-        context, cast<clang::Decl>(named_decl->getDeclContext())));
+        context, loc_id, cast<clang::Decl>(named_decl->getDeclContext())));
     if (!parent) {
       return nullptr;
     }
     clang::DeclarationName name;
     if (auto* identifier = named_decl->getDeclName().getAsIdentifierInfo()) {
-      name = &context.Idents.get(identifier->getName());
+      name = &context.ast_context().Idents.get(identifier->getName());
     } else {
       // TODO: Handle more name kinds.
       return nullptr;
@@ -174,20 +214,65 @@ static auto FindCorrespondingDecl(clang::ASTContext& context,
     // TODO: If there are multiple results, try to pick the right one.
     if (!decls.isSingleResult() ||
         decls.front()->getKind() != named_decl->getKind()) {
-      // TODO: If we were looking for a non-template and found a template, try
-      // to form a matching template specialization.
+      if (const auto* source_spec =
+              dyn_cast<clang::ClassTemplateSpecializationDecl>(named_decl)) {
+        if (auto* target_template =
+                dyn_cast<clang::ClassTemplateDecl>(decls.front())) {
+          if (auto* result = FindCorrespondingTemplateSpecialization(
+                  context, loc_id, source_spec, target_template)) {
+            return result;
+          }
+        }
+      }
       return nullptr;
     }
     return decls.front();
   }
 
   if (isa<clang::TranslationUnitDecl>(decl)) {
-    return context.getTranslationUnitDecl();
+    return context.ast_context().getTranslationUnitDecl();
   }
 
   return nullptr;
 }
 
+// Given a type in some C++ AST which is *not* expected to be `context`,
+// find the corresponding type in `context`, if there is one.
+// NOLINTNEXTLINE(misc-no-recursion)
+static auto FindCorrespondingType(Context& context, SemIR::LocId loc_id,
+                                  clang::QualType type) -> clang::QualType {
+  if (type.isNull()) {
+    return clang::QualType();
+  }
+
+  if (const auto* builtin = type->getAs<clang::BuiltinType>()) {
+    switch (builtin->getKind()) {
+#define BUILTIN_TYPE(Id, SingletonId) \
+  case clang::BuiltinType::Id:        \
+    return context.ast_context().SingletonId;
+#include "clang/AST/BuiltinTypes.def"
+#undef BUILTIN_TYPE
+      default:
+        return clang::QualType();
+    }
+  }
+
+  if (const auto* record = type->getAs<clang::RecordType>()) {
+    const auto* decl = record->getDecl();
+    auto* corresponding_decl = FindCorrespondingDecl(context, loc_id, decl);
+    if (!corresponding_decl) {
+      return clang::QualType();
+    }
+    if (const auto* tag_decl = dyn_cast<clang::TagDecl>(corresponding_decl)) {
+      return context.ast_context().getTypeDeclType(
+          cast<clang::TypeDecl>(tag_decl));
+    }
+    return clang::QualType();
+  }
+
+  return clang::QualType();
+}
+
 auto ImportCppDeclFromFile(Context& context, SemIR::LocId loc_id,
                            const SemIR::File& file,
                            SemIR::ClangDeclId clang_decl_id)
@@ -195,7 +280,7 @@ auto ImportCppDeclFromFile(Context& context, SemIR::LocId loc_id,
   CARBON_CHECK(clang_decl_id.has_value());
   auto key = file.clang_decls().Get(clang_decl_id).key;
   const auto* decl = key.decl;
-  auto* corresponding = FindCorrespondingDecl(context.ast_context(), decl);
+  auto* corresponding = FindCorrespondingDecl(context, loc_id, decl);
   if (!corresponding) {
     // TODO: This needs a proper diagnostic.
     context.TODO(

+ 105 - 12
toolchain/check/testdata/interop/cpp/basics/import/indirect.carbon

@@ -23,6 +23,10 @@ namespace A {
   struct Y {
     T y;
   };
+
+  struct Z {
+    int z;
+  };
 }
 
 // --- direct_import.carbon
@@ -35,6 +39,8 @@ fn F() -> Cpp.A.X { return 1; }
 
 fn G() -> Cpp.A.Y(i32);
 
+fn H() -> Cpp.A.Y(Cpp.A.Z);
+
 alias AX = Cpp.A.X;
 
 // --- indirect_import_via_function.carbon
@@ -108,17 +114,11 @@ fn Use() -> (i32, i32) {
   //@dump-sem-ir-end
 }
 
-// --- fail_todo_import_template.carbon
+// --- import_template.carbon
 
 library "[[@TEST_NAME]]";
 
-import Cpp;
-// CHECK:STDERR: fail_todo_import_template.carbon:[[@LINE+6]]:1: in import [InImport]
-// CHECK:STDERR: direct_import.carbon:4:10: in file included here [InCppInclude]
-// CHECK:STDERR: ./shared.h:10:10: error: semantics TODO: `use of imported C++ declaration with no corresponding local import` [SemanticsTodo]
-// CHECK:STDERR:   struct Y {
-// CHECK:STDERR:          ^
-// CHECK:STDERR:
+import Cpp library "shared.h";
 import library "direct_import";
 
 fn Use() -> i32 {
@@ -127,6 +127,12 @@ fn Use() -> i32 {
   //@dump-sem-ir-end
 }
 
+fn UseZ() -> Cpp.A.Z {
+  //@dump-sem-ir-begin
+  return H().y;
+  //@dump-sem-ir-end
+}
+
 // CHECK:STDOUT: --- indirect_import_via_function.carbon
 // CHECK:STDOUT:
 // CHECK:STDOUT: constants {
@@ -326,24 +332,111 @@ fn Use() -> i32 {
 // CHECK:STDOUT:   return %.loc10_24 to %return.param
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: --- fail_todo_import_template.carbon
+// CHECK:STDOUT: --- import_template.carbon
 // CHECK:STDOUT:
 // CHECK:STDOUT: constants {
 // CHECK:STDOUT:   %int_32: Core.IntLiteral = int_value 32 [concrete]
+// CHECK:STDOUT:   %empty_tuple.type: type = tuple_type () [concrete]
+// CHECK:STDOUT:   %N: Core.IntLiteral = symbolic_binding N, 0 [symbolic]
 // CHECK:STDOUT:   %i32: type = class_type @Int, @Int(%int_32) [concrete]
 // CHECK:STDOUT:   %G.type: type = fn_type @G [concrete]
 // CHECK:STDOUT:   %G: %G.type = struct_value () [concrete]
+// CHECK:STDOUT:   %Y.b7f625.1: type = class_type @Y.1 [concrete]
+// CHECK:STDOUT:   %Y.elem.080: type = unbound_element_type %Y.b7f625.1, %i32 [concrete]
+// CHECK:STDOUT:   %Copy.type: type = facet_type <@Copy> [concrete]
+// CHECK:STDOUT:   %Int.as.Copy.impl.Op.type.824: type = fn_type @Int.as.Copy.impl.Op, @Int.as.Copy.impl(%N) [symbolic]
+// CHECK:STDOUT:   %Int.as.Copy.impl.Op.9b9: %Int.as.Copy.impl.Op.type.824 = struct_value () [symbolic]
+// CHECK:STDOUT:   %Copy.impl_witness.f17: <witness> = impl_witness imports.%Copy.impl_witness_table.e76, @Int.as.Copy.impl(%int_32) [concrete]
+// CHECK:STDOUT:   %Int.as.Copy.impl.Op.type.546: type = fn_type @Int.as.Copy.impl.Op, @Int.as.Copy.impl(%int_32) [concrete]
+// CHECK:STDOUT:   %Int.as.Copy.impl.Op.664: %Int.as.Copy.impl.Op.type.546 = struct_value () [concrete]
+// CHECK:STDOUT:   %Copy.facet.de4: %Copy.type = facet_value %i32, (%Copy.impl_witness.f17) [concrete]
+// CHECK:STDOUT:   %Copy.WithSelf.Op.type.081: type = fn_type @Copy.WithSelf.Op, @Copy.WithSelf(%Copy.facet.de4) [concrete]
+// CHECK:STDOUT:   %.8e2: type = fn_type_with_self_type %Copy.WithSelf.Op.type.081, %Copy.facet.de4 [concrete]
+// CHECK:STDOUT:   %Int.as.Copy.impl.Op.specific_fn: <specific function> = specific_function %Int.as.Copy.impl.Op.664, @Int.as.Copy.impl.Op(%int_32) [concrete]
+// CHECK:STDOUT:   %Y.cpp_destructor.type.edb0f6.1: type = fn_type @Y.cpp_destructor.1 [concrete]
+// CHECK:STDOUT:   %Y.cpp_destructor.4b6d21.1: %Y.cpp_destructor.type.edb0f6.1 = struct_value () [concrete]
+// CHECK:STDOUT:   %Z: type = class_type @Z [concrete]
+// CHECK:STDOUT:   %H.type: type = fn_type @H [concrete]
+// CHECK:STDOUT:   %H: %H.type = struct_value () [concrete]
+// CHECK:STDOUT:   %Y.b7f625.2: type = class_type @Y.2 [concrete]
+// CHECK:STDOUT:   %Y.elem.f4e: type = unbound_element_type %Y.b7f625.2, %Z [concrete]
+// CHECK:STDOUT:   %Z.Z.type: type = fn_type @Z.Z [concrete]
+// CHECK:STDOUT:   %Z.Z: %Z.Z.type = struct_value () [concrete]
+// CHECK:STDOUT:   %const.bf3: type = const_type %Z [concrete]
+// CHECK:STDOUT:   %ptr.061: type = ptr_type %const.bf3 [concrete]
+// CHECK:STDOUT:   %ptr.81c: type = ptr_type %Z [concrete]
+// CHECK:STDOUT:   %Z__carbon_thunk.type: type = fn_type @Z__carbon_thunk [concrete]
+// CHECK:STDOUT:   %Z__carbon_thunk: %Z__carbon_thunk.type = struct_value () [concrete]
+// CHECK:STDOUT:   %Z.Op.type: type = fn_type @Z.Op [concrete]
+// CHECK:STDOUT:   %Z.Op: %Z.Op.type = struct_value () [concrete]
+// CHECK:STDOUT:   %custom_witness.8e5: <witness> = custom_witness (%Z.Op), @Copy [concrete]
+// CHECK:STDOUT:   %Copy.facet.a4a: %Copy.type = facet_value %Z, (%custom_witness.8e5) [concrete]
+// CHECK:STDOUT:   %Copy.WithSelf.Op.type.465: type = fn_type @Copy.WithSelf.Op, @Copy.WithSelf(%Copy.facet.a4a) [concrete]
+// CHECK:STDOUT:   %.9a1: type = fn_type_with_self_type %Copy.WithSelf.Op.type.465, %Copy.facet.a4a [concrete]
+// CHECK:STDOUT:   %Y.cpp_destructor.type.edb0f6.2: type = fn_type @Y.cpp_destructor.2 [concrete]
+// CHECK:STDOUT:   %Y.cpp_destructor.4b6d21.2: %Y.cpp_destructor.type.edb0f6.2 = struct_value () [concrete]
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: imports {
 // CHECK:STDOUT:   %Main.G: %G.type = import_ref Main//direct_import, G, loaded [concrete = constants.%G]
+// CHECK:STDOUT:   %Main.H: %H.type = import_ref Main//direct_import, H, loaded [concrete = constants.%H]
+// CHECK:STDOUT:   %Core.import_ref.18d: @Int.as.Copy.impl.%Int.as.Copy.impl.Op.type (%Int.as.Copy.impl.Op.type.824) = import_ref Core//prelude/parts/int, loc{{\d+_\d+}}, loaded [symbolic = @Int.as.Copy.impl.%Int.as.Copy.impl.Op (constants.%Int.as.Copy.impl.Op.9b9)]
+// CHECK:STDOUT:   %Copy.impl_witness_table.e76 = impl_witness_table (%Core.import_ref.18d), @Int.as.Copy.impl [concrete]
+// CHECK:STDOUT:   %Z.Z.decl: %Z.Z.type = fn_decl @Z.Z [concrete = constants.%Z.Z] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %Z__carbon_thunk.decl: %Z__carbon_thunk.type = fn_decl @Z__carbon_thunk [concrete = constants.%Z__carbon_thunk] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: fn @Use() -> out %return.param: %i32 {
 // CHECK:STDOUT: !entry:
 // CHECK:STDOUT:   %G.ref: %G.type = name_ref G, imports.%Main.G [concrete = constants.%G]
-// CHECK:STDOUT:   %G.call: <error> = call %G.ref()
-// CHECK:STDOUT:   %y.ref: <error> = name_ref y, <error> [concrete = <error>]
-// CHECK:STDOUT:   return <error>
+// CHECK:STDOUT:   %.loc9_12.1: ref %Y.b7f625.1 = temporary_storage
+// CHECK:STDOUT:   %G.call: init %Y.b7f625.1 to %.loc9_12.1 = call %G.ref()
+// CHECK:STDOUT:   %.loc9_12.2: ref %Y.b7f625.1 = temporary %.loc9_12.1, %G.call
+// CHECK:STDOUT:   %y.ref: %Y.elem.080 = name_ref y, @Y.1.%.1 [concrete = @Y.1.%.1]
+// CHECK:STDOUT:   %.loc9_13.1: ref %i32 = class_element_access %.loc9_12.2, element0
+// CHECK:STDOUT:   %.loc9_13.2: %i32 = acquire_value %.loc9_13.1
+// CHECK:STDOUT:   %impl.elem0: %.8e2 = impl_witness_access constants.%Copy.impl_witness.f17, element0 [concrete = constants.%Int.as.Copy.impl.Op.664]
+// CHECK:STDOUT:   %bound_method.loc9_13.1: <bound method> = bound_method %.loc9_13.2, %impl.elem0
+// CHECK:STDOUT:   %specific_fn: <specific function> = specific_function %impl.elem0, @Int.as.Copy.impl.Op(constants.%int_32) [concrete = constants.%Int.as.Copy.impl.Op.specific_fn]
+// CHECK:STDOUT:   %bound_method.loc9_13.2: <bound method> = bound_method %.loc9_13.2, %specific_fn
+// CHECK:STDOUT:   %Int.as.Copy.impl.Op.call: init %i32 = call %bound_method.loc9_13.2(%.loc9_13.2)
+// CHECK:STDOUT:   %Y.cpp_destructor.bound: <bound method> = bound_method %.loc9_12.2, constants.%Y.cpp_destructor.4b6d21.1
+// CHECK:STDOUT:   %Y.cpp_destructor.call: init %empty_tuple.type = call %Y.cpp_destructor.bound(%.loc9_12.2)
+// CHECK:STDOUT:   return %Int.as.Copy.impl.Op.call
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @UseZ() -> out %return.param: %Z {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   %H.ref: %H.type = name_ref H, imports.%Main.H [concrete = constants.%H]
+// CHECK:STDOUT:   %.loc15_12.1: ref %Y.b7f625.2 = temporary_storage
+// CHECK:STDOUT:   %H.call: init %Y.b7f625.2 to %.loc15_12.1 = call %H.ref()
+// CHECK:STDOUT:   %.loc15_12.2: ref %Y.b7f625.2 = temporary %.loc15_12.1, %H.call
+// CHECK:STDOUT:   %y.ref: %Y.elem.f4e = name_ref y, @Y.2.%.1 [concrete = @Y.2.%.1]
+// CHECK:STDOUT:   %.loc15_13.1: ref %Z = class_element_access %.loc15_12.2, element0
+// CHECK:STDOUT:   %.loc15_13.2: %Z = acquire_value %.loc15_13.1
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:   %impl.elem0: %.9a1 = impl_witness_access constants.%custom_witness.8e5, element0 [concrete = constants.%Z.Op]
+// CHECK:STDOUT:   %bound_method: <bound method> = bound_method %.loc15_13.2, %impl.elem0
+// CHECK:STDOUT:   %.loc15_13.3: ref %Z = temporary_storage
+// CHECK:STDOUT:   %Op.ref: %Z.Z.type = name_ref Op, imports.%Z.Z.decl [concrete = constants.%Z.Z]
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:   %.loc15_13.4: ref %Z = value_as_ref %.loc15_13.2
+// CHECK:STDOUT:   %addr.loc15_13.1: %ptr.81c = addr_of %.loc15_13.4
+// CHECK:STDOUT:   %.loc15_13.5: %ptr.061 = as_compatible %addr.loc15_13.1
+// CHECK:STDOUT:   %.loc15_13.6: %ptr.061 = converted %addr.loc15_13.1, %.loc15_13.5
+// CHECK:STDOUT:   %addr.loc15_13.2: %ptr.81c = addr_of %.loc13_19.1
+// CHECK:STDOUT:   %Z__carbon_thunk.call: init %empty_tuple.type = call imports.%Z__carbon_thunk.decl(%.loc15_13.6, %addr.loc15_13.2)
+// CHECK:STDOUT:   %.loc15_13.7: init %Z to %.loc13_19.1 = mark_in_place_init %Z__carbon_thunk.call
+// CHECK:STDOUT:   %Y.cpp_destructor.bound: <bound method> = bound_method %.loc15_12.2, constants.%Y.cpp_destructor.4b6d21.2
+// CHECK:STDOUT:   %Y.cpp_destructor.call: init %empty_tuple.type = call %Y.cpp_destructor.bound(%.loc15_12.2)
+// CHECK:STDOUT:   return %.loc15_13.7 to %return.param
 // CHECK:STDOUT: }
 // CHECK:STDOUT: