Przeglądaj źródła

Store the CppOverloadSetId on CalleeFunction. (#6067)

Richard Smith 7 miesięcy temu
rodzic
commit
0cafb8f0e4

+ 4 - 4
toolchain/check/call.cpp

@@ -299,9 +299,9 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
   if (callee_function.is_error) {
     return SemIR::ErrorInst::InstId;
   }
-  if (callee_function.is_cpp_overload_set) {
-    auto resolved_fn_id =
-        PerformCppOverloadResolution(context, loc_id, callee_id, arg_ids);
+  if (callee_function.cpp_overload_set_id.has_value()) {
+    auto resolved_fn_id = PerformCppOverloadResolution(
+        context, loc_id, callee_function.cpp_overload_set_id, arg_ids);
     if (!resolved_fn_id) {
       return SemIR::ErrorInst::InstId;
     }
@@ -310,7 +310,7 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
     if (callee_function.is_error) {
       return SemIR::ErrorInst::InstId;
     }
-    CARBON_CHECK(!callee_function.is_cpp_overload_set);
+    CARBON_CHECK(!callee_function.cpp_overload_set_id.has_value());
   }
   if (callee_function.function_id.has_value()) {
     return PerformCallToFunction(context, loc_id, callee_id, callee_function,

+ 2 - 11
toolchain/check/cpp_overload_resolution.cpp

@@ -12,7 +12,7 @@
 namespace Carbon::Check {
 
 auto PerformCppOverloadResolution(Context& context, SemIR::LocId loc_id,
-                                  SemIR::InstId callee_id,
+                                  SemIR::CppOverloadSetId overload_set_id,
                                   llvm::ArrayRef<SemIR::InstId> arg_ids)
     -> std::optional<SemIR::InstId> {
   Diagnostics::AnnotationScope annotate_diagnostics(
@@ -41,17 +41,8 @@ auto PerformCppOverloadResolution(Context& context, SemIR::LocId loc_id,
         clang::ExprValueKind::VK_LValue));
   }
 
-  auto overload_set_type =
-      context.types()
-          .GetAsInst(context.insts().Get(callee_id).type_id())
-          .TryAs<SemIR::CppOverloadSetType>();
-  // TODO: CHECK-fail or store CppOverloadSetId in the CalleeFunction and pass
-  // it in here.
-  if (!overload_set_type) {
-    return std::nullopt;
-  }
   const SemIR::CppOverloadSet& overload_set =
-      context.cpp_overload_sets().Get(overload_set_type->overload_set_id);
+      context.cpp_overload_sets().Get(overload_set_id);
 
   // Add candidate functions from the name lookup.
   clang::OverloadCandidateSet candidate_set(

+ 2 - 1
toolchain/check/cpp_overload_resolution.h

@@ -6,6 +6,7 @@
 #define CARBON_TOOLCHAIN_CHECK_CPP_OVERLOAD_RESOLUTION_H_
 
 #include "toolchain/check/context.h"
+#include "toolchain/sem_ir/ids.h"
 
 namespace Carbon::Check {
 
@@ -23,7 +24,7 @@ namespace Carbon::Check {
 // consistency and supporting migrations so that the migrated callers from C++
 // remain valid.
 auto PerformCppOverloadResolution(Context& context, SemIR::LocId loc_id,
-                                  SemIR::InstId callee_id,
+                                  SemIR::CppOverloadSetId overload_set_id,
                                   llvm::ArrayRef<SemIR::InstId> arg_ids)
     -> std::optional<SemIR::InstId>;
 

+ 4 - 5
toolchain/sem_ir/function.cpp

@@ -16,12 +16,12 @@ namespace Carbon::SemIR {
 auto GetCalleeFunction(const File& sem_ir, InstId callee_id,
                        SpecificId specific_id) -> CalleeFunction {
   CalleeFunction result = {.function_id = FunctionId::None,
+                           .cpp_overload_set_id = CppOverloadSetId::None,
                            .enclosing_specific_id = SpecificId::None,
                            .resolved_specific_id = SpecificId::None,
                            .self_type_id = InstId::None,
                            .self_id = InstId::None,
-                           .is_error = false,
-                           .is_cpp_overload_set = false};
+                           .is_error = false};
   if (auto bound_method = sem_ir.insts().TryGetAs<BoundMethod>(callee_id)) {
     result.self_id = bound_method->object_id;
     callee_id = bound_method->function_decl_id;
@@ -48,9 +48,8 @@ auto GetCalleeFunction(const File& sem_ir, InstId callee_id,
   auto fn_type_inst =
       sem_ir.types().GetAsInst(sem_ir.insts().Get(val_id).type_id());
 
-  if (fn_type_inst.TryAs<CppOverloadSetType>()) {
-    // TODO: Consider evaluating this at runtime instead of having a field.
-    result.is_cpp_overload_set = true;
+  if (auto cpp_overload_set_type = fn_type_inst.TryAs<CppOverloadSetType>()) {
+    result.cpp_overload_set_id = cpp_overload_set_type->overload_set_id;
     return result;
   }
 

+ 9 - 5
toolchain/sem_ir/function.h

@@ -188,6 +188,8 @@ class File;
 struct CalleeFunction : public Printable<CalleeFunction> {
   // The function. `None` if not a function.
   FunctionId function_id;
+  // The overload set, if this is a C++ overload set rather than a function.
+  CppOverloadSetId cpp_overload_set_id;
   // The specific that contains the function.
   SpecificId enclosing_specific_id;
   // The specific for the callee itself, in a resolved call.
@@ -199,13 +201,15 @@ struct CalleeFunction : public Printable<CalleeFunction> {
   InstId self_id;
   // True if an error instruction was found.
   bool is_error;
-  // True if the function is a C++ overload set.
-  // TODO: Store the CppOverloadSetId instead of a bool.
-  bool is_cpp_overload_set;
 
   auto Print(llvm::raw_ostream& out) const -> void {
-    out << "{function_id: " << function_id
-        << ", enclosing_specific_id: " << enclosing_specific_id
+    out << "{";
+    if (cpp_overload_set_id.has_value()) {
+      out << "cpp_overload_set_id: " << cpp_overload_set_id;
+    } else {
+      out << "function_id: " << function_id;
+    }
+    out << ", enclosing_specific_id: " << enclosing_specific_id
         << ", resolved_specific_id: " << resolved_specific_id
         << ", self_type_id: " << self_type_id << ", self_id: " << self_id
         << ", is_error: " << is_error << "}";