Просмотр исходного кода

Factor `IdKind` enum out of node stack. (#3787)

Provide a general mechanism for determining the kind of the args of an
instruction. Use this to simplify instruction profiling a little. The
intent is to also use this mechanism as the basis of a substitution
mechanism, which will be part of a future patch.

Note that this causes us to do three table lookups and two indirect
calls in inst_profile per instruction, instead of one table lookup and
one indirect call. We can revisit this if it shows up in profiles.
Richard Smith 2 лет назад
Родитель
Сommit
2584399673

+ 9 - 34
toolchain/check/node_stack.h

@@ -5,70 +5,47 @@
 #ifndef CARBON_TOOLCHAIN_CHECK_NODE_STACK_H_
 #define CARBON_TOOLCHAIN_CHECK_NODE_STACK_H_
 
-#include <type_traits>
-
 #include "common/vlog.h"
 #include "llvm/ADT/SmallVector.h"
 #include "toolchain/parse/node_ids.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/tree.h"
 #include "toolchain/parse/typed_nodes.h"
+#include "toolchain/sem_ir/id_kind.h"
 #include "toolchain/sem_ir/ids.h"
 
 namespace Carbon::Check {
 
 // A non-discriminated union of ID types.
-template <typename... IdTypes>
 class IdUnion {
  public:
   // The default constructor forms an invalid ID.
   explicit constexpr IdUnion() : index(IdBase::InvalidIndex) {}
 
   template <typename IdT>
-    requires(std::same_as<IdT, IdTypes> || ...)
+    requires SemIR::IdKind::Contains<IdT>
   explicit constexpr IdUnion(IdT id) : index(id.index) {}
 
-  static constexpr std::size_t NumValidKinds = sizeof...(IdTypes);
-
-  // A numbering for the associated ID types.
-  enum class Kind : int8_t {
-    // The first `sizeof...(IdTypes)` indexes correspond to the types in
-    // `IdTypes`.
-
-    // An explicit invalid state.
-    Invalid = NumValidKinds,
-
-    // No active union element.
-    None,
-  };
+  using Kind = SemIR::IdKind::RawEnumType;
 
   // Returns the ID given its type.
   template <typename IdT>
-    requires(std::same_as<IdT, IdTypes> || ...)
+    requires SemIR::IdKind::Contains<IdT>
   constexpr auto As() const -> IdT {
     return IdT(index);
   }
 
   // Returns the ID given its kind.
-  template <Kind K>
-    requires(static_cast<size_t>(K) < sizeof...(IdTypes))
-  constexpr auto As() const {
-    using IdT = __type_pack_element<static_cast<size_t>(K), IdTypes...>;
-    return As<IdT>();
+  template <SemIR::IdKind::RawEnumType K>
+  constexpr auto As() const -> SemIR::IdKind::TypeFor<K> {
+    return As<SemIR::IdKind::TypeFor<K>>();
   }
 
   // Translates an ID type to the enum ID kind. Returns Invalid if `IdT` isn't
   // a type that can be stored in this union.
   template <typename IdT>
   static constexpr auto KindFor() -> Kind {
-    // A bool for each type saying whether it matches. The result is the index
-    // of the first `true` in this list. If none matches, then the result is the
-    // length of the list, which is mapped to `Invalid`.
-    constexpr bool TypeMatches[] = {std::same_as<IdT, IdTypes>...};
-    constexpr int Index =
-        std::find(TypeMatches, TypeMatches + sizeof...(IdTypes), true) -
-        TypeMatches;
-    return static_cast<Kind>(Index);
+    return SemIR::IdKind::For<IdT>;
   }
 
  private:
@@ -361,9 +338,7 @@ class NodeStack {
   // that the parse node has no associated ID, in which case the *SoloNodeId
   // functions should be used to push and pop it. Id::Kind::Invalid indicates
   // that the parse node should not appear in the node stack at all.
-  using Id = IdUnion<SemIR::InstId, SemIR::InstBlockId, SemIR::FunctionId,
-                     SemIR::ClassId, SemIR::InterfaceId, SemIR::ImplId,
-                     SemIR::NameId, SemIR::TypeId>;
+  using Id = IdUnion;
 
   // An entry in stack_.
   struct Entry {

+ 5 - 1
toolchain/sem_ir/BUILD

@@ -26,7 +26,10 @@ cc_library(
 
 cc_library(
     name = "ids",
-    hdrs = ["ids.h"],
+    hdrs = [
+        "id_kind.h",
+        "ids.h",
+    ],
     deps = [
         "//common:check",
         "//common:ostream",
@@ -61,6 +64,7 @@ cc_library(
     deps = [
         ":block_value_store",
         ":builtin_kind",
+        ":ids",
         ":inst_kind",
         "//common:check",
         "//common:ostream",

+ 120 - 0
toolchain/sem_ir/id_kind.h

@@ -0,0 +1,120 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_SEM_IR_ID_KIND_H_
+#define CARBON_TOOLCHAIN_SEM_IR_ID_KIND_H_
+
+#include <algorithm>
+
+#include "toolchain/sem_ir/ids.h"
+
+namespace Carbon::SemIR {
+
+// An enum whose values are the specified types.
+template <typename... Types>
+class TypeEnum {
+ public:
+  static constexpr std::size_t NumTypes = sizeof...(Types);
+  static constexpr std::size_t NumValues = NumTypes + 2;
+
+  static_assert(NumValues <= 256, "Too many types for raw enum.");
+
+  // The underlying raw enumeration type.
+  enum class RawEnumType : uint8_t {
+    // The first sizeof...(Types) values correspond to the types.
+
+    // An explicitly invalid value.
+    Invalid = NumTypes,
+
+    // Indicates that no type should be used.
+    // TODO: This doesn't really fit the model of this type, but it's convenient
+    // for all of its users.
+    None,
+  };
+
+  // Accesses the type given an enum value.
+  template <RawEnumType K>
+    requires(K != RawEnumType::Invalid)
+  using TypeFor = __type_pack_element<static_cast<size_t>(K), Types...>;
+
+  // Workarond for Clang bug https://github.com/llvm/llvm-project/issues/85461
+  template <RawEnumType Value>
+  static constexpr auto FromRaw = TypeEnum(Value);
+
+  // Names for the `Invalid` and `None` enumeration values.
+  static constexpr const TypeEnum& Invalid = FromRaw<RawEnumType::Invalid>;
+  static constexpr const TypeEnum& None = FromRaw<RawEnumType::None>;
+
+  // Accesses the enumeration value for the type `IdT`. If `AllowInvalid` is
+  // set, any unexpected type is mapped to `Invalid`, otherwise an invalid type
+  // results in a compile error.
+  //
+  // The `Self` parameter is an implementation detail to allow `ForImpl` to be
+  // defined after this template, and should not be specified.
+  template <typename IdT, bool AllowInvalid = false, typename Self = TypeEnum>
+  static constexpr auto For = Self::template ForImpl<IdT, AllowInvalid>();
+
+  // This bool indicates whether the specified type corresponds to a value in
+  // this enum.
+  template <typename IdT>
+  static constexpr bool Contains = For<IdT, true>.is_valid();
+
+  // Explicitly convert from the raw enum type.
+  explicit constexpr TypeEnum(RawEnumType value) : value_(value) {}
+
+  // Implicitly convert to the raw enum type, for use in `switch`.
+  //
+  // NOLINTNEXTLINE(google-explicit-constructor)
+  constexpr operator RawEnumType() const { return value_; }
+
+  // Conversion to bool is deleted to prevent direct use in an `if` condition
+  // instead of comparing with another value.
+  explicit operator bool() const = delete;
+
+  // Returns the raw enum value.
+  constexpr auto ToRaw() const -> RawEnumType { return value_; }
+
+  // Returns a value that can be used as an array index. Returned value will be
+  // < NumValues.
+  constexpr auto ToIndex() const -> std::size_t {
+    return static_cast<std::size_t>(value_);
+  }
+
+  // Returns whether this is a valid value, not `Invalid`.
+  constexpr auto is_valid() const -> bool {
+    return value_ != RawEnumType::Invalid;
+  }
+
+ private:
+  // Translates a type to its enum value, or `Invalid`.
+  template <typename IdT, bool AllowInvalid>
+  static constexpr auto ForImpl() -> TypeEnum {
+    // A bool for each type saying whether it matches. The result is the index
+    // of the first `true` in this list. If none matches, then the result is the
+    // length of the list, which is mapped to `Invalid`.
+    constexpr bool TypeMatches[] = {std::same_as<IdT, Types>...};
+    constexpr int Index =
+        std::find(TypeMatches, TypeMatches + NumTypes, true) - TypeMatches;
+    static_assert(Index != NumTypes || AllowInvalid,
+                  "Unexpected type passed to TypeEnum::For<...>");
+    return TypeEnum(static_cast<RawEnumType>(Index));
+  }
+
+  RawEnumType value_;
+};
+
+// An enum of all the ID types used as instruction operands.
+using IdKind = TypeEnum<
+    // From sem_ir/builtin_kind.h.
+    BuiltinKind,
+    // From base/value_store.h.
+    IntId, RealId, StringLiteralValueId,
+    // From sem_ir/id.h.
+    InstId, ConstantId, BindNameId, FunctionId, ClassId, InterfaceId, ImplId,
+    ImportIRId, BoolValue, NameId, NameScopeId, InstBlockId, TypeId,
+    TypeBlockId, ElementIndex>;
+
+}  // namespace Carbon::SemIR
+
+#endif  // CARBON_TOOLCHAIN_SEM_IR_ID_KIND_H_

+ 18 - 0
toolchain/sem_ir/inst.cpp

@@ -32,4 +32,22 @@ auto Inst::Print(llvm::raw_ostream& out) const -> void {
   out << "}";
 }
 
+// Returns the IdKind of an instruction's argument, or None if there is no
+// argument with that index.
+template <typename InstKind, int ArgIndex>
+static constexpr auto IdKindFor() -> IdKind {
+  using Info = Internal::InstLikeTypeInfo<InstKind>;
+  if constexpr (ArgIndex < Info::NumArgs) {
+    return IdKind::For<typename Info::template ArgType<ArgIndex>>;
+  } else {
+    return IdKind::None;
+  }
+}
+
+const std::pair<IdKind, IdKind> Inst::ArgKindTable[] = {
+#define CARBON_SEM_IR_INST_KIND(Name) \
+  {IdKindFor<Name, 0>(), IdKindFor<Name, 1>()},
+#include "toolchain/sem_ir/inst_kind.def"
+};
+
 }  // namespace Carbon::SemIR

+ 16 - 0
toolchain/sem_ir/inst.h

@@ -14,6 +14,7 @@
 #include "toolchain/base/index_base.h"
 #include "toolchain/sem_ir/block_value_store.h"
 #include "toolchain/sem_ir/builtin_kind.h"
+#include "toolchain/sem_ir/id_kind.h"
 #include "toolchain/sem_ir/inst_kind.h"
 #include "toolchain/sem_ir/typed_insts.h"
 
@@ -209,6 +210,18 @@ class Inst : public Printable<Inst> {
   // Gets the type of the value produced by evaluating this instruction.
   auto type_id() const -> TypeId { return type_id_; }
 
+  // Gets the kinds of IDs used for arg0 and arg1 of the specified kind of
+  // instruction.
+  //
+  // TODO: This would ideally live on InstKind, but can't be there for layering
+  // reasons.
+  static auto ArgKinds(InstKind kind) -> std::pair<IdKind, IdKind> {
+    return ArgKindTable[kind.AsInt()];
+  }
+
+  // Gets the kinds of IDs used for arg0 and arg1 of this instruction.
+  auto ArgKinds() const -> std::pair<IdKind, IdKind> { return ArgKinds(kind_); }
+
   // Gets the first argument of the instruction. InvalidIndex if there is no
   // such argument.
   auto arg0() const -> int32_t { return arg0_; }
@@ -222,6 +235,9 @@ class Inst : public Printable<Inst> {
  private:
   friend class InstTestHelper;
 
+  // Table mapping instruction kinds to their argument kinds.
+  static const std::pair<IdKind, IdKind> ArgKindTable[];
+
   // Raw constructor, used for testing.
   explicit Inst(InstKind kind, TypeId type_id, int32_t arg0, int32_t arg1)
       : kind_(kind), type_id_(type_id), arg0_(arg0), arg1_(arg1) {}

+ 18 - 42
toolchain/sem_ir/inst_profile.cpp

@@ -65,54 +65,30 @@ static auto RealProfileArgFunction(llvm::FoldingSetNodeID& id,
   id.AddBoolean(real.is_decimal);
 }
 
-// Selects the function to use to profile argument N of instruction InstT. We
-// compute this in advance so that we can reuse the profiling code for all
-// instructions that are profiled in the same way. For example, all instructions
-// that take two IDs that are profiled by value use the same profiling code,
-// namely `ProfileArgs<DefaultProfileArgFunction, DefaultProfileArgFunction>`.
-template <typename InstT, int N>
-static constexpr auto SelectProfileArgFunction() -> ProfileArgFunction* {
-  if constexpr (N >= Internal::InstLikeTypeInfo<InstT>::NumArgs) {
-    // This argument is not used by this instruction; don't profile it.
-    return NullProfileArgFunction;
-  } else {
-    using ArgT = Internal::InstLikeTypeInfo<InstT>::template ArgType<N>;
-    if constexpr (std::is_same_v<ArgT, InstBlockId>) {
-      return InstBlockProfileArgFunction;
-    } else if constexpr (std::is_same_v<ArgT, TypeBlockId>) {
-      return TypeBlockProfileArgFunction;
-    } else if constexpr (std::is_same_v<ArgT, IntId>) {
-      return IntProfileArgFunction;
-    } else if constexpr (std::is_same_v<ArgT, RealId>) {
-      return RealProfileArgFunction;
-    } else {
-      return DefaultProfileArgFunction;
-    }
-  }
-}
-
-// Profiles the given instruction arguments using the specified functions.
-template <ProfileArgFunction* ProfileArg0, ProfileArgFunction* ProfileArg1>
-static auto ProfileArgs(llvm::FoldingSetNodeID& id, const File& sem_ir,
-                        int32_t arg0, int32_t arg1) -> void {
-  ProfileArg0(id, sem_ir, arg0);
-  ProfileArg1(id, sem_ir, arg1);
+// Profiles the given instruction argument, which is of the specified kind.
+static auto ProfileArg(llvm::FoldingSetNodeID& id, const File& sem_ir,
+                       IdKind arg_kind, int32_t arg) -> void {
+  static constexpr std::array<ProfileArgFunction*, IdKind::NumValues>
+      ProfileFunctions = [] {
+        std::array<ProfileArgFunction*, IdKind::NumValues> array;
+        array.fill(DefaultProfileArgFunction);
+        array[IdKind::None.ToIndex()] = NullProfileArgFunction;
+        array[IdKind::For<InstBlockId>.ToIndex()] = InstBlockProfileArgFunction;
+        array[IdKind::For<TypeBlockId>.ToIndex()] = TypeBlockProfileArgFunction;
+        array[IdKind::For<IntId>.ToIndex()] = IntProfileArgFunction;
+        array[IdKind::For<RealId>.ToIndex()] = RealProfileArgFunction;
+        return array;
+      }();
+  ProfileFunctions[arg_kind.ToIndex()](id, sem_ir, arg);
 }
 
 auto ProfileConstant(llvm::FoldingSetNodeID& id, const File& sem_ir, Inst inst)
     -> void {
-  using ProfileArgsFunction =
-      auto(llvm::FoldingSetNodeID&, const File&, int32_t, int32_t)->void;
-  static constexpr ProfileArgsFunction* ProfileFunctions[] = {
-#define CARBON_SEM_IR_INST_KIND(KindName)              \
-  ProfileArgs<SelectProfileArgFunction<KindName, 0>(), \
-              SelectProfileArgFunction<KindName, 1>()>,
-#include "toolchain/sem_ir/inst_kind.def"
-  };
-
   inst.kind().Profile(id);
   id.AddInteger(inst.type_id().index);
-  ProfileFunctions[inst.kind().AsInt()](id, sem_ir, inst.arg0(), inst.arg1());
+  auto arg_kinds = inst.ArgKinds();
+  ProfileArg(id, sem_ir, arg_kinds.first, inst.arg0());
+  ProfileArg(id, sem_ir, arg_kinds.second, inst.arg1());
 }
 
 }  // namespace Carbon::SemIR