Jelajahi Sumber

Factor out some parts of builtin handling and add declaration checking. (#3815)

In preparation for adding more builtins, factor out the handling of
builtin function kinds into separate files.

Add checking for builtin function signatures. The mechanism used here is
intended to provide a lot of flexibility for declaring generic builtin
functions and pretty arbitrary constraints on the types of parameters of
builtin functions. For now, these constraints are checked when the
builtin function is declared. The hope is that this will suffice, but if
not, it should be straightforward to switch to doing some of the
checking on call and share logic between the checks.

---------

Co-authored-by: Jon Ross-Perkins <jperkins@google.com>
Co-authored-by: Carbon Infra Bot <carbon-external-infra@google.com>
Richard Smith 2 tahun lalu
induk
melakukan
a3d77d9b74

+ 12 - 34
toolchain/check/eval.cpp

@@ -5,6 +5,7 @@
 #include "toolchain/check/eval.h"
 
 #include "toolchain/diagnostics/diagnostic_emitter.h"
+#include "toolchain/sem_ir/builtin_function_kind.h"
 #include "toolchain/sem_ir/function.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/typed_insts.h"
@@ -286,7 +287,7 @@ static auto PerformAggregateIndex(Context& context, SemIR::Inst inst)
 }
 
 static auto PerformBuiltinCall(Context& context, SemIRLocation loc,
-                               SemIR::Call call,
+                               SemIR::Call /*call*/,
                                SemIR::BuiltinFunctionKind builtin_kind,
                                llvm::ArrayRef<SemIR::InstId> arg_ids,
                                Phase phase) -> SemIR::ConstantId {
@@ -298,20 +299,12 @@ static auto PerformBuiltinCall(Context& context, SemIRLocation loc,
       if (phase != Phase::Template) {
         break;
       }
-      if (arg_ids.size() != 2) {
-        break;
-      }
-      auto lhs = context.insts().TryGetAs<SemIR::IntLiteral>(arg_ids[0]);
-      auto rhs = context.insts().TryGetAs<SemIR::IntLiteral>(arg_ids[1]);
-      // TODO: Move type checking to the point where we make the call.
-      if (!lhs || !rhs || lhs->type_id != rhs->type_id ||
-          call.type_id != lhs->type_id) {
-        break;
-      }
+      auto lhs = context.insts().GetAs<SemIR::IntLiteral>(arg_ids[0]);
+      auto rhs = context.insts().GetAs<SemIR::IntLiteral>(arg_ids[1]);
       // TODO: Integer values should be stored in the correct bit width for
       // their types. For now we assume i32.
-      auto lhs_val = context.ints().Get(lhs->int_id).sextOrTrunc(32);
-      auto rhs_val = context.ints().Get(rhs->int_id).sextOrTrunc(32);
+      auto lhs_val = context.ints().Get(lhs.int_id).sextOrTrunc(32);
+      auto rhs_val = context.ints().Get(rhs.int_id).sextOrTrunc(32);
       bool overflow = false;
       auto result = context.ints().Add(lhs_val.sadd_ov(rhs_val, overflow));
       if (overflow) {
@@ -322,28 +315,14 @@ static auto PerformBuiltinCall(Context& context, SemIRLocation loc,
                                llvm::APSInt(lhs_val, false),
                                llvm::APSInt(rhs_val, false));
       }
-      return MakeConstantResult(context,
-                                SemIR::IntLiteral{lhs->type_id, result}, phase);
+      return MakeConstantResult(context, SemIR::IntLiteral{lhs.type_id, result},
+                                phase);
     }
   }
 
   return SemIR::ConstantId::NotConstant;
 }
 
-// Extracts the callee function from a callee constant. Returns
-// FunctionId::Invalid if the callee is not known.
-static auto GetCalleeFunctionId(Context& context, SemIR::InstId callee_id)
-    -> SemIR::FunctionId {
-  if (auto bound_method =
-          context.insts().TryGetAs<SemIR::BoundMethod>(callee_id)) {
-    callee_id = bound_method->function_id;
-  }
-  if (auto callee = context.insts().TryGetAs<SemIR::FunctionDecl>(callee_id)) {
-    return {callee->function_id};
-  }
-  return {SemIR::FunctionId::Invalid};
-}
-
 static auto PerformCall(Context& context, SemIRLocation loc, SemIR::Call call)
     -> SemIR::ConstantId {
   Phase phase = Phase::Template;
@@ -362,11 +341,10 @@ static auto PerformCall(Context& context, SemIRLocation loc, SemIR::Call call)
     return SemIR::ConstantId::NotConstant;
   }
 
-  auto function_id = GetCalleeFunctionId(context, call.callee_id);
-
   // Handle calls to builtins.
-  auto& function = context.functions().Get(function_id);
-  if (function.builtin_kind != SemIR::BuiltinFunctionKind::None) {
+  if (auto builtin_function_kind = SemIR::BuiltinFunctionKind::ForCallee(
+          context.sem_ir(), call.callee_id);
+      builtin_function_kind != SemIR::BuiltinFunctionKind::None) {
     if (!ReplaceFieldWithConstantValue(context, &call, &SemIR::Call::args_id,
                                        &phase)) {
       return SemIR::ConstantId::NotConstant;
@@ -374,7 +352,7 @@ static auto PerformCall(Context& context, SemIRLocation loc, SemIR::Call call)
     if (phase == Phase::UnknownDueToError) {
       return SemIR::ConstantId::Error;
     }
-    return PerformBuiltinCall(context, loc, call, function.builtin_kind,
+    return PerformBuiltinCall(context, loc, call, builtin_function_kind,
                               context.inst_blocks().Get(call.args_id), phase);
   }
   return SemIR::ConstantId::NotConstant;

+ 46 - 8
toolchain/check/handle_function.cpp

@@ -10,6 +10,7 @@
 #include "toolchain/check/interface.h"
 #include "toolchain/check/modifiers.h"
 #include "toolchain/parse/tree_node_diagnostic_converter.h"
+#include "toolchain/sem_ir/builtin_function_kind.h"
 #include "toolchain/sem_ir/entry_point.h"
 #include "toolchain/sem_ir/function.h"
 #include "toolchain/sem_ir/ids.h"
@@ -306,9 +307,7 @@ static auto LookupBuiltinFunctionKind(Context& context,
   auto builtin_name = context.string_literal_values().Get(
       context.tokens().GetStringLiteralValue(
           context.parse_tree().node_token(name_id)));
-  auto kind = llvm::StringSwitch<SemIR::BuiltinFunctionKind>(builtin_name)
-                  .Case("int.add", SemIR::BuiltinFunctionKind::IntAdd)
-                  .Default(SemIR::BuiltinFunctionKind::None);
+  auto kind = SemIR::BuiltinFunctionKind::ForBuiltinName(builtin_name);
   if (kind == SemIR::BuiltinFunctionKind::None) {
     CARBON_DIAGNOSTIC(UnknownBuiltinFunctionName, Error,
                       "Unknown builtin function name \"{0}\".", std::string);
@@ -318,17 +317,56 @@ static auto LookupBuiltinFunctionKind(Context& context,
   return kind;
 }
 
+// Returns whether `function` is a valid declaration of the builtin
+// `builtin_kind`.
+static auto IsValidBuiltinDeclaration(Context& context,
+                                      const SemIR::Function& function,
+                                      SemIR::BuiltinFunctionKind builtin_kind)
+    -> bool {
+  // Form the list of parameter types for the declaration.
+  llvm::SmallVector<SemIR::TypeId> param_type_ids;
+  auto implicit_param_refs =
+      context.inst_blocks().Get(function.implicit_param_refs_id);
+  auto param_refs = context.inst_blocks().Get(function.param_refs_id);
+  param_type_ids.reserve(implicit_param_refs.size() + param_refs.size());
+  for (auto param_id :
+       llvm::concat<SemIR::InstId>(implicit_param_refs, param_refs)) {
+    // TODO: We also need to track whether the parameter is declared with
+    // `var`.
+    param_type_ids.push_back(context.insts().Get(param_id).type_id());
+  }
+
+  // Get the return type. This is `()` if none was specified.
+  auto return_type_id = function.return_type_id;
+  if (!return_type_id.is_valid()) {
+    return_type_id = context.GetTupleType({});
+  }
+
+  return builtin_kind.IsValidType(context.sem_ir(), param_type_ids,
+                                  return_type_id);
+}
+
 auto HandleBuiltinFunctionDefinition(
     Context& context, Parse::BuiltinFunctionDefinitionId /*node_id*/) -> bool {
   auto name_id =
       context.node_stack().PopForSoloNodeId<Parse::NodeKind::BuiltinName>();
-  auto function_id =
+  auto [fn_node_id, function_id] =
       context.node_stack()
-          .Pop<Parse::NodeKind::BuiltinFunctionDefinitionStart>();
-
-  auto& function = context.functions().Get(function_id);
-  function.builtin_kind = LookupBuiltinFunctionKind(context, name_id);
+          .PopWithNodeId<Parse::NodeKind::BuiltinFunctionDefinitionStart>();
 
+  auto builtin_kind = LookupBuiltinFunctionKind(context, name_id);
+  if (builtin_kind != SemIR::BuiltinFunctionKind::None) {
+    auto& function = context.functions().Get(function_id);
+    if (IsValidBuiltinDeclaration(context, function, builtin_kind)) {
+      function.builtin_kind = builtin_kind;
+    } else {
+      CARBON_DIAGNOSTIC(InvalidBuiltinSignature, Error,
+                        "Invalid signature for builtin function \"{0}\".",
+                        std::string);
+      context.emitter().Emit(fn_node_id, InvalidBuiltinSignature,
+                             builtin_kind.name().str());
+    }
+  }
   context.decl_name_stack().PopScope();
   return true;
 }

+ 76 - 65
toolchain/check/testdata/builtins/int_add.carbon

@@ -18,8 +18,20 @@ fn RuntimeCall(a: i32, b: i32) -> i32 {
 
 package FailBadDecl api;
 
+// CHECK:STDERR: fail_bad_decl.carbon:[[@LINE+4]]:1: ERROR: Invalid signature for builtin function "int.add".
+// CHECK:STDERR: fn TooFew(a: i32) -> i32 = "int.add";
+// CHECK:STDERR: ^~~~~~~~~~~~~~~~~~~~~~~~~~
+// CHECK:STDERR:
 fn TooFew(a: i32) -> i32 = "int.add";
+// CHECK:STDERR: fail_bad_decl.carbon:[[@LINE+4]]:1: ERROR: Invalid signature for builtin function "int.add".
+// CHECK:STDERR: fn TooMany(a: i32, b: i32, c: i32) -> i32 = "int.add";
+// CHECK:STDERR: ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// CHECK:STDERR:
 fn TooMany(a: i32, b: i32, c: i32) -> i32 = "int.add";
+// CHECK:STDERR: fail_bad_decl.carbon:[[@LINE+4]]:1: ERROR: Invalid signature for builtin function "int.add".
+// CHECK:STDERR: fn BadReturnType(a: i32, b: i32) -> bool = "int.add";
+// CHECK:STDERR: ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// CHECK:STDERR:
 fn BadReturnType(a: i32, b: i32) -> bool = "int.add";
 fn JustRight(a: i32, b: i32) -> i32 = "int.add";
 
@@ -48,7 +60,6 @@ var bad_return_type: [i32; BadReturnType(1, 2)];
 // CHECK:STDERR:
 var bad_call: [i32; JustRight(1, 2, 3)];
 
-// TODO: We should diagnose these in check rather than failing in lower.
 fn RuntimeCallTooFew(a: i32) -> i32 {
   return TooFew(a);
 }
@@ -148,87 +159,87 @@ let b: i32 = Add(0x7FFFFFFF, 1);
 // CHECK:STDOUT:     .RuntimeCallBadReturnType = %RuntimeCallBadReturnType
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %TooFew: <function> = fn_decl @TooFew [template] {
-// CHECK:STDOUT:     %a.loc4_11.1: i32 = param a
-// CHECK:STDOUT:     @TooFew.%a: i32 = bind_name a, %a.loc4_11.1
-// CHECK:STDOUT:     %return.var.loc4: ref i32 = var <return slot>
+// CHECK:STDOUT:     %a.loc8_11.1: i32 = param a
+// CHECK:STDOUT:     @TooFew.%a: i32 = bind_name a, %a.loc8_11.1
+// CHECK:STDOUT:     %return.var.loc8: ref i32 = var <return slot>
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %TooMany: <function> = fn_decl @TooMany [template] {
-// CHECK:STDOUT:     %a.loc5_12.1: i32 = param a
-// CHECK:STDOUT:     @TooMany.%a: i32 = bind_name a, %a.loc5_12.1
-// CHECK:STDOUT:     %b.loc5_20.1: i32 = param b
-// CHECK:STDOUT:     @TooMany.%b: i32 = bind_name b, %b.loc5_20.1
-// CHECK:STDOUT:     %c.loc5_28.1: i32 = param c
-// CHECK:STDOUT:     @TooMany.%c: i32 = bind_name c, %c.loc5_28.1
-// CHECK:STDOUT:     %return.var.loc5: ref i32 = var <return slot>
+// CHECK:STDOUT:     %a.loc13_12.1: i32 = param a
+// CHECK:STDOUT:     @TooMany.%a: i32 = bind_name a, %a.loc13_12.1
+// CHECK:STDOUT:     %b.loc13_20.1: i32 = param b
+// CHECK:STDOUT:     @TooMany.%b: i32 = bind_name b, %b.loc13_20.1
+// CHECK:STDOUT:     %c.loc13_28.1: i32 = param c
+// CHECK:STDOUT:     @TooMany.%c: i32 = bind_name c, %c.loc13_28.1
+// CHECK:STDOUT:     %return.var.loc13: ref i32 = var <return slot>
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %BadReturnType: <function> = fn_decl @BadReturnType [template] {
-// CHECK:STDOUT:     %a.loc6_18.1: i32 = param a
-// CHECK:STDOUT:     @BadReturnType.%a: i32 = bind_name a, %a.loc6_18.1
-// CHECK:STDOUT:     %b.loc6_26.1: i32 = param b
-// CHECK:STDOUT:     @BadReturnType.%b: i32 = bind_name b, %b.loc6_26.1
-// CHECK:STDOUT:     %return.var.loc6: ref bool = var <return slot>
+// CHECK:STDOUT:     %a.loc18_18.1: i32 = param a
+// CHECK:STDOUT:     @BadReturnType.%a: i32 = bind_name a, %a.loc18_18.1
+// CHECK:STDOUT:     %b.loc18_26.1: i32 = param b
+// CHECK:STDOUT:     @BadReturnType.%b: i32 = bind_name b, %b.loc18_26.1
+// CHECK:STDOUT:     %return.var.loc18: ref bool = var <return slot>
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %JustRight: <function> = fn_decl @JustRight [template] {
-// CHECK:STDOUT:     %a.loc7_14.1: i32 = param a
-// CHECK:STDOUT:     @JustRight.%a: i32 = bind_name a, %a.loc7_14.1
-// CHECK:STDOUT:     %b.loc7_22.1: i32 = param b
-// CHECK:STDOUT:     @JustRight.%b: i32 = bind_name b, %b.loc7_22.1
-// CHECK:STDOUT:     %return.var.loc7: ref i32 = var <return slot>
+// CHECK:STDOUT:     %a.loc19_14.1: i32 = param a
+// CHECK:STDOUT:     @JustRight.%a: i32 = bind_name a, %a.loc19_14.1
+// CHECK:STDOUT:     %b.loc19_22.1: i32 = param b
+// CHECK:STDOUT:     @JustRight.%b: i32 = bind_name b, %b.loc19_22.1
+// CHECK:STDOUT:     %return.var.loc19: ref i32 = var <return slot>
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %TooFew.ref: <function> = name_ref TooFew, %TooFew [template = %TooFew]
-// CHECK:STDOUT:   %.loc13_27: i32 = int_literal 1 [template = constants.%.1]
-// CHECK:STDOUT:   %.loc13_26: init i32 = call %TooFew.ref(%.loc13_27)
+// CHECK:STDOUT:   %.loc25_27: i32 = int_literal 1 [template = constants.%.1]
+// CHECK:STDOUT:   %.loc25_26: init i32 = call %TooFew.ref(%.loc25_27)
 // CHECK:STDOUT:   %too_few.var: ref <error> = var too_few
 // CHECK:STDOUT:   %too_few: ref <error> = bind_name too_few, %too_few.var
 // CHECK:STDOUT:   %TooMany.ref: <function> = name_ref TooMany, %TooMany [template = %TooMany]
-// CHECK:STDOUT:   %.loc18_29: i32 = int_literal 1 [template = constants.%.1]
-// CHECK:STDOUT:   %.loc18_32: i32 = int_literal 2 [template = constants.%.2]
-// CHECK:STDOUT:   %.loc18_35: i32 = int_literal 3 [template = constants.%.3]
-// CHECK:STDOUT:   %.loc18_28: init i32 = call %TooMany.ref(%.loc18_29, %.loc18_32, %.loc18_35)
+// CHECK:STDOUT:   %.loc30_29: i32 = int_literal 1 [template = constants.%.1]
+// CHECK:STDOUT:   %.loc30_32: i32 = int_literal 2 [template = constants.%.2]
+// CHECK:STDOUT:   %.loc30_35: i32 = int_literal 3 [template = constants.%.3]
+// CHECK:STDOUT:   %.loc30_28: init i32 = call %TooMany.ref(%.loc30_29, %.loc30_32, %.loc30_35)
 // CHECK:STDOUT:   %too_many.var: ref <error> = var too_many
 // CHECK:STDOUT:   %too_many: ref <error> = bind_name too_many, %too_many.var
 // CHECK:STDOUT:   %BadReturnType.ref: <function> = name_ref BadReturnType, %BadReturnType [template = %BadReturnType]
-// CHECK:STDOUT:   %.loc23_42: i32 = int_literal 1 [template = constants.%.1]
-// CHECK:STDOUT:   %.loc23_45: i32 = int_literal 2 [template = constants.%.2]
-// CHECK:STDOUT:   %.loc23_41: init bool = call %BadReturnType.ref(%.loc23_42, %.loc23_45)
+// CHECK:STDOUT:   %.loc35_42: i32 = int_literal 1 [template = constants.%.1]
+// CHECK:STDOUT:   %.loc35_45: i32 = int_literal 2 [template = constants.%.2]
+// CHECK:STDOUT:   %.loc35_41: init bool = call %BadReturnType.ref(%.loc35_42, %.loc35_45)
 // CHECK:STDOUT:   %bad_return_type.var: ref <error> = var bad_return_type
 // CHECK:STDOUT:   %bad_return_type: ref <error> = bind_name bad_return_type, %bad_return_type.var
 // CHECK:STDOUT:   %JustRight.ref: <function> = name_ref JustRight, %JustRight [template = %JustRight]
-// CHECK:STDOUT:   %.loc32_31: i32 = int_literal 1 [template = constants.%.1]
-// CHECK:STDOUT:   %.loc32_34: i32 = int_literal 2 [template = constants.%.2]
-// CHECK:STDOUT:   %.loc32_37: i32 = int_literal 3 [template = constants.%.3]
-// CHECK:STDOUT:   %.loc32_30: init i32 = call %JustRight.ref(<invalid>) [template = <error>]
-// CHECK:STDOUT:   %.loc32_39: type = array_type %.loc32_30, i32 [template = <error>]
+// CHECK:STDOUT:   %.loc44_31: i32 = int_literal 1 [template = constants.%.1]
+// CHECK:STDOUT:   %.loc44_34: i32 = int_literal 2 [template = constants.%.2]
+// CHECK:STDOUT:   %.loc44_37: i32 = int_literal 3 [template = constants.%.3]
+// CHECK:STDOUT:   %.loc44_30: init i32 = call %JustRight.ref(<invalid>) [template = <error>]
+// CHECK:STDOUT:   %.loc44_39: type = array_type %.loc44_30, i32 [template = <error>]
 // CHECK:STDOUT:   %bad_call.var: ref <error> = var bad_call
 // CHECK:STDOUT:   %bad_call: ref <error> = bind_name bad_call, %bad_call.var
 // CHECK:STDOUT:   %RuntimeCallTooFew: <function> = fn_decl @RuntimeCallTooFew [template] {
-// CHECK:STDOUT:     %a.loc35_22.1: i32 = param a
-// CHECK:STDOUT:     @RuntimeCallTooFew.%a: i32 = bind_name a, %a.loc35_22.1
-// CHECK:STDOUT:     %return.var.loc35: ref i32 = var <return slot>
+// CHECK:STDOUT:     %a.loc46_22.1: i32 = param a
+// CHECK:STDOUT:     @RuntimeCallTooFew.%a: i32 = bind_name a, %a.loc46_22.1
+// CHECK:STDOUT:     %return.var.loc46: ref i32 = var <return slot>
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %RuntimeCallTooMany: <function> = fn_decl @RuntimeCallTooMany [template] {
-// CHECK:STDOUT:     %a.loc39_23.1: i32 = param a
-// CHECK:STDOUT:     @RuntimeCallTooMany.%a: i32 = bind_name a, %a.loc39_23.1
-// CHECK:STDOUT:     %b.loc39_31.1: i32 = param b
-// CHECK:STDOUT:     @RuntimeCallTooMany.%b: i32 = bind_name b, %b.loc39_31.1
-// CHECK:STDOUT:     %c.loc39_39.1: i32 = param c
-// CHECK:STDOUT:     @RuntimeCallTooMany.%c: i32 = bind_name c, %c.loc39_39.1
-// CHECK:STDOUT:     %return.var.loc39: ref i32 = var <return slot>
+// CHECK:STDOUT:     %a.loc50_23.1: i32 = param a
+// CHECK:STDOUT:     @RuntimeCallTooMany.%a: i32 = bind_name a, %a.loc50_23.1
+// CHECK:STDOUT:     %b.loc50_31.1: i32 = param b
+// CHECK:STDOUT:     @RuntimeCallTooMany.%b: i32 = bind_name b, %b.loc50_31.1
+// CHECK:STDOUT:     %c.loc50_39.1: i32 = param c
+// CHECK:STDOUT:     @RuntimeCallTooMany.%c: i32 = bind_name c, %c.loc50_39.1
+// CHECK:STDOUT:     %return.var.loc50: ref i32 = var <return slot>
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %RuntimeCallBadReturnType: <function> = fn_decl @RuntimeCallBadReturnType [template] {
-// CHECK:STDOUT:     %a.loc43_29.1: i32 = param a
-// CHECK:STDOUT:     @RuntimeCallBadReturnType.%a: i32 = bind_name a, %a.loc43_29.1
-// CHECK:STDOUT:     %b.loc43_37.1: i32 = param b
-// CHECK:STDOUT:     @RuntimeCallBadReturnType.%b: i32 = bind_name b, %b.loc43_37.1
-// CHECK:STDOUT:     %return.var.loc43: ref bool = var <return slot>
+// CHECK:STDOUT:     %a.loc54_29.1: i32 = param a
+// CHECK:STDOUT:     @RuntimeCallBadReturnType.%a: i32 = bind_name a, %a.loc54_29.1
+// CHECK:STDOUT:     %b.loc54_37.1: i32 = param b
+// CHECK:STDOUT:     @RuntimeCallBadReturnType.%b: i32 = bind_name b, %b.loc54_37.1
+// CHECK:STDOUT:     %return.var.loc54: ref bool = var <return slot>
 // CHECK:STDOUT:   }
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: fn @TooFew(%a: i32) -> i32 = "int.add";
+// CHECK:STDOUT: fn @TooFew(%a: i32) -> i32;
 // CHECK:STDOUT:
-// CHECK:STDOUT: fn @TooMany(%a: i32, %b: i32, %c: i32) -> i32 = "int.add";
+// CHECK:STDOUT: fn @TooMany(%a: i32, %b: i32, %c: i32) -> i32;
 // CHECK:STDOUT:
-// CHECK:STDOUT: fn @BadReturnType(%a: i32, %b: i32) -> bool = "int.add";
+// CHECK:STDOUT: fn @BadReturnType(%a: i32, %b: i32) -> bool;
 // CHECK:STDOUT:
 // CHECK:STDOUT: fn @JustRight(%a: i32, %b: i32) -> i32 = "int.add";
 // CHECK:STDOUT:
@@ -236,10 +247,10 @@ let b: i32 = Add(0x7FFFFFFF, 1);
 // CHECK:STDOUT: !entry:
 // CHECK:STDOUT:   %TooFew.ref: <function> = name_ref TooFew, file.%TooFew [template = file.%TooFew]
 // CHECK:STDOUT:   %a.ref: i32 = name_ref a, %a
-// CHECK:STDOUT:   %.loc36_16.1: init i32 = call %TooFew.ref(%a.ref)
-// CHECK:STDOUT:   %.loc36_19: i32 = value_of_initializer %.loc36_16.1
-// CHECK:STDOUT:   %.loc36_16.2: i32 = converted %.loc36_16.1, %.loc36_19
-// CHECK:STDOUT:   return %.loc36_16.2
+// CHECK:STDOUT:   %.loc47_16.1: init i32 = call %TooFew.ref(%a.ref)
+// CHECK:STDOUT:   %.loc47_19: i32 = value_of_initializer %.loc47_16.1
+// CHECK:STDOUT:   %.loc47_16.2: i32 = converted %.loc47_16.1, %.loc47_19
+// CHECK:STDOUT:   return %.loc47_16.2
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: fn @RuntimeCallTooMany(%a: i32, %b: i32, %c: i32) -> i32 {
@@ -248,10 +259,10 @@ let b: i32 = Add(0x7FFFFFFF, 1);
 // CHECK:STDOUT:   %a.ref: i32 = name_ref a, %a
 // CHECK:STDOUT:   %b.ref: i32 = name_ref b, %b
 // CHECK:STDOUT:   %c.ref: i32 = name_ref c, %c
-// CHECK:STDOUT:   %.loc40_17.1: init i32 = call %TooMany.ref(%a.ref, %b.ref, %c.ref)
-// CHECK:STDOUT:   %.loc40_26: i32 = value_of_initializer %.loc40_17.1
-// CHECK:STDOUT:   %.loc40_17.2: i32 = converted %.loc40_17.1, %.loc40_26
-// CHECK:STDOUT:   return %.loc40_17.2
+// CHECK:STDOUT:   %.loc51_17.1: init i32 = call %TooMany.ref(%a.ref, %b.ref, %c.ref)
+// CHECK:STDOUT:   %.loc51_26: i32 = value_of_initializer %.loc51_17.1
+// CHECK:STDOUT:   %.loc51_17.2: i32 = converted %.loc51_17.1, %.loc51_26
+// CHECK:STDOUT:   return %.loc51_17.2
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: fn @RuntimeCallBadReturnType(%a: i32, %b: i32) -> bool {
@@ -259,10 +270,10 @@ let b: i32 = Add(0x7FFFFFFF, 1);
 // CHECK:STDOUT:   %BadReturnType.ref: <function> = name_ref BadReturnType, file.%BadReturnType [template = file.%BadReturnType]
 // CHECK:STDOUT:   %a.ref: i32 = name_ref a, %a
 // CHECK:STDOUT:   %b.ref: i32 = name_ref b, %b
-// CHECK:STDOUT:   %.loc44_23.1: init bool = call %BadReturnType.ref(%a.ref, %b.ref)
-// CHECK:STDOUT:   %.loc44_29: bool = value_of_initializer %.loc44_23.1
-// CHECK:STDOUT:   %.loc44_23.2: bool = converted %.loc44_23.1, %.loc44_29
-// CHECK:STDOUT:   return %.loc44_23.2
+// CHECK:STDOUT:   %.loc55_23.1: init bool = call %BadReturnType.ref(%a.ref, %b.ref)
+// CHECK:STDOUT:   %.loc55_29: bool = value_of_initializer %.loc55_23.1
+// CHECK:STDOUT:   %.loc55_23.2: bool = converted %.loc55_23.1, %.loc55_29
+// CHECK:STDOUT:   return %.loc55_23.2
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: --- fail_overflow.carbon

+ 1 - 0
toolchain/diagnostics/diagnostic_kind.def

@@ -172,6 +172,7 @@ CARBON_DIAGNOSTIC_KIND(FunctionRedeclReturnTypePreviousNoReturn)
 CARBON_DIAGNOSTIC_KIND(InvalidMainRunSignature)
 CARBON_DIAGNOSTIC_KIND(MissingReturnStatement)
 CARBON_DIAGNOSTIC_KIND(UnknownBuiltinFunctionName)
+CARBON_DIAGNOSTIC_KIND(InvalidBuiltinSignature)
 
 // Class checking.
 CARBON_DIAGNOSTIC_KIND(BaseIsFinal)

+ 2 - 33
toolchain/lower/handle.cpp

@@ -11,6 +11,7 @@
 #include "llvm/IR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "toolchain/lower/function_context.h"
+#include "toolchain/sem_ir/builtin_function_kind.h"
 #include "toolchain/sem_ir/function.h"
 #include "toolchain/sem_ir/inst.h"
 #include "toolchain/sem_ir/typed_insts.h"
@@ -167,26 +168,6 @@ auto HandleBuiltin(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
   CARBON_FATAL() << "TODO: Add support: " << inst;
 }
 
-// Returns the builtin function kind of the callee in a function call, or None
-// if the call is not to a builtin.
-static auto GetCalleeBuiltinFunctionKind(const SemIR::File& sem_ir,
-                                         SemIR::InstId callee_id)
-    -> SemIR::BuiltinFunctionKind {
-  if (auto bound_method =
-          sem_ir.insts().TryGetAs<SemIR::BoundMethod>(callee_id)) {
-    callee_id = bound_method->function_id;
-  }
-  callee_id = sem_ir.constant_values().Get(callee_id).inst_id();
-  if (!callee_id.is_valid()) {
-    return SemIR::BuiltinFunctionKind::None;
-  }
-  if (auto callee = sem_ir.insts().TryGetAs<SemIR::FunctionDecl>(callee_id)) {
-    const auto& function = sem_ir.functions().Get(callee->function_id);
-    return function.builtin_kind;
-  }
-  return SemIR::BuiltinFunctionKind::None;
-}
-
 // Handles a call to a builtin function.
 static auto HandleBuiltinCall(FunctionContext& context, SemIR::InstId inst_id,
                               SemIR::BuiltinFunctionKind builtin_kind,
@@ -196,18 +177,6 @@ static auto HandleBuiltinCall(FunctionContext& context, SemIR::InstId inst_id,
       CARBON_FATAL() << "No callee in function call.";
 
     case SemIR::BuiltinFunctionKind::IntAdd: {
-      // TODO: Move type checking to the point where we make the call.
-      if (arg_ids.size() != 2) {
-        break;
-      }
-      auto lhs_type = context.sem_ir().insts().Get(arg_ids[0]).type_id();
-      auto rhs_type = context.sem_ir().insts().Get(arg_ids[1]).type_id();
-      auto result_type = context.sem_ir().insts().Get(inst_id).type_id();
-      if (lhs_type != rhs_type || lhs_type != result_type ||
-          context.sem_ir().types().GetInstId(lhs_type) !=
-              SemIR::InstId::BuiltinIntType) {
-        break;
-      }
       constexpr bool SignedOverflowIsUB = false;
       context.SetLocal(inst_id, context.builder().CreateAdd(
                                     context.GetValue(arg_ids[0]),
@@ -231,7 +200,7 @@ auto HandleCall(FunctionContext& context, SemIR::InstId inst_id,
   // A null callee pointer value indicates this isn't a real function.
   if (!callee_value) {
     auto builtin_kind =
-        GetCalleeBuiltinFunctionKind(context.sem_ir(), inst.callee_id);
+        SemIR::BuiltinFunctionKind::ForCallee(context.sem_ir(), inst.callee_id);
     HandleBuiltinCall(context, inst_id, builtin_kind, arg_ids);
     return;
   }

+ 6 - 0
toolchain/sem_ir/BUILD

@@ -77,6 +77,7 @@ cc_library(
 cc_library(
     name = "file",
     srcs = [
+        "builtin_function_kind.cpp",
         "constant.cpp",
         "file.cpp",
         "inst_profile.cpp",
@@ -84,6 +85,7 @@ cc_library(
         "name.cpp",
     ],
     hdrs = [
+        "builtin_function_kind.h",
         "class.h",
         "constant.h",
         "copy_on_write_block.h",
@@ -95,6 +97,9 @@ cc_library(
         "name_scope.h",
         "type.h",
     ],
+    textual_hdrs = [
+        "builtin_function_kind.def",
+    ],
     deps = [
         ":block_value_store",
         ":builtin_kind",
@@ -103,6 +108,7 @@ cc_library(
         ":inst_kind",
         ":type_info",
         "//common:check",
+        "//common:enum_base",
         "//common:error",
         "//toolchain/base:value_store",
         "//toolchain/base:yaml",

+ 179 - 0
toolchain/sem_ir/builtin_function_kind.cpp

@@ -0,0 +1,179 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "toolchain/sem_ir/builtin_function_kind.h"
+
+#include <utility>
+
+#include "toolchain/sem_ir/file.h"
+#include "toolchain/sem_ir/ids.h"
+#include "toolchain/sem_ir/typed_insts.h"
+
+namespace Carbon::SemIR {
+
+// A function that validates that a builtin was declared properly.
+using ValidateFn = auto(const File& sem_ir, llvm::ArrayRef<TypeId> arg_types,
+                        TypeId return_type) -> bool;
+
+namespace {
+// Information about a builtin function.
+struct BuiltinInfo {
+  llvm::StringLiteral name;
+  ValidateFn* validate;
+};
+
+// The maximum number of type parameters any builtin needs.
+constexpr int MaxTypeParams = 1;
+
+// State used when validating a builtin signature that persists between
+// individual checks.
+struct ValidateState {
+  // The type values of type parameters in the builtin signature. Invalid if
+  // either no value has been deduced yet or the parameter is not used.
+  TypeId type_params[MaxTypeParams] = {TypeId::Invalid};
+};
+
+// Constraint that a type is generic type parameter `I` of the builtin,
+// satisfying `TypeConstraint`. See ValidateSignature for details.
+template <int I, typename TypeConstraint>
+struct TypeParam {
+  static_assert(I >= 0 && I < MaxTypeParams);
+
+  static auto Check(const File& sem_ir, ValidateState& state, TypeId type_id)
+      -> bool {
+    if (state.type_params[I].is_valid() && type_id != state.type_params[I]) {
+      return false;
+    }
+    state.type_params[I] = type_id;
+    return TypeConstraint::Check(sem_ir, state, type_id);
+  }
+};
+
+// Constraint that requires the type to be an integer type. See
+// ValidateSignature for details.
+struct AnyInt {
+  static auto Check(const File& sem_ir, ValidateState& /*state*/,
+                    TypeId type_id) -> bool {
+    if (sem_ir.types().GetInstId(type_id) == InstId::BuiltinIntType) {
+      return true;
+    }
+    // TODO: Support iN for all N, and the Core.BigInt type we use to implement
+    // for integer literals.
+    return false;
+  }
+};
+}  // namespace
+
+// Validates that this builtin has a signature matching the specified signature.
+//
+// `SignatureFnType` is a C++ function type that describes the signature that is
+// expected for this builtin. For example, `auto (AnyInt, AnyInt) -> AnyInt`
+// specifies that the builtin takes values of two integer types and returns a
+// value of a third integer type. Types used within the signature should provide
+// a `Check` function that validates that the Carbon type is expected:
+//
+//   auto Check(const File&, ValidateState&, TypeId) -> bool;
+//
+// To constrain that the same type is used in multiple places in the signature,
+// `TypeParam<I, T>` can be used. For example:
+//
+//   auto (TypeParam<0, AnyInt>, AnyInt) -> TypeParam<0, AnyInt>
+//
+// describes a builtin that takes two integers, and whose return type matches
+// its first parameter type. For convenience, typedefs for `TypeParam<I, T>`
+// are used in the descriptions of the builtins.
+template <typename SignatureFnType>
+static auto ValidateSignature(const File& sem_ir,
+                              llvm::ArrayRef<TypeId> arg_types,
+                              TypeId return_type) -> bool {
+  using SignatureTraits = llvm::function_traits<SignatureFnType*>;
+  ValidateState state;
+
+  // Must have expected number of arguments.
+  if (arg_types.size() != SignatureTraits::num_args) {
+    return false;
+  }
+
+  // Argument types must match.
+  if (![&]<std::size_t... Indexes>(std::index_sequence<Indexes...>) {
+        return ((SignatureTraits::template arg_t<Indexes>::Check(
+                    sem_ir, state, arg_types[Indexes])) &&
+                ...);
+      }(std::make_index_sequence<SignatureTraits::num_args>())) {
+    return false;
+  }
+
+  // Result type must match.
+  if (!SignatureTraits::result_t::Check(sem_ir, state, return_type)) {
+    return false;
+  }
+
+  return true;
+}
+
+// Descriptions of builtin functions follow. For each builtin, a corresponding
+// `BuiltinInfo` constant is declared describing properties of that builtin.
+namespace BuiltinFunctionInfo {
+
+// Convenience name used in the builtin type signatures below for a first
+// generic type parameter that is constrained to be an integer type.
+using IntT = TypeParam<0, AnyInt>;
+
+// Not a builtin function.
+constexpr BuiltinInfo None = {"", nullptr};
+
+// "int.add": integer addition.
+constexpr BuiltinInfo IntAdd = {"int.add",
+                                ValidateSignature<auto(IntT, IntT)->IntT>};
+
+}  // namespace BuiltinFunctionInfo
+
+CARBON_DEFINE_ENUM_CLASS_NAMES(BuiltinFunctionKind) = {
+#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
+  BuiltinFunctionInfo::Name.name,
+#include "toolchain/sem_ir/builtin_function_kind.def"
+};
+
+// Returns the builtin function kind with the given name, or None if the name
+// is unknown.
+auto BuiltinFunctionKind::ForBuiltinName(llvm::StringRef name)
+    -> BuiltinFunctionKind {
+#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
+  if (name == BuiltinFunctionInfo::Name.name) {   \
+    return BuiltinFunctionKind::Name;             \
+  }
+#include "toolchain/sem_ir/builtin_function_kind.def"
+  return BuiltinFunctionKind::None;
+}
+
+// Returns the builtin function kind corresponding to the given function
+// callee, or None if the callee is not known to be a builtin.
+auto BuiltinFunctionKind::ForCallee(const File& sem_ir, InstId callee_id)
+    -> BuiltinFunctionKind {
+  if (auto bound_method =
+          sem_ir.insts().TryGetAs<SemIR::BoundMethod>(callee_id)) {
+    callee_id = bound_method->function_id;
+  }
+  callee_id = sem_ir.constant_values().Get(callee_id).inst_id();
+  if (!callee_id.is_valid()) {
+    return SemIR::BuiltinFunctionKind::None;
+  }
+  if (auto callee = sem_ir.insts().TryGetAs<SemIR::FunctionDecl>(callee_id)) {
+    return sem_ir.functions().Get(callee->function_id).builtin_kind;
+  }
+  return SemIR::BuiltinFunctionKind::None;
+}
+
+auto BuiltinFunctionKind::IsValidType(const File& sem_ir,
+                                      llvm::ArrayRef<TypeId> arg_types,
+                                      TypeId return_type) const -> bool {
+  static constexpr ValidateFn* ValidateFns[] = {
+#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
+  BuiltinFunctionInfo::Name.validate,
+#include "toolchain/sem_ir/builtin_function_kind.def"
+  };
+  return ValidateFns[AsInt()](sem_ir, arg_types, return_type);
+}
+
+}  // namespace Carbon::SemIR

+ 22 - 0
toolchain/sem_ir/builtin_function_kind.def

@@ -0,0 +1,22 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// This is an X-macro header. It does not use `#include` guards, and instead is
+// designed to be `#include`ed after the x-macro is defined in order for its
+// inclusion to expand to the desired output. Macro definitions are cleaned up
+// at the end of this file.
+//
+// Supported x-macro is:
+// - CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name)
+//   Defines a builtin function type.
+
+#if !defined(CARBON_SEM_IR_BUILTIN_FUNCTION_KIND)
+#error \
+    "Must define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND x-macro to use this file."
+#endif
+
+CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(None)
+CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(IntAdd)
+
+#undef CARBON_SEM_IR_BUILTIN_FUNCTION_KIND

+ 51 - 0
toolchain/sem_ir/builtin_function_kind.h

@@ -0,0 +1,51 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_SEM_IR_BUILTIN_FUNCTION_KIND_H_
+#define CARBON_TOOLCHAIN_SEM_IR_BUILTIN_FUNCTION_KIND_H_
+
+#include <cstdint>
+
+#include "common/enum_base.h"
+#include "toolchain/sem_ir/ids.h"
+
+namespace Carbon::SemIR {
+
+class File;
+
+CARBON_DEFINE_RAW_ENUM_CLASS(BuiltinFunctionKind, std::uint8_t) {
+#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
+  CARBON_RAW_ENUM_ENUMERATOR(Name)
+#include "toolchain/sem_ir/builtin_function_kind.def"
+};
+
+// A kind of builtin function.
+class BuiltinFunctionKind : public CARBON_ENUM_BASE(BuiltinFunctionKind) {
+ public:
+#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
+  CARBON_ENUM_CONSTANT_DECL(Name)
+#include "toolchain/sem_ir/builtin_function_kind.def"
+
+  // Returns the builtin function kind with the given name, or None if the name
+  // is unknown.
+  static auto ForBuiltinName(llvm::StringRef name) -> BuiltinFunctionKind;
+
+  // Returns the builtin function kind corresponding to the given function
+  // callee, or None if the callee is not known to be a builtin.
+  static auto ForCallee(const File& sem_ir, InstId callee_id)
+      -> BuiltinFunctionKind;
+
+  // Determines whether this builtin function kind can have the specified
+  // function type.
+  auto IsValidType(const File& sem_ir, llvm::ArrayRef<TypeId> arg_types,
+                   TypeId return_type) const -> bool;
+};
+
+#define CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(Name) \
+  CARBON_ENUM_CONSTANT_DEFINITION(BuiltinFunctionKind, Name)
+#include "toolchain/sem_ir/builtin_function_kind.def"
+
+}  // namespace Carbon::SemIR
+
+#endif  // CARBON_TOOLCHAIN_SEM_IR_BUILTIN_FUNCTION_KIND_H_

+ 6 - 7
toolchain/sem_ir/formatter.cpp

@@ -12,6 +12,7 @@
 #include "toolchain/base/value_store.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/sem_ir/builtin_function_kind.h"
 #include "toolchain/sem_ir/function.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/typed_insts.h"
@@ -796,13 +797,11 @@ class Formatter {
       FormatType(fn.return_type_id);
     }
 
-    // TODO: Move this conversion of kind to string elsewhere.
-    switch (fn.builtin_kind) {
-      case BuiltinFunctionKind::None:
-        break;
-      case BuiltinFunctionKind::IntAdd:
-        out_ << " = \"int.add\"";
-        break;
+    if (fn.builtin_kind != BuiltinFunctionKind::None) {
+      out_ << " = \"";
+      out_.write_escaped(fn.builtin_kind.name(),
+                         /*UseHexEscapes=*/true);
+      out_ << "\"";
     }
 
     if (!fn.body_block_ids.empty()) {

+ 1 - 9
toolchain/sem_ir/function.h

@@ -5,20 +5,12 @@
 #ifndef CARBON_TOOLCHAIN_SEM_IR_FUNCTION_H_
 #define CARBON_TOOLCHAIN_SEM_IR_FUNCTION_H_
 
+#include "toolchain/sem_ir/builtin_function_kind.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/typed_insts.h"
 
 namespace Carbon::SemIR {
 
-// A builtin function.
-// TODO: Move out to another file.
-enum class BuiltinFunctionKind : std::uint8_t {
-  // Not a builtin function.
-  None,
-  // "int.add", integer addition.
-  IntAdd,
-};
-
 // A function.
 struct Function : public Printable<Function> {
   auto Print(llvm::raw_ostream& out) const -> void {