Эх сурвалжийг харах

Make C++ enum types impl Core.Copy. (#7013)

Remove special-case handling in conversion logic for C++ enum types,
synthesize a custom witness of `Core.Copy` using the `primitive_copy`
builtin function.
Richard Smith 1 сар өмнө
parent
commit
6f0ec37a8b

+ 5 - 52
toolchain/check/convert.cpp

@@ -1544,32 +1544,10 @@ static auto PerformBuiltinConversion(Context& context, SemIR::LocId loc_id,
   return value_id;
 }
 
-// Determine whether this is a C++ enum type.
-// TODO: This should be removed once we can properly add a `Copy` impl for C++
-// enum types.
-static auto IsCppEnum(Context& context, SemIR::TypeId type_id) -> bool {
-  auto class_type = context.types().TryGetAs<SemIR::ClassType>(type_id);
-  if (!class_type) {
-    return false;
-  }
-
-  // A C++-imported class type that is an adapter is an enum.
-  auto& class_info = context.classes().Get(class_type->class_id);
-  return class_info.adapt_id.has_value() &&
-         context.name_scopes().Get(class_info.scope_id).is_cpp_scope();
-}
-
 // Given a value expression, form a corresponding initializer that copies from
 // that value to the specified target, if it is possible to do so.
 static auto PerformCopy(Context& context, SemIR::InstId expr_id,
                         const ConversionTarget& target) -> SemIR::InstId {
-  // TODO: We don't have a mechanism yet to generate `Copy` impls for each enum
-  // type imported from C++. For now we fake it by providing a direct copy.
-  auto type_id = context.insts().Get(expr_id).type_id();
-  if (IsCppEnum(context, type_id)) {
-    return expr_id;
-  }
-
   auto copy_id = BuildUnaryOperator(
       context, SemIR::LocId(expr_id), {.interface_name = CoreIdentifier::Copy},
       expr_id, target.diagnose, [&](auto& builder) {
@@ -1695,10 +1673,7 @@ class CategoryConverter {
 auto CategoryConverter::DoStep(const SemIR::InstId expr_id,
                                const SemIR::ExprCategory category) const
     -> State {
-  CARBON_DCHECK(SemIR::GetExprCategory(sem_ir_, expr_id) == category ||
-                // TODO: Drop this special case once PerformCopy on C++ enums
-                // produces an initializing expression.
-                IsCppEnum(context_, target_.type_id));
+  CARBON_DCHECK(SemIR::GetExprCategory(sem_ir_, expr_id) == category);
   switch (category) {
     case SemIR::ExprCategory::NotExpr:
     case SemIR::ExprCategory::Mixed:
@@ -1721,11 +1696,8 @@ auto CategoryConverter::DoStep(const SemIR::InstId expr_id,
         // hasn't already been set. However, we skip this if the type is a C++
         // enum: in that case, we don't actually have an initializing
         // expression, we're just pretending we do.
-        auto new_storage_id = target_.storage_id;
-        if (!IsCppEnum(context_, target_.type_id)) {
-          new_storage_id =
-              OverwriteTemporaryStorageArg(sem_ir_, expr_id, target_);
-        }
+        auto new_storage_id =
+            OverwriteTemporaryStorageArg(sem_ir_, expr_id, target_);
 
         // If in-place initialization was requested, and it hasn't already
         // happened, ensure it happens now.
@@ -1842,27 +1814,8 @@ auto CategoryConverter::DoStep(const SemIR::InstId expr_id,
         if (copy_id == SemIR::ErrorInst::InstId) {
           return Done{SemIR::ErrorInst::InstId};
         }
-        // Deal with special-case category behavior of PerformCopy.
-        switch (SemIR::GetExprCategory(sem_ir_, copy_id)) {
-          case SemIR::ExprCategory::Value:
-            // As a temporary workaround, PerformCopy on a C++ enum currently
-            // returns the unchanged value, but we treat it as an initializing
-            // expression.
-            // TODO: Drop this case once it's no longer applicable.
-            CARBON_CHECK(IsCppEnum(context_, target_.type_id));
-            [[fallthrough]];
-          case SemIR::ExprCategory::ReprInitializing:
-            // The common case: PerformCopy produces an initializing expression.
-            return NextStep{.expr_id = copy_id,
-                            .category = SemIR::ExprCategory::ReprInitializing};
-          case SemIR::ExprCategory::InPlaceInitializing:
-            // A C++ copy operation produces an ephemeral entire reference.
-            return NextStep{
-                .expr_id = copy_id,
-                .category = SemIR::ExprCategory::InPlaceInitializing};
-          default:
-            CARBON_FATAL("Unexpected category of copy operation {0}", category);
-        }
+        return NextStep{.expr_id = copy_id,
+                        .category = SemIR::GetExprCategory(sem_ir_, copy_id)};
       }
 
       // When initializing a C++ thunk parameter, form a reference, creating a

+ 43 - 22
toolchain/check/cpp/impl_lookup.cpp

@@ -15,26 +15,34 @@
 #include "toolchain/check/import_ref.h"
 #include "toolchain/check/inst.h"
 #include "toolchain/check/type.h"
+#include "toolchain/sem_ir/builtin_function_kind.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/typed_insts.h"
 
 namespace Carbon::Check {
 
-// If the given type is a C++ class type, returns the corresponding class
-// declaration. Otherwise returns nullptr.
-// TODO: Handle qualified types.
-static auto TypeAsClassDecl(Context& context,
-                            SemIR::ConstantId query_self_const_id)
-    -> clang::CXXRecordDecl* {
+// Given a type constant, return the corresponding class scope if there is one.
+static auto GetClassScope(Context& context,
+                          SemIR::ConstantId query_self_const_id)
+    -> SemIR::NameScopeId {
   auto class_type = context.constant_values().TryGetInstAs<SemIR::ClassType>(
       query_self_const_id);
   if (!class_type) {
     // Not a class.
-    return nullptr;
+    return SemIR::NameScopeId::None;
   }
 
+  return context.classes().Get(class_type->class_id).scope_id;
+}
+
+// If the given type is a C++ tag (class or enumeration) type, returns the
+// corresponding tag declaration. Otherwise returns nullptr.
+// TODO: Handle qualified types.
+static auto TypeAsTagDecl(Context& context,
+                          SemIR::ConstantId query_self_const_id)
+    -> clang::TagDecl* {
   SemIR::NameScopeId class_scope_id =
-      context.classes().Get(class_type->class_id).scope_id;
+      GetClassScope(context, query_self_const_id);
   if (!class_scope_id.has_value()) {
     return nullptr;
   }
@@ -45,8 +53,16 @@ static auto TypeAsClassDecl(Context& context,
     return nullptr;
   }
 
-  return dyn_cast<clang::CXXRecordDecl>(
-      context.clang_decls().Get(decl_id).key.decl);
+  return dyn_cast<clang::TagDecl>(context.clang_decls().Get(decl_id).key.decl);
+}
+
+// If the given type is a C++ class type, returns the corresponding class
+// declaration. Otherwise returns nullptr.
+static auto TypeAsClassDecl(Context& context,
+                            SemIR::ConstantId query_self_const_id)
+    -> clang::CXXRecordDecl* {
+  return dyn_cast_or_null<clang::CXXRecordDecl>(
+      TypeAsTagDecl(context, query_self_const_id));
 }
 
 namespace {
@@ -92,21 +108,26 @@ static auto BuildCopyWitness(
     SemIR::SpecificInterfaceId query_specific_interface_id) -> SemIR::InstId {
   auto& clang_sema = context.clang_sema();
 
-  // TODO: This should provide `Copy` for enums and other trivially copyable
-  // types.
-  auto* class_decl = TypeAsClassDecl(context, query_self_const_id);
-  if (!class_decl) {
+  auto* tag_decl = TypeAsTagDecl(context, query_self_const_id);
+  if (!tag_decl) {
     return SemIR::InstId::None;
   }
-  auto decl_info = DeclInfo{.decl = clang_sema.LookupCopyingConstructor(
-                                class_decl, clang::Qualifiers::Const),
-                            .signature = {.num_params = 1}};
-  auto fn_id = GetFunctionId(context, loc_id, decl_info);
-  if (fn_id == SemIR::ErrorInst::InstId || fn_id == SemIR::InstId::None) {
-    return fn_id;
+  if (auto* class_decl = dyn_cast<clang::CXXRecordDecl>(tag_decl)) {
+    auto decl_info = DeclInfo{.decl = clang_sema.LookupCopyingConstructor(
+                                  class_decl, clang::Qualifiers::Const),
+                              .signature = {.num_params = 1}};
+    auto fn_id = GetFunctionId(context, loc_id, decl_info);
+    if (fn_id == SemIR::ErrorInst::InstId || fn_id == SemIR::InstId::None) {
+      return fn_id;
+    }
+    return BuildCustomWitness(context, loc_id, query_self_const_id,
+                              query_specific_interface_id, {fn_id});
   }
-  return BuildCustomWitness(context, loc_id, query_self_const_id,
-                            query_specific_interface_id, {fn_id});
+  // Otherwise it's an enum (or eventually a C struct type). Perform a primitive
+  // copy.
+  return BuildPrimitiveCopyWitness(
+      context, loc_id, GetClassScope(context, query_self_const_id),
+      query_self_const_id, query_specific_interface_id);
 }
 
 static auto BuildCppUnsafeDerefWitness(

+ 33 - 0
toolchain/check/custom_witness.cpp

@@ -40,6 +40,28 @@ static auto GetFacetAsType(Context& context,
   return context.types().GetTypeIdForTypeInstId(facet_or_type_id);
 }
 
+// Returns a manufactured `Copy.Op` function with the `self` parameter typed
+// to `self_type_id`.
+static auto MakeCopyOpFunction(Context& context, SemIR::LocId loc_id,
+                               SemIR::TypeId self_type_id,
+                               SemIR::NameScopeId parent_scope_id)
+    -> SemIR::InstId {
+  auto name_id = context.core_identifiers().AddNameId(CoreIdentifier::Op);
+
+  auto [decl_id, function_id] =
+      MakeGeneratedFunctionDecl(context, loc_id,
+                                {.parent_scope_id = parent_scope_id,
+                                 .name_id = name_id,
+                                 .self_type_id = self_type_id,
+                                 .self_is_ref = false,
+                                 .return_type_id = self_type_id});
+
+  auto& function = context.functions().Get(function_id);
+  function.SetCoreWitness(SemIR::BuiltinFunctionKind::PrimitiveCopy);
+
+  return decl_id;
+}
+
 // Returns the body for `Destroy.Op`. This will return `None` if using the
 // builtin `NoOp` is appropriate.
 //
@@ -318,6 +340,17 @@ auto GetCoreInterface(Context& context, SemIR::InterfaceId interface_id)
   return CoreInterface::Unknown;
 }
 
+auto BuildPrimitiveCopyWitness(
+    Context& context, SemIR::LocId loc_id, SemIR::NameScopeId parent_scope_id,
+    SemIR::ConstantId query_self_const_id,
+    SemIR::SpecificInterfaceId query_specific_interface_id) -> SemIR::InstId {
+  auto self_type_id = GetFacetAsType(context, query_self_const_id);
+  auto op_id =
+      MakeCopyOpFunction(context, loc_id, self_type_id, parent_scope_id);
+  return BuildCustomWitness(context, loc_id, query_self_const_id,
+                            query_specific_interface_id, {op_id});
+}
+
 // Returns true if the `Self` should impl `Destroy`.
 static auto TypeCanDestroy(Context& context,
                            SemIR::ConstantId query_self_const_id,

+ 6 - 0
toolchain/check/custom_witness.h

@@ -19,6 +19,12 @@ auto BuildCustomWitness(Context& context, SemIR::LocId loc_id,
                         SemIR::SpecificInterfaceId query_specific_interface_id,
                         llvm::ArrayRef<SemIR::InstId> values) -> SemIR::InstId;
 
+// Builds a witness that the given type is copyable via a primitive copy.
+auto BuildPrimitiveCopyWitness(
+    Context& context, SemIR::LocId loc_id, SemIR::NameScopeId parent_scope_id,
+    SemIR::ConstantId query_self_const_id,
+    SemIR::SpecificInterfaceId query_specific_interface_id) -> SemIR::InstId;
+
 // Significant interfaces in `Core` which correspond to language features and
 // can have custom witnesses.
 enum class CoreInterface {

+ 1 - 1
toolchain/check/function.cpp

@@ -121,7 +121,7 @@ static auto MakeFunctionSignature(Context& context, SemIR::LocId loc_id,
 
   StartFunctionSignature(context);
 
-  // Build and add a `[ref self: Self]` parameter if needed.
+  // Build and add a `self: Self` or `ref self: Self` parameter if needed.
   if (args.self_type_id.has_value()) {
     context.full_pattern_stack().StartImplicitParamList();
 

+ 19 - 2
toolchain/check/testdata/interop/cpp/enum/copy.carbon

@@ -40,7 +40,16 @@ fn F() {
 // CHECK:STDOUT:   %Enum: type = class_type @Enum [concrete]
 // CHECK:STDOUT:   %pattern_type.ebf: type = pattern_type %Enum [concrete]
 // CHECK:STDOUT:   %int_0: %Enum = int_value 0 [concrete]
+// CHECK:STDOUT:   %Copy.type: type = facet_type <@Copy> [concrete]
+// CHECK:STDOUT:   %Enum.Op.type: type = fn_type @Enum.Op [concrete]
+// CHECK:STDOUT:   %Enum.Op: %Enum.Op.type = struct_value () [concrete]
+// CHECK:STDOUT:   %custom_witness.0a8: <witness> = custom_witness (%Enum.Op), @Copy [concrete]
+// CHECK:STDOUT:   %Copy.facet.897: %Copy.type = facet_value %Enum, (%custom_witness.0a8) [concrete]
+// CHECK:STDOUT:   %Copy.WithSelf.Op.type.67c: type = fn_type @Copy.WithSelf.Op, @Copy.WithSelf(%Copy.facet.897) [concrete]
+// CHECK:STDOUT:   %.76d: type = fn_type_with_self_type %Copy.WithSelf.Op.type.67c, %Copy.facet.897 [concrete]
+// CHECK:STDOUT:   %Enum.Op.bound.3a0: <bound method> = bound_method %int_0, %Enum.Op [concrete]
 // CHECK:STDOUT:   %int_1: %Enum = int_value 1 [concrete]
+// CHECK:STDOUT:   %Enum.Op.bound.87b: <bound method> = bound_method %int_1, %Enum.Op [concrete]
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: imports {
@@ -67,7 +76,10 @@ fn F() {
 // CHECK:STDOUT:   %Cpp.ref.loc12_21: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
 // CHECK:STDOUT:   %Enum.ref.loc12_24: type = name_ref Enum, imports.%Enum.decl [concrete = constants.%Enum]
 // CHECK:STDOUT:   %a.ref.loc12: %Enum = name_ref a, imports.%int_0 [concrete = constants.%int_0]
-// CHECK:STDOUT:   assign %a.var, %a.ref.loc12
+// CHECK:STDOUT:   %impl.elem0.loc12: %.76d = impl_witness_access constants.%custom_witness.0a8, element0 [concrete = constants.%Enum.Op]
+// CHECK:STDOUT:   %bound_method.loc12: <bound method> = bound_method %a.ref.loc12, %impl.elem0.loc12 [concrete = constants.%Enum.Op.bound.3a0]
+// CHECK:STDOUT:   %Enum.Op.call.loc12: init %Enum = call %bound_method.loc12(%a.ref.loc12) [concrete = constants.%int_0]
+// CHECK:STDOUT:   assign %a.var, %Enum.Op.call.loc12
 // CHECK:STDOUT:   %.loc12: type = splice_block %Enum.ref.loc12_13 [concrete = constants.%Enum] {
 // CHECK:STDOUT:     %Cpp.ref.loc12_10: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
 // CHECK:STDOUT:     %Enum.ref.loc12_13: type = name_ref Enum, imports.%Enum.decl [concrete = constants.%Enum]
@@ -77,7 +89,12 @@ fn F() {
 // CHECK:STDOUT:   %Cpp.ref.loc14: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
 // CHECK:STDOUT:   %Enum.ref.loc14: type = name_ref Enum, imports.%Enum.decl [concrete = constants.%Enum]
 // CHECK:STDOUT:   %b.ref: %Enum = name_ref b, imports.%int_1 [concrete = constants.%int_1]
-// CHECK:STDOUT:   assign %a.ref.loc14, %b.ref
+// CHECK:STDOUT:   %impl.elem0.loc14: %.76d = impl_witness_access constants.%custom_witness.0a8, element0 [concrete = constants.%Enum.Op]
+// CHECK:STDOUT:   %bound_method.loc14: <bound method> = bound_method %b.ref, %impl.elem0.loc14 [concrete = constants.%Enum.Op.bound.87b]
+// CHECK:STDOUT:   %Enum.Op.call.loc14: init %Enum = call %bound_method.loc14(%b.ref) [concrete = constants.%int_1]
+// CHECK:STDOUT:   assign %a.ref.loc14, %Enum.Op.call.loc14
 // CHECK:STDOUT:   return
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
+// CHECK:STDOUT: fn @Enum.Op(%self.param: %Enum) -> out %return.param: %Enum = "primitive_copy";
+// CHECK:STDOUT: