Преглед на файлове

Remove ArgKinds to encourage safer coding patterns (#5212)

#5171 ran into an issue where the wrong kind was associated with an arg
(`auto arg1 = RefineOperand(context, loc_id, arg0_kind,
action.arg1());`). This PR is trying to reduce risk of similar errors by
replaced `ArgKinds()` with instead an `ArgAndKind` structure and
corresponding accessors.

A couple things I considered and discarded were:

- Adding `CARBON_KIND_SWITCH` support (in this PR -- see #5216).
- The particular way that `ForCase` works would need to change, and I
was hesitant to do that here.
- But this is why I did add `As` to `ArgAndKind`, because it had me
thinking in that direction.
- Trying to make wrapper functions like `MutateArgs(callback_fn);`. This
kind of approach gets a little messy due to some of the conditional
passes, and in particular the reverse-iteration done for `PopOperand` in
subst.cpp
- Making something like `args_and_kinds() -> std::array<ArgAndKind, 2>`.
There's one spot where iteration is already set up as a loop, but for
others it felt a little convoluted with less gain than
`MutateArgs`-style things.

I'm not sure if there's a better way to set up the table generators, I
might keep tinkering with those for ideas.
Jon Ross-Perkins преди 1 година
родител
ревизия
4cb61ae4e1

+ 28 - 28
toolchain/check/action.cpp

@@ -44,21 +44,24 @@ auto OperandIsDependent(Context& context, SemIR::MetaInstId inst_id) -> bool {
          OperandIsDependent(context, context.constant_values().Get(inst_id));
 }
 
-static auto OperandIsDependent(Context& context, SemIR::IdKind kind,
-                               int32_t arg) -> bool {
-  if (kind == SemIR::IdKind::For<SemIR::MetaInstId>) {
-    return OperandIsDependent(context, SemIR::MetaInstId(arg));
-  }
-  if (kind == SemIR::IdKind::For<SemIR::TypeId>) {
-    return OperandIsDependent(context, SemIR::TypeId(arg));
-  }
-  if (kind == SemIR::IdKind::None ||
-      kind == SemIR::IdKind::For<SemIR::AbsoluteInstId> ||
-      kind == SemIR::IdKind::For<SemIR::NameId>) {
-    return false;
+static auto OperandIsDependent(Context& context, SemIR::Inst::ArgAndKind arg)
+    -> bool {
+  switch (arg.kind) {
+    case SemIR::IdKind::For<SemIR::MetaInstId>:
+      return OperandIsDependent(context, arg.As<SemIR::MetaInstId>());
+
+    case SemIR::IdKind::For<SemIR::TypeId>:
+      return OperandIsDependent(context, arg.As<SemIR::TypeId>());
+
+    case SemIR::IdKind::None:
+    case SemIR::IdKind::For<SemIR::AbsoluteInstId>:
+    case SemIR::IdKind::For<SemIR::NameId>:
+      return false;
+
+    default:
+      // TODO: Properly handle different argument kinds.
+      CARBON_FATAL("Unexpected argument kind for action");
   }
-  // TODO: Properly handle different argument kinds.
-  CARBON_FATAL("Unexpected argument kind for action");
 }
 
 auto ActionIsDependent(Context& context, SemIR::Inst action_inst) -> bool {
@@ -71,9 +74,8 @@ auto ActionIsDependent(Context& context, SemIR::Inst action_inst) -> bool {
   if (OperandIsDependent(context, action_inst.type_id())) {
     return true;
   }
-  auto [arg0_kind, arg1_kind] = action_inst.ArgKinds();
-  return OperandIsDependent(context, arg0_kind, action_inst.arg0()) ||
-         OperandIsDependent(context, arg1_kind, action_inst.arg1());
+  return OperandIsDependent(context, action_inst.arg0_and_kind()) ||
+         OperandIsDependent(context, action_inst.arg1_and_kind());
 }
 
 static auto AddDependentActionSpliceImpl(Context& context,
@@ -98,14 +100,13 @@ static auto AddDependentActionSpliceImpl(Context& context,
 // their concrete values, so that the action doesn't need to know which specific
 // it is operating on.
 static auto RefineOperand(Context& context, SemIR::LocId loc_id,
-                          SemIR::IdKind kind, int32_t arg) -> int32_t {
-  if (kind == SemIR::IdKind::For<SemIR::MetaInstId>) {
-    auto inst_id = SemIR::MetaInstId(arg);
-    auto inst = context.insts().Get(inst_id);
+                          SemIR::Inst::ArgAndKind arg) -> int32_t {
+  if (auto inst_id = arg.TryAs<SemIR::MetaInstId>()) {
+    auto inst = context.insts().Get(*inst_id);
     if (inst.Is<SemIR::SpliceInst>()) {
       // The argument will evaluate to the spliced instruction, which is already
       // refined.
-      return arg;
+      return arg.value;
     }
 
     // If the type of the action argument is dependent, refine to an instruction
@@ -116,7 +117,7 @@ static auto RefineOperand(Context& context, SemIR::LocId loc_id,
           SemIR::LocIdAndInst(loc_id,
                               SemIR::RefineTypeAction{
                                   .type_id = SemIR::InstType::SingletonTypeId,
-                                  .inst_id = inst_id,
+                                  .inst_id = *inst_id,
                                   .inst_type_id = inst.type_id()}),
           inst.type_id());
     }
@@ -124,19 +125,18 @@ static auto RefineOperand(Context& context, SemIR::LocId loc_id,
     // TODO: Handle the case where the constant value of the instruction is
     // template-dependent.
 
-    return inst_id.index;
+    return inst_id->index;
   }
 
-  return arg;
+  return arg.value;
 }
 
 // Refine the operands of an action, ensuring that they will refer to concrete
 // instructions that don't have template-dependent types.
 static auto RefineOperands(Context& context, SemIR::LocId loc_id,
                            SemIR::Inst action) -> SemIR::Inst {
-  auto [arg0_kind, arg1_kind] = action.ArgKinds();
-  auto arg0 = RefineOperand(context, loc_id, arg0_kind, action.arg0());
-  auto arg1 = RefineOperand(context, loc_id, arg1_kind, action.arg1());
+  auto arg0 = RefineOperand(context, loc_id, action.arg0_and_kind());
+  auto arg1 = RefineOperand(context, loc_id, action.arg1_and_kind());
   action.SetArgs(arg0, arg1);
   return action;
 }

+ 12 - 13
toolchain/check/deduce.cpp

@@ -138,37 +138,37 @@ class DeductionWorklist {
   }
 
   // Adds a (param, arg) pair for an instruction argument, given its kind.
-  auto AddInstArg(SemIR::IdKind kind, int32_t param, int32_t arg,
+  auto AddInstArg(SemIR::Inst::ArgAndKind param, int32_t arg,
                   bool needs_substitution) -> void {
-    switch (kind) {
+    switch (param.kind) {
       case SemIR::IdKind::None:
       case SemIR::IdKind::For<SemIR::ClassId>:
       case SemIR::IdKind::For<SemIR::IntKind>:
         break;
       case SemIR::IdKind::For<SemIR::InstId>:
-        Add(SemIR::InstId(param), SemIR::InstId(arg), needs_substitution);
+        Add(param.As<SemIR::InstId>(), SemIR::InstId(arg), needs_substitution);
         break;
       case SemIR::IdKind::For<SemIR::TypeId>:
-        Add(SemIR::TypeId(param), SemIR::TypeId(arg), needs_substitution);
+        Add(param.As<SemIR::TypeId>(), SemIR::TypeId(arg), needs_substitution);
         break;
       case SemIR::IdKind::For<SemIR::StructTypeFieldsId>:
-        AddAll(SemIR::StructTypeFieldsId(param), SemIR::StructTypeFieldsId(arg),
-               needs_substitution);
+        AddAll(param.As<SemIR::StructTypeFieldsId>(),
+               SemIR::StructTypeFieldsId(arg), needs_substitution);
         break;
       case SemIR::IdKind::For<SemIR::InstBlockId>:
-        AddAll(SemIR::InstBlockId(param), SemIR::InstBlockId(arg),
+        AddAll(param.As<SemIR::InstBlockId>(), SemIR::InstBlockId(arg),
                needs_substitution);
         break;
       case SemIR::IdKind::For<SemIR::TypeBlockId>:
-        AddAll(SemIR::TypeBlockId(param), SemIR::TypeBlockId(arg),
+        AddAll(param.As<SemIR::TypeBlockId>(), SemIR::TypeBlockId(arg),
                needs_substitution);
         break;
       case SemIR::IdKind::For<SemIR::SpecificId>:
-        Add(SemIR::SpecificId(param), SemIR::SpecificId(arg),
+        Add(param.As<SemIR::SpecificId>(), SemIR::SpecificId(arg),
             needs_substitution);
         break;
       case SemIR::IdKind::For<SemIR::FacetTypeId>:
-        AddAll(SemIR::FacetTypeId(param), SemIR::FacetTypeId(arg),
+        AddAll(param.As<SemIR::FacetTypeId>(), SemIR::FacetTypeId(arg),
                needs_substitution);
         break;
       default:
@@ -469,10 +469,9 @@ auto DeductionContext::Deduce() -> bool {
           if (arg_inst.kind() != param_inst.kind()) {
             break;
           }
-          auto [kind0, kind1] = param_inst.ArgKinds();
-          worklist_.AddInstArg(kind0, param_inst.arg0(), arg_inst.arg0(),
+          worklist_.AddInstArg(param_inst.arg0_and_kind(), arg_inst.arg0(),
                                needs_substitution);
-          worklist_.AddInstArg(kind1, param_inst.arg1(), arg_inst.arg1(),
+          worklist_.AddInstArg(param_inst.arg1_and_kind(), arg_inst.arg1(),
                                needs_substitution);
           continue;
         }

+ 44 - 31
toolchain/check/eval.cpp

@@ -640,36 +640,52 @@ static constexpr bool HasGetConstantValueOverload = requires {
   Accept<auto (*)(EvalContext&, IdT, Phase*)->IdT>(GetConstantValue);
 };
 
+using ArgHandlerFnT = auto(EvalContext& context, int32_t arg, Phase* phase)
+    -> int32_t;
+
+// Returns a lookup table to get constants by Id::Kind. Requires a null IdKind
+// as a parameter in order to get the type pack.
+template <typename... Types>
+static constexpr auto MakeArgHandlerTable(
+    SemIR::TypeEnum<Types...>* /*id_kind*/)
+    -> std::array<ArgHandlerFnT*, SemIR::IdKind::NumValues> {
+  std::array<ArgHandlerFnT*, SemIR::IdKind::NumValues> table = {};
+  ((table[SemIR::IdKind::template For<Types>.ToIndex()] =
+        [](EvalContext& eval_context, int32_t arg, Phase* phase) -> int32_t {
+     auto id = SemIR::Inst::FromRaw<Types>(arg);
+     if constexpr (HasGetConstantValueOverload<Types>) {
+       // If we have a custom `GetConstantValue` overload, call it.
+       return SemIR::Inst::ToRaw(GetConstantValue(eval_context, id, phase));
+     } else {
+       // Otherwise, we assume the value is already constant.
+       return arg;
+     }
+   }),
+   ...);
+  table[SemIR::IdKind::Invalid.ToIndex()] = [](EvalContext& /*context*/,
+                                               int32_t /*arg*/,
+                                               Phase* /*phase*/) -> int32_t {
+    CARBON_FATAL("Instruction has argument with invalid IdKind");
+  };
+  table[SemIR::IdKind::None.ToIndex()] =
+      [](EvalContext& /*context*/, int32_t arg, Phase* /*phase*/) -> int32_t {
+    return arg;
+  };
+  return table;
+}
+
 // Given the stored value `arg` of an instruction field and its corresponding
 // kind `kind`, returns the constant value to use for that field, if it has a
 // constant phase. `*phase` is updated to include the new constant value. If
 // the resulting phase is not constant, the returned value is not useful and
 // will typically be `NoneIndex`.
-template <typename... Type>
 static auto GetConstantValueForArg(EvalContext& eval_context,
-                                   SemIR::TypeEnum<Type...> kind, int32_t arg,
+                                   SemIR::Inst::ArgAndKind arg_and_kind,
                                    Phase* phase) -> int32_t {
-  using Handler = auto(EvalContext&, int32_t arg, Phase * phase)->int32_t;
-  static constexpr Handler* Handlers[] = {
-      [](EvalContext& eval_context, int32_t arg, Phase* phase) -> int32_t {
-        auto id = SemIR::Inst::FromRaw<Type>(arg);
-        if constexpr (HasGetConstantValueOverload<Type>) {
-          // If we have a custom `GetConstantValue` overload, call it.
-          return SemIR::Inst::ToRaw(GetConstantValue(eval_context, id, phase));
-        } else {
-          // Otherwise, we assume the value is already constant.
-          return arg;
-        }
-      }...,
-      [](EvalContext&, int32_t, Phase*) -> int32_t {
-        // Handler for IdKind::Invalid is next.
-        CARBON_FATAL("Instruction has argument with invalid IdKind");
-      },
-      [](EvalContext&, int32_t arg, Phase*) -> int32_t {
-        // Handler for IdKind::None is last.
-        return arg;
-      }};
-  return Handlers[kind.ToIndex()](eval_context, arg, phase);
+  static constexpr auto Table =
+      MakeArgHandlerTable(static_cast<SemIR::IdKind*>(nullptr));
+  return Table[arg_and_kind.kind.ToIndex()](eval_context, arg_and_kind.value,
+                                            phase);
 }
 
 // Given an instruction, replaces its type and operands with their constant
@@ -680,22 +696,20 @@ static auto ReplaceAllFieldsWithConstantValues(EvalContext& eval_context,
                                                SemIR::Inst* inst, Phase* phase)
     -> bool {
   auto type_id = SemIR::TypeId(
-      GetConstantValueForArg(eval_context, SemIR::IdKind::For<SemIR::TypeId>,
-                             inst->type_id().index, phase));
+      GetConstantValueForArg(eval_context, inst->type_id_and_kind(), phase));
   inst->SetType(type_id);
   if (!IsConstant(*phase)) {
     return false;
   }
 
-  auto kinds = inst->ArgKinds();
   auto arg0 =
-      GetConstantValueForArg(eval_context, kinds.first, inst->arg0(), phase);
+      GetConstantValueForArg(eval_context, inst->arg0_and_kind(), phase);
   if (!IsConstant(*phase)) {
     return false;
   }
 
   auto arg1 =
-      GetConstantValueForArg(eval_context, kinds.second, inst->arg1(), phase);
+      GetConstantValueForArg(eval_context, inst->arg1_and_kind(), phase);
   if (!IsConstant(*phase)) {
     return false;
   }
@@ -1601,9 +1615,8 @@ static auto ComputeInstPhase(Context& context, SemIR::Inst inst) -> Phase {
 
   auto phase = GetPhase(context.constant_values(),
                         context.types().GetConstantId(inst.type_id()));
-  auto kinds = inst.ArgKinds();
-  GetConstantValueForArg(eval_context, kinds.first, inst.arg0(), &phase);
-  GetConstantValueForArg(eval_context, kinds.second, inst.arg1(), &phase);
+  GetConstantValueForArg(eval_context, inst.arg0_and_kind(), &phase);
+  GetConstantValueForArg(eval_context, inst.arg1_and_kind(), &phase);
   CARBON_CHECK(IsConstant(phase));
   return phase;
 }

+ 9 - 11
toolchain/check/impl_lookup.cpp

@@ -71,31 +71,29 @@ static auto FindAssociatedImportIRs(Context& context,
 
     // Visit the operands of the constant.
     auto inst = context.insts().Get(inst_id);
-    auto [arg0_kind, arg1_kind] = inst.ArgKinds();
-    for (auto [arg, kind] :
-         {std::pair{inst.arg0(), arg0_kind}, {inst.arg1(), arg1_kind}}) {
-      switch (kind) {
+    for (auto arg : {inst.arg0_and_kind(), inst.arg1_and_kind()}) {
+      switch (arg.kind) {
         case SemIR::IdKind::For<SemIR::InstId>: {
-          if (auto id = SemIR::InstId(arg); id.has_value()) {
+          if (auto id = arg.As<SemIR::InstId>(); id.has_value()) {
             worklist.push_back(id);
           }
           break;
         }
         case SemIR::IdKind::For<SemIR::InstBlockId>: {
-          push_block(SemIR::InstBlockId(arg));
+          push_block(arg.As<SemIR::InstBlockId>());
           break;
         }
         case SemIR::IdKind::For<SemIR::ClassId>: {
-          add_entity(context.classes().Get(SemIR::ClassId(arg)));
+          add_entity(context.classes().Get(arg.As<SemIR::ClassId>()));
           break;
         }
         case SemIR::IdKind::For<SemIR::InterfaceId>: {
-          add_entity(context.interfaces().Get(SemIR::InterfaceId(arg)));
+          add_entity(context.interfaces().Get(arg.As<SemIR::InterfaceId>()));
           break;
         }
         case SemIR::IdKind::For<SemIR::FacetTypeId>: {
           const auto& facet_type_info =
-              context.facet_types().Get(SemIR::FacetTypeId(arg));
+              context.facet_types().Get(arg.As<SemIR::FacetTypeId>());
           for (const auto& impl : facet_type_info.impls_constraints) {
             add_entity(context.interfaces().Get(impl.interface_id));
             push_args(impl.specific_id);
@@ -103,11 +101,11 @@ static auto FindAssociatedImportIRs(Context& context,
           break;
         }
         case SemIR::IdKind::For<SemIR::FunctionId>: {
-          add_entity(context.functions().Get(SemIR::FunctionId(arg)));
+          add_entity(context.functions().Get(arg.As<SemIR::FunctionId>()));
           break;
         }
         case SemIR::IdKind::For<SemIR::SpecificId>: {
-          push_args(SemIR::SpecificId(arg));
+          push_args(arg.As<SemIR::SpecificId>());
           break;
         }
         default: {

+ 45 - 39
toolchain/check/subst.cpp

@@ -60,7 +60,7 @@ class Worklist {
 
 // Pushes the specified operand onto the worklist.
 static auto PushOperand(Context& context, Worklist& worklist,
-                        SemIR::IdKind kind, int32_t arg) -> void {
+                        SemIR::Inst::ArgAndKind arg) -> void {
   auto push_block = [&](SemIR::InstBlockId block_id) {
     for (auto inst_id :
          context.inst_blocks().Get(SemIR::InstBlockId(block_id))) {
@@ -74,45 +74,50 @@ static auto PushOperand(Context& context, Worklist& worklist,
     }
   };
 
-  switch (kind) {
+  switch (arg.kind) {
     case SemIR::IdKind::For<SemIR::InstId>:
+      if (auto inst_id = arg.As<SemIR::InstId>(); inst_id.has_value()) {
+        worklist.Push(inst_id);
+      }
+      break;
     case SemIR::IdKind::For<SemIR::MetaInstId>:
-      if (SemIR::InstId inst_id(arg); inst_id.has_value()) {
+      if (auto inst_id = arg.As<SemIR::MetaInstId>(); inst_id.has_value()) {
         worklist.Push(inst_id);
       }
       break;
     case SemIR::IdKind::For<SemIR::TypeId>:
-      if (SemIR::TypeId type_id(arg); type_id.has_value()) {
+      if (auto type_id = arg.As<SemIR::TypeId>(); type_id.has_value()) {
         worklist.Push(context.types().GetInstId(type_id));
       }
       break;
     case SemIR::IdKind::For<SemIR::InstBlockId>:
-      push_block(SemIR::InstBlockId(arg));
+      push_block(arg.As<SemIR::InstBlockId>());
       break;
     case SemIR::IdKind::For<SemIR::StructTypeFieldsId>: {
-      for (auto field :
-           context.struct_type_fields().Get(SemIR::StructTypeFieldsId(arg))) {
+      for (auto field : context.struct_type_fields().Get(
+               arg.As<SemIR::StructTypeFieldsId>())) {
         worklist.Push(context.types().GetInstId(field.type_id));
       }
       break;
     }
     case SemIR::IdKind::For<SemIR::TypeBlockId>:
-      for (auto type_id : context.type_blocks().Get(SemIR::TypeBlockId(arg))) {
+      for (auto type_id :
+           context.type_blocks().Get(arg.As<SemIR::TypeBlockId>())) {
         worklist.Push(context.types().GetInstId(type_id));
       }
       break;
     case SemIR::IdKind::For<SemIR::SpecificId>:
-      push_specific(SemIR::SpecificId(arg));
+      push_specific(arg.As<SemIR::SpecificId>());
       break;
     case SemIR::IdKind::For<SemIR::SpecificInterfaceId>: {
-      auto interface =
-          context.specific_interfaces().Get(SemIR::SpecificInterfaceId(arg));
+      auto interface = context.specific_interfaces().Get(
+          arg.As<SemIR::SpecificInterfaceId>());
       push_specific(interface.specific_id);
       break;
     }
     case SemIR::IdKind::For<SemIR::FacetTypeId>: {
       const auto& facet_type_info =
-          context.facet_types().Get(SemIR::FacetTypeId(arg));
+          context.facet_types().Get(arg.As<SemIR::FacetTypeId>());
       for (auto interface : facet_type_info.impls_constraints) {
         push_specific(interface.specific_id);
       }
@@ -137,16 +142,14 @@ static auto PushOperand(Context& context, Worklist& worklist,
 static auto ExpandOperands(Context& context, Worklist& worklist,
                            SemIR::InstId inst_id) -> void {
   auto inst = context.insts().Get(inst_id);
-  auto kinds = inst.ArgKinds();
-  PushOperand(context, worklist, SemIR::IdKind::For<SemIR::TypeId>,
-              inst.type_id().index);
-  PushOperand(context, worklist, kinds.first, inst.arg0());
-  PushOperand(context, worklist, kinds.second, inst.arg1());
+  PushOperand(context, worklist, inst.type_id_and_kind());
+  PushOperand(context, worklist, inst.arg0_and_kind());
+  PushOperand(context, worklist, inst.arg1_and_kind());
 }
 
 // Pops the specified operand from the worklist and returns it.
-static auto PopOperand(Context& context, Worklist& worklist, SemIR::IdKind kind,
-                       int32_t arg) -> int32_t {
+static auto PopOperand(Context& context, Worklist& worklist,
+                       SemIR::Inst::ArgAndKind arg) -> int32_t {
   auto pop_block_id = [&](SemIR::InstBlockId old_inst_block_id) {
     auto size = context.inst_blocks().Get(old_inst_block_id).size();
     SemIR::CopyOnWriteInstBlock new_inst_block(context.sem_ir(),
@@ -166,27 +169,33 @@ static auto PopOperand(Context& context, Worklist& worklist, SemIR::IdKind kind,
     return context.specifics().GetOrAdd(specific.generic_id, args_id);
   };
 
-  switch (kind) {
-    case SemIR::IdKind::For<SemIR::InstId>:
+  switch (arg.kind) {
+    case SemIR::IdKind::For<SemIR::InstId>: {
+      auto inst_id = arg.As<SemIR::InstId>();
+      if (!inst_id.has_value()) {
+        return arg.value;
+      }
+      return worklist.Pop().index;
+    }
     case SemIR::IdKind::For<SemIR::MetaInstId>: {
-      SemIR::InstId inst_id(arg);
+      auto inst_id = arg.As<SemIR::MetaInstId>();
       if (!inst_id.has_value()) {
-        return arg;
+        return arg.value;
       }
       return worklist.Pop().index;
     }
     case SemIR::IdKind::For<SemIR::TypeId>: {
-      SemIR::TypeId type_id(arg);
+      auto type_id = arg.As<SemIR::TypeId>();
       if (!type_id.has_value()) {
-        return arg;
+        return arg.value;
       }
       return context.types().GetTypeIdForTypeInstId(worklist.Pop()).index;
     }
     case SemIR::IdKind::For<SemIR::InstBlockId>: {
-      return pop_block_id(SemIR::InstBlockId(arg)).index;
+      return pop_block_id(arg.As<SemIR::InstBlockId>()).index;
     }
     case SemIR::IdKind::For<SemIR::StructTypeFieldsId>: {
-      SemIR::StructTypeFieldsId old_fields_id(arg);
+      auto old_fields_id = arg.As<SemIR::StructTypeFieldsId>();
       auto old_fields = context.struct_type_fields().Get(old_fields_id);
       SemIR::CopyOnWriteStructTypeFieldsBlock new_fields(context.sem_ir(),
                                                          old_fields_id);
@@ -198,7 +207,7 @@ static auto PopOperand(Context& context, Worklist& worklist, SemIR::IdKind kind,
       return new_fields.GetCanonical().index;
     }
     case SemIR::IdKind::For<SemIR::TypeBlockId>: {
-      SemIR::TypeBlockId old_type_block_id(arg);
+      auto old_type_block_id = arg.As<SemIR::TypeBlockId>();
       auto size = context.type_blocks().Get(old_type_block_id).size();
       SemIR::CopyOnWriteTypeBlock new_type_block(context.sem_ir(),
                                                  old_type_block_id);
@@ -209,11 +218,11 @@ static auto PopOperand(Context& context, Worklist& worklist, SemIR::IdKind kind,
       return new_type_block.GetCanonical().index;
     }
     case SemIR::IdKind::For<SemIR::SpecificId>: {
-      return pop_specific(SemIR::SpecificId(arg)).index;
+      return pop_specific(arg.As<SemIR::SpecificId>()).index;
     }
     case SemIR::IdKind::For<SemIR::SpecificInterfaceId>: {
-      auto interface =
-          context.specific_interfaces().Get(SemIR::SpecificInterfaceId(arg));
+      auto interface = context.specific_interfaces().Get(
+          arg.As<SemIR::SpecificInterfaceId>());
       auto specific_id = pop_specific(interface.specific_id);
       return context.specific_interfaces()
           .Add({
@@ -224,7 +233,7 @@ static auto PopOperand(Context& context, Worklist& worklist, SemIR::IdKind kind,
     }
     case SemIR::IdKind::For<SemIR::FacetTypeId>: {
       const auto& old_facet_type_info =
-          context.facet_types().Get(SemIR::FacetTypeId(arg));
+          context.facet_types().Get(arg.As<SemIR::FacetTypeId>());
       SemIR::FacetTypeInfo new_facet_type_info;
       // Since these were added to a stack, we get them back in reverse order.
       new_facet_type_info.rewrite_constraints.resize(
@@ -252,7 +261,7 @@ static auto PopOperand(Context& context, Worklist& worklist, SemIR::IdKind kind,
       return context.facet_types().Add(new_facet_type_info).index;
     }
     default:
-      return arg;
+      return arg.value;
   }
 }
 
@@ -261,14 +270,11 @@ static auto PopOperand(Context& context, Worklist& worklist, SemIR::IdKind kind,
 static auto Rebuild(Context& context, Worklist& worklist, SemIR::InstId inst_id,
                     const SubstInstCallbacks& callbacks) -> SemIR::InstId {
   auto inst = context.insts().Get(inst_id);
-  auto kinds = inst.ArgKinds();
 
   // Note that we pop in reverse order because we pushed them in forwards order.
-  int32_t arg1 = PopOperand(context, worklist, kinds.second, inst.arg1());
-  int32_t arg0 = PopOperand(context, worklist, kinds.first, inst.arg0());
-  int32_t type_id =
-      PopOperand(context, worklist, SemIR::IdKind::For<SemIR::TypeId>,
-                 inst.type_id().index);
+  int32_t arg1 = PopOperand(context, worklist, inst.arg1_and_kind());
+  int32_t arg0 = PopOperand(context, worklist, inst.arg0_and_kind());
+  int32_t type_id = PopOperand(context, worklist, inst.type_id_and_kind());
   if (type_id == inst.type_id().index && arg0 == inst.arg0() &&
       arg1 == inst.arg1()) {
     return callbacks.ReuseUnchanged(inst_id);

+ 2 - 0
toolchain/sem_ir/id_kind.h

@@ -17,6 +17,8 @@ namespace Carbon::SemIR {
 template <typename... Types>
 class TypeEnum : public Printable<TypeEnum<Types...>> {
  public:
+  using TypeTuple = std::tuple<Types...>;
+
   static constexpr size_t NumTypes = sizeof...(Types);
   static constexpr size_t NumValues = NumTypes + 2;
 

+ 37 - 14
toolchain/sem_ir/inst.h

@@ -127,6 +127,29 @@ concept InstLikeType = requires { sizeof(InstLikeTypeInfo<T>); };
 //   data where the instruction's kind is not known.
 class Inst : public Printable<Inst> {
  public:
+  // Associated an argument (usually arg0 or arg1, potentially type_id) with its
+  // IdKind.
+  struct ArgAndKind {
+    // Converts to `IdT`, validating the `kind` matches.
+    template <typename IdT>
+    auto As() const -> IdT {
+      CARBON_DCHECK(kind == SemIR::IdKind::For<IdT>);
+      return IdT(value);
+    }
+
+    // Converts to `IdT`, returning nullopt if the kind is incorrect.
+    template <typename IdT>
+    auto TryAs() const -> std::optional<IdT> {
+      if (kind != SemIR::IdKind::For<IdT>) {
+        return std::nullopt;
+      }
+      return IdT(value);
+    }
+
+    IdKind kind;
+    int32_t value;
+  };
+
   // Makes an instruction for a singleton. This exists to support simple
   // construction of all singletons by File.
   static auto MakeSingleton(InstKind kind) -> Inst {
@@ -225,20 +248,6 @@ class Inst : public Printable<Inst> {
   // Gets the type of the value produced by evaluating this instruction.
   auto type_id() const -> TypeId { return type_id_; }
 
-  // Gets the kinds of IDs used for arg0 and arg1 of the specified kind of
-  // instruction.
-  //
-  // TODO: This would ideally live on InstKind, but can't be there for layering
-  // reasons.
-  static auto ArgKinds(InstKind kind) -> std::pair<IdKind, IdKind> {
-    return ArgKindTable[kind.AsInt()];
-  }
-
-  // Gets the kinds of IDs used for arg0 and arg1 of this instruction.
-  auto ArgKinds() const -> std::pair<IdKind, IdKind> {
-    return ArgKinds(kind());
-  }
-
   // Gets the first argument of the instruction. NoneIndex if there is no such
   // argument.
   auto arg0() const -> int32_t { return arg0_; }
@@ -247,6 +256,17 @@ class Inst : public Printable<Inst> {
   // argument.
   auto arg1() const -> int32_t { return arg1_; }
 
+  // Returns arguments with their IdKind.
+  auto type_id_and_kind() const -> ArgAndKind {
+    return {.kind = SemIR::IdKind::For<SemIR::TypeId>, .value = type_id_.index};
+  }
+  auto arg0_and_kind() const -> ArgAndKind {
+    return {.kind = ArgKindTable[kind_].first, .value = arg0_};
+  }
+  auto arg1_and_kind() const -> ArgAndKind {
+    return {.kind = ArgKindTable[kind_].second, .value = arg1_};
+  }
+
   // Sets the type of this instruction.
   auto SetType(TypeId type_id) -> void { type_id_ = type_id; }
 
@@ -281,6 +301,9 @@ class Inst : public Printable<Inst> {
   friend class InstTestHelper;
 
   // Table mapping instruction kinds to their argument kinds.
+  //
+  // TODO: ArgKindTable would ideally live on InstKind, but can't be there for
+  // layering reasons.
   static const std::pair<IdKind, IdKind> ArgKindTable[];
 
   // Raw constructor, used for testing.

+ 28 - 26
toolchain/sem_ir/inst_fingerprinter.cpp

@@ -315,30 +315,33 @@ struct Worklist {
     CARBON_FATAL("Unexpected instruction operand kind {0}", typeid(T).name());
   }
 
-  // Add an instruction argument to the contents of the current instruction.
+  using AddFnT = auto(Worklist& worklist, int32_t arg) -> void;
+
+  // Returns a lookup table to add an argument of the given kind. Requires a
+  // null IdKind as a parameter in order to get the type pack.
   template <typename... Types>
-  auto AddWithKind(uint64_t arg, TypeEnum<Types...> kind) -> void {
-    using AddFunction = void (*)(Worklist& worklist, uint64_t arg);
-    using Kind = decltype(kind);
-
-    // Build a lookup table to add an argument of the given kind.
-    static constexpr std::array<AddFunction, Kind::NumTypes + 2> Table = [] {
-      std::array<AddFunction, Kind::NumTypes + 2> table;
-      table[Kind::None.ToIndex()] = [](Worklist& /*worklist*/,
-                                       uint64_t /*arg*/) {};
-      table[Kind::Invalid.ToIndex()] = [](Worklist& /*worklist*/,
-                                          uint64_t /*arg*/) {
-        CARBON_FATAL("Unexpected invalid argument kind");
-      };
-      ((table[Kind::template For<Types>.ToIndex()] =
-            [](Worklist& worklist, uint64_t arg) {
-              return worklist.Add(Inst::FromRaw<Types>(arg));
-            }),
-       ...);
-      return table;
-    }();
-
-    Table[kind.ToIndex()](*this, arg);
+  static constexpr auto MakeAddTable(TypeEnum<Types...>* /*id_kind*/)
+      -> std::array<AddFnT*, SemIR::IdKind::NumValues> {
+    std::array<AddFnT*, SemIR::IdKind::NumValues> table = {};
+    ((table[SemIR::IdKind::template For<Types>.ToIndex()] =
+          [](Worklist& worklist, int32_t arg) {
+            worklist.Add(Inst::FromRaw<Types>(arg));
+          }),
+     ...);
+    table[SemIR::IdKind::Invalid.ToIndex()] = [](Worklist& /*worklist*/,
+                                                 int32_t /*arg*/) {
+      CARBON_FATAL("Unexpected invalid argument kind");
+    };
+    table[SemIR::IdKind::None.ToIndex()] = [](Worklist& /*worklist*/,
+                                              int32_t /*arg*/) {};
+    return table;
+  }
+
+  // Add an instruction argument to the contents of the current instruction.
+  auto AddWithKind(Inst::ArgAndKind arg) -> void {
+    static constexpr auto Table = MakeAddTable(static_cast<IdKind*>(nullptr));
+
+    Table[arg.kind.ToIndex()](*this, arg.value);
   }
 
   // Ensure all the instructions on the todo list have fingerprints. To avoid a
@@ -399,7 +402,6 @@ struct Worklist {
       // finish that work and process this instruction again, and if not, we'll
       // pop the instruction at the end of the loop.
       auto inst = next_sem_ir->insts().Get(next_inst_id);
-      auto [arg0_kind, arg1_kind] = inst.ArgKinds();
 
       // Add the instruction's fields to the contents.
       Add(inst.kind());
@@ -411,8 +413,8 @@ struct Worklist {
         Add(inst.type_id());
       }
 
-      AddWithKind(inst.arg0(), arg0_kind);
-      AddWithKind(inst.arg1(), arg1_kind);
+      AddWithKind(inst.arg0_and_kind());
+      AddWithKind(inst.arg1_and_kind());
 
       // If we didn't add any work, we have a fingerprint for this instruction;
       // pop it from the todo list. Otherwise, we leave it on the todo list so