Răsfoiți Sursa

Support const eval when calling a C++ thunk (#6947)

This makes it possible to do const eval when calling a constexpr C++
function with params and return types other than 32/64-bit integers.

Most of the new logic is in `MaybeModifyCppThunkCallForConstEval`, which
is called by `MakeConstantForCall`. This checks if the callee is a C++
thunk (using a new `SpecialFunctionKind::CppThunk` variant), and if so
it:
* Changes the callee from the C++ thunk to the thunk's callee
* Remaps parameters that are passed by pointer to the thunk to the
underlying value
* Drops the return value parameter, if present
Nicholas Bishop 1 lună în urmă
părinte
comite
0075d530b9

+ 2 - 1
toolchain/check/call.cpp

@@ -295,7 +295,8 @@ auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
 
     case SemIR::Function::SpecialFunctionKind::None:
     case SemIR::Function::SpecialFunctionKind::Builtin:
-    case SemIR::Function::SpecialFunctionKind::CoreWitness: {
+    case SemIR::Function::SpecialFunctionKind::CoreWitness:
+    case SemIR::Function::SpecialFunctionKind::CppThunk: {
       return GetOrAddInst<SemIR::Call>(context, loc_id,
                                        {.type_id = return_type_id,
                                         .callee_id = callee_id,

+ 68 - 0
toolchain/check/cpp/constant.cpp

@@ -196,6 +196,11 @@ auto MapConstantToAPValue(Context& context, SemIR::InstId const_inst_id,
 
 static auto ConvertArgToExpr(Context& context, SemIR::InstId arg_inst_id,
                              clang::QualType param_type) -> clang::Expr* {
+  if (auto temporary =
+          context.insts().TryGetAs<SemIR::Temporary>(arg_inst_id)) {
+    arg_inst_id = temporary->init_id;
+  }
+
   auto const_inst_id = context.constant_values().GetConstantInstId(arg_inst_id);
   if (!const_inst_id.has_value()) {
     return nullptr;
@@ -275,4 +280,67 @@ auto EvalCppCall(Context& context, SemIR::LocId loc_id,
                                           function_decl->getCallResultType());
 }
 
+auto MaybeModifyCppThunkCallForConstEval(Context& context, SemIR::Call* call)
+    -> void {
+  clang::FunctionDecl* function_decl = nullptr;
+  SemIR::InstId thunk_callee_inst_id = SemIR::InstId::None;
+
+  // Check if the callee is a C++ thunk for a constexpr function. If so,
+  // fill in `function_decl` and `thunk_callee_inst_id`.
+  auto callee = SemIR::GetCallee(context.sem_ir(), call->callee_id);
+  if (auto* callee_function = std::get_if<SemIR::CalleeFunction>(&callee)) {
+    auto function = context.functions().Get(callee_function->function_id);
+
+    thunk_callee_inst_id = function.cpp_thunk_callee();
+    if (!thunk_callee_inst_id.has_value()) {
+      return;
+    }
+    auto thunk_callee_function = context.functions().Get(
+        context.insts()
+            .GetAs<SemIR::FunctionDecl>(thunk_callee_inst_id)
+            .function_id);
+
+    function_decl =
+        cast<clang::FunctionDecl>(context.clang_decls()
+                                      .Get(thunk_callee_function.clang_decl_id)
+                                      .GetAsKey()
+                                      .decl);
+
+    if (!(function_decl->isConstexpr() || function_decl->isConsteval())) {
+      return;
+    }
+
+    if (function_decl->isDefaulted()) {
+      return;
+    }
+  } else {
+    return;
+  }
+
+  auto thunk_args = context.inst_blocks().Get(call->args_id);
+
+  // Get the new call arguments. This drops the return slot arg, if
+  // present. It also remaps arguments that are a pointer in the thunk,
+  // but a non-pointer in the callee.
+  llvm::SmallVector<SemIR::InstId> new_args;
+  for (auto [arg_inst_id, parm_var_decl] :
+       llvm::zip(thunk_args, function_decl->parameters())) {
+    auto parm_type = parm_var_decl->getType();
+    auto new_arg_inst_id = arg_inst_id;
+
+    // TODO: reuse the logic in `check/cpp/thunk.cpp` to determine
+    // whether to dereference the argument.
+    if (!parm_type->isPointerType()) {
+      if (auto addr_of = context.insts().TryGetAs<SemIR::AddrOf>(arg_inst_id)) {
+        new_arg_inst_id = addr_of->lvalue_id;
+      }
+    }
+
+    new_args.push_back(new_arg_inst_id);
+  }
+
+  call->callee_id = thunk_callee_inst_id;
+  call->args_id = context.inst_blocks().AddCanonical(new_args);
+}
+
 }  // namespace Carbon::Check

+ 6 - 0
toolchain/check/cpp/constant.h

@@ -34,6 +34,12 @@ auto EvalCppCall(Context& context, SemIR::LocId loc_id,
                  SemIR::ClangDeclId clang_decl_id, SemIR::InstBlockId args_id)
     -> SemIR::ConstantId;
 
+// If the callee is a C++ thunk, modify `call` to directly call the
+// callee of the C++ thunk. Otherwise, does nothing and leaves `call`
+// unmodified.
+auto MaybeModifyCppThunkCallForConstEval(Context& context, SemIR::Call* call)
+    -> void;
+
 }  // namespace Carbon::Check
 
 #endif  // CARBON_TOOLCHAIN_CHECK_CPP_CONSTANT_H_

+ 3 - 1
toolchain/check/cpp/import.cpp

@@ -1587,8 +1587,10 @@ static auto ImportFunctionDecl(Context& context, SemIR::LocId loc_id,
               context, loc_id, import_ir_inst_id, thunk_clang_decl,
               {.num_params =
                    static_cast<int32_t>(thunk_clang_decl->getNumParams())})) {
+        auto& thunk_function = context.functions().Get(*thunk_function_id);
+        thunk_function.SetCppThunk(function_info.first_owning_decl_id);
         SemIR::InstId thunk_function_decl_id =
-            context.functions().Get(*thunk_function_id).first_owning_decl_id;
+            thunk_function.first_owning_decl_id;
         function_info.SetHasCppThunk(thunk_function_decl_id);
       }
     }

+ 4 - 0
toolchain/check/eval.cpp

@@ -2149,6 +2149,10 @@ static auto MakeConstantForCall(EvalContext& eval_context,
     return SemIR::ErrorInst::ConstantId;
   }
 
+  // If the callee is a C++ thunk, modify the `call` to directly call
+  // the thunk's callee.
+  MaybeModifyCppThunkCallForConstEval(eval_context.context(), &call);
+
   // Find the constant value of the callee.
   bool has_constant_callee = ReplaceFieldWithConstantValue(
       eval_context, &call, &SemIR::Call::callee_id, &phase);

+ 6 - 0
toolchain/check/eval_inst.cpp

@@ -476,6 +476,12 @@ auto EvalConstantInst(Context& context, SemIR::InterfaceDecl inst)
       context.generics().GetSelfSpecific(interface_info.generic_id)));
 }
 
+auto EvalConstantInst(Context& context, SemIR::MarkInPlaceInit inst)
+    -> ConstantEvalResult {
+  return ConstantEvalResult::Existing(
+      context.constant_values().Get(inst.src_id));
+}
+
 auto EvalConstantInst(Context& context, SemIR::NamedConstraintDecl inst)
     -> ConstantEvalResult {
   const auto& named_constraint_info =

+ 1 - 0
toolchain/check/import_ref.cpp

@@ -2340,6 +2340,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
   }
 
   switch (import_function.special_function_kind) {
+    case SemIR::Function::SpecialFunctionKind::CppThunk:
     case SemIR::Function::SpecialFunctionKind::None: {
       break;
     }

+ 40 - 0
toolchain/check/testdata/interop/cpp/constexpr.carbon

@@ -69,6 +69,46 @@ constexpr int f(int a, int b) { return a + b; }
 
 let a: array(i32, Cpp.f(1, 2)) = (1, 2, 3);
 
+// --- function_bool_param.carbon
+
+library "[[@TEST_NAME]]";
+
+import Cpp inline '''
+constexpr int f(bool b) {
+  return b ? 3 : 0;
+}
+''';
+
+let a: array(i32, Cpp.f(true)) = (1, 2, 3);
+
+// --- function_float_param.carbon
+
+library "[[@TEST_NAME]]";
+
+import Cpp inline '''
+constexpr int f(float b) {
+  return static_cast<int>(b);
+}
+''';
+
+let a: array(i32, Cpp.f(3.0)) = (1, 2, 3);
+
+// --- function_return_bool.carbon
+
+library "[[@TEST_NAME]]";
+
+import Cpp inline '''
+constexpr bool f() {
+  return true;
+}
+''';
+
+musteval fn F(b: bool) -> i32 {
+  return if b then 3 else 0;
+}
+
+let a: array(i32, F(Cpp.f())) = (1, 2, 3);
+
 // --- fail_invalid_constant_eval.carbon
 
 library "[[@TEST_NAME]]";

+ 15 - 0
toolchain/sem_ir/function.h

@@ -24,6 +24,7 @@ struct FunctionFields {
     Builtin,
     CoreWitness,
     Thunk,
+    CppThunk,
     HasCppThunk,
   };
 
@@ -231,6 +232,13 @@ struct Function : public EntityWithParamsBase,
                : InstId::None;
   }
 
+  // Gets the `InstId` of the C++ function called by this thunk.
+  auto cpp_thunk_callee() const -> InstId {
+    return special_function_kind == SpecialFunctionKind::CppThunk
+               ? InstId(special_function_kind_data.index)
+               : InstId::None;
+  }
+
   // Gets the declared return type for a specific version of this function, or
   // the canonical return type for the original declaration no specific is
   // specified.  Returns `None` if no return type was specified, in which
@@ -271,6 +279,13 @@ struct Function : public EntityWithParamsBase,
     special_function_kind_data = AnyRawId(decl_id.index);
   }
 
+  // Sets that this function is a C++ thunk.
+  auto SetCppThunk(InstId decl_id) -> void {
+    CARBON_CHECK(special_function_kind == SpecialFunctionKind::None);
+    special_function_kind = SpecialFunctionKind::CppThunk;
+    special_function_kind_data = AnyRawId(decl_id.index);
+  }
+
   // Sets that this function is a C++ function that should be called using a C++
   // thunk.
   auto SetHasCppThunk(InstId decl_id) -> void {

+ 1 - 0
toolchain/sem_ir/mangler.cpp

@@ -194,6 +194,7 @@ auto Mangler::Mangle(SemIR::FunctionId function_id,
   // For a special function, add a marker to disambiguate.
   switch (function.special_function_kind) {
     case SemIR::Function::SpecialFunctionKind::None:
+    case SemIR::Function::SpecialFunctionKind::CppThunk:
       break;
 
     case SemIR::Function::SpecialFunctionKind::CoreWitness:

+ 1 - 2
toolchain/sem_ir/typed_insts.h

@@ -1239,8 +1239,7 @@ struct LookupImplWitness {
 struct MarkInPlaceInit {
   static constexpr auto Kind = InstKind::MarkInPlaceInit.Define<Parse::NodeId>(
       {.ir_name = "mark_in_place_init",
-       .expr_category = ExprCategory::InPlaceInitializing,
-       .constant_kind = InstConstantKind::Never});
+       .expr_category = ExprCategory::InPlaceInitializing});
 
   TypeId type_id;
   // Used only to track the source of the initialization; this has no semantic