Browse Source

Extend CARBON_KIND_SWITCH to support ArgAndKind (#5216)

This builds on #5212 which is adding ArgAndKind. This further modifies
CARBON_KIND_SWITCH support so that we can use it with ArgAndKind in
addition to Inst. That creates a quirk where it's easier if ArgAndKind
provides `kind` as an accessor instead of a data member, so I'm just
switching it to a class.
Jon Ross-Perkins 1 year ago
parent
commit
9134e36ec0

+ 1 - 0
.codespell_ignore

@@ -19,3 +19,4 @@ pullrequest
 rightt
 rouge
 statics
+switcht

+ 11 - 6
toolchain/base/kind_switch.h

@@ -5,6 +5,8 @@
 #ifndef CARBON_TOOLCHAIN_BASE_KIND_SWITCH_H_
 #define CARBON_TOOLCHAIN_BASE_KIND_SWITCH_H_
 
+#include <type_traits>
+
 #include "llvm/ADT/STLExtras.h"
 
 // This library provides switch-like behaviors for Carbon's kind-based types.
@@ -37,19 +39,21 @@
 // requirements should change.
 namespace Carbon::Internal::Kind {
 
-// Given `CARBON_KIND_SWITCH(value)` this handles calling `value.kind()`.
-template <typename T>
-auto SwitchOn(T&& switch_value) -> auto {
+// Given `CARBON_KIND_SWITCH(value)` this returns `value.kind()` to switch on.
+template <typename SwitchT>
+auto SwitchOn(SwitchT&& switch_value) -> auto {
   return switch_value.kind();
 }
 
 // Given `CARBON_KIND(CaseT name)` this generates `CaseT::Kind`. It explicitly
 // returns `KindT` because that may differ from `CaseT::Kind`, and may not be
 // copyable.
-template <typename FnT>
+template <typename SwitchT, typename CaseFnT>
 consteval auto ForCase() -> auto {
-  using ArgT = llvm::function_traits<FnT>::template arg_t<0>;
-  return static_cast<decltype(ArgT::Kind)::RawEnumType>(ArgT::Kind);
+  using KindT = llvm::function_traits<
+      decltype(&std::remove_cvref_t<SwitchT>::kind)>::result_t;
+  using CaseT = llvm::function_traits<CaseFnT>::template arg_t<0>;
+  return static_cast<KindT::RawEnumType>(KindT::template For<CaseT>);
 }
 
 // Given `CARBON_KIND_SWITCH(value)` and `CARBON_KIND(CaseT name)` this
@@ -80,6 +84,7 @@ auto Cast(ValueT&& kind_switch_value) -> auto {
 // name, making it look more like a typical `case`.
 #define CARBON_KIND(typed_variable_decl)                                \
   ::Carbon::Internal::Kind::ForCase<                                    \
+      decltype(carbon_internal_kind_switch_value),                      \
       decltype([]([[maybe_unused]] typed_variable_decl) {})>()          \
       : if (typed_variable_decl = ::Carbon::Internal::Kind::Cast<       \
                 decltype([]([[maybe_unused]] typed_variable_decl) {})>( \

+ 10 - 7
toolchain/check/action.cpp

@@ -4,6 +4,7 @@
 
 #include "toolchain/check/action.h"
 
+#include "toolchain/base/kind_switch.h"
 #include "toolchain/check/generic_region_stack.h"
 #include "toolchain/check/inst.h"
 #include "toolchain/sem_ir/constant.h"
@@ -46,12 +47,14 @@ auto OperandIsDependent(Context& context, SemIR::MetaInstId inst_id) -> bool {
 
 static auto OperandIsDependent(Context& context, SemIR::Inst::ArgAndKind arg)
     -> bool {
-  switch (arg.kind) {
-    case SemIR::IdKind::For<SemIR::MetaInstId>:
-      return OperandIsDependent(context, arg.As<SemIR::MetaInstId>());
+  CARBON_KIND_SWITCH(arg) {
+    case CARBON_KIND(SemIR::MetaInstId inst_id): {
+      return OperandIsDependent(context, inst_id);
+    }
 
-    case SemIR::IdKind::For<SemIR::TypeId>:
-      return OperandIsDependent(context, arg.As<SemIR::TypeId>());
+    case CARBON_KIND(SemIR::TypeId type_id): {
+      return OperandIsDependent(context, type_id);
+    }
 
     case SemIR::IdKind::None:
     case SemIR::IdKind::For<SemIR::AbsoluteInstId>:
@@ -106,7 +109,7 @@ static auto RefineOperand(Context& context, SemIR::LocId loc_id,
     if (inst.Is<SemIR::SpliceInst>()) {
       // The argument will evaluate to the spliced instruction, which is already
       // refined.
-      return arg.value;
+      return arg.value();
     }
 
     // If the type of the action argument is dependent, refine to an instruction
@@ -128,7 +131,7 @@ static auto RefineOperand(Context& context, SemIR::LocId loc_id,
     return inst_id->index;
   }
 
-  return arg.value;
+  return arg.value();
 }
 
 // Refine the operands of an action, ensuring that they will refer to concrete

+ 22 - 20
toolchain/check/deduce.cpp

@@ -140,37 +140,39 @@ class DeductionWorklist {
   // Adds a (param, arg) pair for an instruction argument, given its kind.
   auto AddInstArg(SemIR::Inst::ArgAndKind param, int32_t arg,
                   bool needs_substitution) -> void {
-    switch (param.kind) {
+    CARBON_KIND_SWITCH(param) {
       case SemIR::IdKind::None:
       case SemIR::IdKind::For<SemIR::ClassId>:
       case SemIR::IdKind::For<SemIR::IntKind>:
         break;
-      case SemIR::IdKind::For<SemIR::InstId>:
-        Add(param.As<SemIR::InstId>(), SemIR::InstId(arg), needs_substitution);
+      case CARBON_KIND(SemIR::InstId inst_id): {
+        Add(inst_id, SemIR::InstId(arg), needs_substitution);
         break;
-      case SemIR::IdKind::For<SemIR::TypeId>:
-        Add(param.As<SemIR::TypeId>(), SemIR::TypeId(arg), needs_substitution);
+      }
+      case CARBON_KIND(SemIR::TypeId type_id): {
+        Add(type_id, SemIR::TypeId(arg), needs_substitution);
         break;
-      case SemIR::IdKind::For<SemIR::StructTypeFieldsId>:
-        AddAll(param.As<SemIR::StructTypeFieldsId>(),
-               SemIR::StructTypeFieldsId(arg), needs_substitution);
+      }
+      case CARBON_KIND(SemIR::StructTypeFieldsId fields_id): {
+        AddAll(fields_id, SemIR::StructTypeFieldsId(arg), needs_substitution);
         break;
-      case SemIR::IdKind::For<SemIR::InstBlockId>:
-        AddAll(param.As<SemIR::InstBlockId>(), SemIR::InstBlockId(arg),
-               needs_substitution);
+      }
+      case CARBON_KIND(SemIR::InstBlockId inst_block_id): {
+        AddAll(inst_block_id, SemIR::InstBlockId(arg), needs_substitution);
         break;
-      case SemIR::IdKind::For<SemIR::TypeBlockId>:
-        AddAll(param.As<SemIR::TypeBlockId>(), SemIR::TypeBlockId(arg),
-               needs_substitution);
+      }
+      case CARBON_KIND(SemIR::TypeBlockId type_block_id): {
+        AddAll(type_block_id, SemIR::TypeBlockId(arg), needs_substitution);
         break;
-      case SemIR::IdKind::For<SemIR::SpecificId>:
-        Add(param.As<SemIR::SpecificId>(), SemIR::SpecificId(arg),
-            needs_substitution);
+      }
+      case CARBON_KIND(SemIR::SpecificId specific_id): {
+        Add(specific_id, SemIR::SpecificId(arg), needs_substitution);
         break;
-      case SemIR::IdKind::For<SemIR::FacetTypeId>:
-        AddAll(param.As<SemIR::FacetTypeId>(), SemIR::FacetTypeId(arg),
-               needs_substitution);
+      }
+      case CARBON_KIND(SemIR::FacetTypeId facet_type_id): {
+        AddAll(facet_type_id, SemIR::FacetTypeId(arg), needs_substitution);
         break;
+      }
       default:
         CARBON_FATAL("unexpected argument kind");
     }

+ 2 - 2
toolchain/check/eval.cpp

@@ -684,8 +684,8 @@ static auto GetConstantValueForArg(EvalContext& eval_context,
                                    Phase* phase) -> int32_t {
   static constexpr auto Table =
       MakeArgHandlerTable(static_cast<SemIR::IdKind*>(nullptr));
-  return Table[arg_and_kind.kind.ToIndex()](eval_context, arg_and_kind.value,
-                                            phase);
+  return Table[arg_and_kind.kind().ToIndex()](eval_context,
+                                              arg_and_kind.value(), phase);
 }
 
 // Given an instruction, replaces its type and operands with their constant

+ 16 - 16
toolchain/check/impl_lookup.cpp

@@ -72,40 +72,40 @@ static auto FindAssociatedImportIRs(Context& context,
     // Visit the operands of the constant.
     auto inst = context.insts().Get(inst_id);
     for (auto arg : {inst.arg0_and_kind(), inst.arg1_and_kind()}) {
-      switch (arg.kind) {
-        case SemIR::IdKind::For<SemIR::InstId>: {
-          if (auto id = arg.As<SemIR::InstId>(); id.has_value()) {
-            worklist.push_back(id);
+      CARBON_KIND_SWITCH(arg) {
+        case CARBON_KIND(SemIR::InstId inst_id): {
+          if (inst_id.has_value()) {
+            worklist.push_back(inst_id);
           }
           break;
         }
-        case SemIR::IdKind::For<SemIR::InstBlockId>: {
-          push_block(arg.As<SemIR::InstBlockId>());
+        case CARBON_KIND(SemIR::InstBlockId inst_block_id): {
+          push_block(inst_block_id);
           break;
         }
-        case SemIR::IdKind::For<SemIR::ClassId>: {
-          add_entity(context.classes().Get(arg.As<SemIR::ClassId>()));
+        case CARBON_KIND(SemIR::ClassId class_id): {
+          add_entity(context.classes().Get(class_id));
           break;
         }
-        case SemIR::IdKind::For<SemIR::InterfaceId>: {
-          add_entity(context.interfaces().Get(arg.As<SemIR::InterfaceId>()));
+        case CARBON_KIND(SemIR::InterfaceId interface_id): {
+          add_entity(context.interfaces().Get(interface_id));
           break;
         }
-        case SemIR::IdKind::For<SemIR::FacetTypeId>: {
+        case CARBON_KIND(SemIR::FacetTypeId facet_type_id): {
           const auto& facet_type_info =
-              context.facet_types().Get(arg.As<SemIR::FacetTypeId>());
+              context.facet_types().Get(facet_type_id);
           for (const auto& impl : facet_type_info.impls_constraints) {
             add_entity(context.interfaces().Get(impl.interface_id));
             push_args(impl.specific_id);
           }
           break;
         }
-        case SemIR::IdKind::For<SemIR::FunctionId>: {
-          add_entity(context.functions().Get(arg.As<SemIR::FunctionId>()));
+        case CARBON_KIND(SemIR::FunctionId function_id): {
+          add_entity(context.functions().Get(function_id));
           break;
         }
-        case SemIR::IdKind::For<SemIR::SpecificId>: {
-          push_args(arg.As<SemIR::SpecificId>());
+        case CARBON_KIND(SemIR::SpecificId specific_id): {
+          push_args(specific_id);
           break;
         }
         default: {

+ 44 - 47
toolchain/check/subst.cpp

@@ -4,6 +4,7 @@
 
 #include "toolchain/check/subst.h"
 
+#include "toolchain/base/kind_switch.h"
 #include "toolchain/check/eval.h"
 #include "toolchain/check/generic.h"
 #include "toolchain/sem_ir/copy_on_write_block.h"
@@ -74,50 +75,52 @@ static auto PushOperand(Context& context, Worklist& worklist,
     }
   };
 
-  switch (arg.kind) {
-    case SemIR::IdKind::For<SemIR::InstId>:
-      if (auto inst_id = arg.As<SemIR::InstId>(); inst_id.has_value()) {
+  CARBON_KIND_SWITCH(arg) {
+    case CARBON_KIND(SemIR::InstId inst_id): {
+      if (inst_id.has_value()) {
         worklist.Push(inst_id);
       }
       break;
-    case SemIR::IdKind::For<SemIR::MetaInstId>:
-      if (auto inst_id = arg.As<SemIR::MetaInstId>(); inst_id.has_value()) {
+    }
+    case CARBON_KIND(SemIR::MetaInstId inst_id): {
+      if (inst_id.has_value()) {
         worklist.Push(inst_id);
       }
       break;
-    case SemIR::IdKind::For<SemIR::TypeId>:
-      if (auto type_id = arg.As<SemIR::TypeId>(); type_id.has_value()) {
+    }
+    case CARBON_KIND(SemIR::TypeId type_id): {
+      if (type_id.has_value()) {
         worklist.Push(context.types().GetInstId(type_id));
       }
       break;
-    case SemIR::IdKind::For<SemIR::InstBlockId>:
-      push_block(arg.As<SemIR::InstBlockId>());
+    }
+    case CARBON_KIND(SemIR::InstBlockId inst_block_id): {
+      push_block(inst_block_id);
       break;
-    case SemIR::IdKind::For<SemIR::StructTypeFieldsId>: {
-      for (auto field : context.struct_type_fields().Get(
-               arg.As<SemIR::StructTypeFieldsId>())) {
+    }
+    case CARBON_KIND(SemIR::StructTypeFieldsId fields_id): {
+      for (auto field : context.struct_type_fields().Get(fields_id)) {
         worklist.Push(context.types().GetInstId(field.type_id));
       }
       break;
     }
-    case SemIR::IdKind::For<SemIR::TypeBlockId>:
-      for (auto type_id :
-           context.type_blocks().Get(arg.As<SemIR::TypeBlockId>())) {
+    case CARBON_KIND(SemIR::TypeBlockId type_block_id): {
+      for (auto type_id : context.type_blocks().Get(type_block_id)) {
         worklist.Push(context.types().GetInstId(type_id));
       }
       break;
-    case SemIR::IdKind::For<SemIR::SpecificId>:
-      push_specific(arg.As<SemIR::SpecificId>());
+    }
+    case CARBON_KIND(SemIR::SpecificId specific_id): {
+      push_specific(specific_id);
       break;
-    case SemIR::IdKind::For<SemIR::SpecificInterfaceId>: {
-      auto interface = context.specific_interfaces().Get(
-          arg.As<SemIR::SpecificInterfaceId>());
+    }
+    case CARBON_KIND(SemIR::SpecificInterfaceId interface_id): {
+      auto interface = context.specific_interfaces().Get(interface_id);
       push_specific(interface.specific_id);
       break;
     }
-    case SemIR::IdKind::For<SemIR::FacetTypeId>: {
-      const auto& facet_type_info =
-          context.facet_types().Get(arg.As<SemIR::FacetTypeId>());
+    case CARBON_KIND(SemIR::FacetTypeId facet_type_id): {
+      const auto& facet_type_info = context.facet_types().Get(facet_type_id);
       for (auto interface : facet_type_info.impls_constraints) {
         push_specific(interface.specific_id);
       }
@@ -169,33 +172,29 @@ static auto PopOperand(Context& context, Worklist& worklist,
     return context.specifics().GetOrAdd(specific.generic_id, args_id);
   };
 
-  switch (arg.kind) {
-    case SemIR::IdKind::For<SemIR::InstId>: {
-      auto inst_id = arg.As<SemIR::InstId>();
+  CARBON_KIND_SWITCH(arg) {
+    case CARBON_KIND(SemIR::InstId inst_id): {
       if (!inst_id.has_value()) {
-        return arg.value;
+        return arg.value();
       }
       return worklist.Pop().index;
     }
-    case SemIR::IdKind::For<SemIR::MetaInstId>: {
-      auto inst_id = arg.As<SemIR::MetaInstId>();
+    case CARBON_KIND(SemIR::MetaInstId inst_id): {
       if (!inst_id.has_value()) {
-        return arg.value;
+        return arg.value();
       }
       return worklist.Pop().index;
     }
-    case SemIR::IdKind::For<SemIR::TypeId>: {
-      auto type_id = arg.As<SemIR::TypeId>();
+    case CARBON_KIND(SemIR::TypeId type_id): {
       if (!type_id.has_value()) {
-        return arg.value;
+        return arg.value();
       }
       return context.types().GetTypeIdForTypeInstId(worklist.Pop()).index;
     }
-    case SemIR::IdKind::For<SemIR::InstBlockId>: {
-      return pop_block_id(arg.As<SemIR::InstBlockId>()).index;
+    case CARBON_KIND(SemIR::InstBlockId inst_block_id): {
+      return pop_block_id(inst_block_id).index;
     }
-    case SemIR::IdKind::For<SemIR::StructTypeFieldsId>: {
-      auto old_fields_id = arg.As<SemIR::StructTypeFieldsId>();
+    case CARBON_KIND(SemIR::StructTypeFieldsId old_fields_id): {
       auto old_fields = context.struct_type_fields().Get(old_fields_id);
       SemIR::CopyOnWriteStructTypeFieldsBlock new_fields(context.sem_ir(),
                                                          old_fields_id);
@@ -206,8 +205,7 @@ static auto PopOperand(Context& context, Worklist& worklist,
       }
       return new_fields.GetCanonical().index;
     }
-    case SemIR::IdKind::For<SemIR::TypeBlockId>: {
-      auto old_type_block_id = arg.As<SemIR::TypeBlockId>();
+    case CARBON_KIND(SemIR::TypeBlockId old_type_block_id): {
       auto size = context.type_blocks().Get(old_type_block_id).size();
       SemIR::CopyOnWriteTypeBlock new_type_block(context.sem_ir(),
                                                  old_type_block_id);
@@ -217,12 +215,11 @@ static auto PopOperand(Context& context, Worklist& worklist,
       }
       return new_type_block.GetCanonical().index;
     }
-    case SemIR::IdKind::For<SemIR::SpecificId>: {
-      return pop_specific(arg.As<SemIR::SpecificId>()).index;
+    case CARBON_KIND(SemIR::SpecificId specific_id): {
+      return pop_specific(specific_id).index;
     }
-    case SemIR::IdKind::For<SemIR::SpecificInterfaceId>: {
-      auto interface = context.specific_interfaces().Get(
-          arg.As<SemIR::SpecificInterfaceId>());
+    case CARBON_KIND(SemIR::SpecificInterfaceId interface_id): {
+      auto interface = context.specific_interfaces().Get(interface_id);
       auto specific_id = pop_specific(interface.specific_id);
       return context.specific_interfaces()
           .Add({
@@ -231,9 +228,9 @@ static auto PopOperand(Context& context, Worklist& worklist,
           })
           .index;
     }
-    case SemIR::IdKind::For<SemIR::FacetTypeId>: {
+    case CARBON_KIND(SemIR::FacetTypeId facet_type_id): {
       const auto& old_facet_type_info =
-          context.facet_types().Get(arg.As<SemIR::FacetTypeId>());
+          context.facet_types().Get(facet_type_id);
       SemIR::FacetTypeInfo new_facet_type_info;
       // Since these were added to a stack, we get them back in reverse order.
       new_facet_type_info.rewrite_constraints.resize(
@@ -261,7 +258,7 @@ static auto PopOperand(Context& context, Worklist& worklist,
       return context.facet_types().Add(new_facet_type_info).index;
     }
     default:
-      return arg.value;
+      return arg.value();
   }
 }
 

+ 19 - 11
toolchain/sem_ir/inst.h

@@ -127,27 +127,35 @@ concept InstLikeType = requires { sizeof(InstLikeTypeInfo<T>); };
 //   data where the instruction's kind is not known.
 class Inst : public Printable<Inst> {
  public:
-  // Associated an argument (usually arg0 or arg1, potentially type_id) with its
+  // Associates an argument (usually arg0 or arg1, potentially type_id) with its
   // IdKind.
-  struct ArgAndKind {
+  class ArgAndKind {
+   public:
+    explicit ArgAndKind(IdKind kind, int32_t value)
+        : kind_(kind), value_(value) {}
+
     // Converts to `IdT`, validating the `kind` matches.
     template <typename IdT>
     auto As() const -> IdT {
-      CARBON_DCHECK(kind == SemIR::IdKind::For<IdT>);
-      return IdT(value);
+      CARBON_DCHECK(kind_ == SemIR::IdKind::For<IdT>);
+      return IdT(value_);
     }
 
     // Converts to `IdT`, returning nullopt if the kind is incorrect.
     template <typename IdT>
     auto TryAs() const -> std::optional<IdT> {
-      if (kind != SemIR::IdKind::For<IdT>) {
+      if (kind_ != SemIR::IdKind::For<IdT>) {
         return std::nullopt;
       }
-      return IdT(value);
+      return IdT(value_);
     }
 
-    IdKind kind;
-    int32_t value;
+    auto kind() const -> IdKind { return kind_; }
+    auto value() const -> int32_t { return value_; }
+
+   private:
+    IdKind kind_;
+    int32_t value_;
   };
 
   // Makes an instruction for a singleton. This exists to support simple
@@ -258,13 +266,13 @@ class Inst : public Printable<Inst> {
 
   // Returns arguments with their IdKind.
   auto type_id_and_kind() const -> ArgAndKind {
-    return {.kind = SemIR::IdKind::For<SemIR::TypeId>, .value = type_id_.index};
+    return ArgAndKind(SemIR::IdKind::For<SemIR::TypeId>, type_id_.index);
   }
   auto arg0_and_kind() const -> ArgAndKind {
-    return {.kind = ArgKindTable[kind_].first, .value = arg0_};
+    return ArgAndKind(ArgKindTable[kind_].first, arg0_);
   }
   auto arg1_and_kind() const -> ArgAndKind {
-    return {.kind = ArgKindTable[kind_].second, .value = arg1_};
+    return ArgAndKind(ArgKindTable[kind_].second, arg1_);
   }
 
   // Sets the type of this instruction.

+ 1 - 1
toolchain/sem_ir/inst_fingerprinter.cpp

@@ -341,7 +341,7 @@ struct Worklist {
   auto AddWithKind(Inst::ArgAndKind arg) -> void {
     static constexpr auto Table = MakeAddTable(static_cast<IdKind*>(nullptr));
 
-    Table[arg.kind.ToIndex()](*this, arg.value);
+    Table[arg.kind().ToIndex()](*this, arg.value());
   }
 
   // Ensure all the instructions on the todo list have fingerprints. To avoid a

+ 4 - 0
toolchain/sem_ir/inst_kind.h

@@ -110,6 +110,10 @@ class InstKind : public CARBON_ENUM_BASE(InstKind) {
 #define CARBON_SEM_IR_INST_KIND(Name) CARBON_ENUM_CONSTANT_DECL(Name)
 #include "toolchain/sem_ir/inst_kind.def"
 
+  // Returns the `InstKind` for an instruction, for `CARBON_KIND_SWITCH`.
+  template <typename InstT>
+  static constexpr auto& For = InstT::Kind;
+
   template <typename TypedNodeId>
   class Definition;