Преглед на файлове

Support determining `IdKind` from `NodeCategory`, in addition to `NodeKind` (#3648)

The categories `Expr`, `MemberName`, `Decl`, `Statement`, and `Modifier`
are usable since they are associated with a consistent `IdKind`. The
mapping to `IdKind` for NodeKinds that have those categories are no
longer listed explicitly, ensuring that the `NodeCategory` mapping is
the source of truth.

Also: fixes the category of the `FunctionDefinitionStart` and
`ArrayExprStart` node kinds.

Note: I've added [a section on defining constexpr constants to the
Toolchain architecture
doc](https://docs.google.com/document/d/1RRYMm42osyqhI2LyjrjockYCutQ5dOf8Abu50kTrkX0/edit?resourcekey=0-kHyqOESbOHmzZphUbtLrTw&tab=t.0#heading=h.f7682a2tpvxr).

FUTURE:

* We should switch `TuplePattern` to put an `InstId` on the `NodeStack`
instead of an `InstBlockId`, so we can handle the pattern category.
* We should make a category for names to replace uses of the `NameId`
`IdKind`.
* We should make use of these new APIs more, and propagate more-precise
types through the codebase.

QUESTION: Should I use a different approach for determining the number
of members of the `NodeKind` enum?

---------

Co-authored-by: Richard Smith <richard@metafoo.co.uk>
josh11b преди 2 години
родител
ревизия
afd7115c0e
променени са 4 файла, в които са добавени 234 реда и са изтрити 102 реда
  1. 6 2
      toolchain/check/context.cpp
  2. 205 94
      toolchain/check/node_stack.h
  3. 17 1
      toolchain/parse/node_kind.h
  4. 6 5
      toolchain/parse/typed_nodes.h

+ 6 - 2
toolchain/check/context.cpp

@@ -581,12 +581,16 @@ auto Context::is_current_position_reachable() -> bool {
 auto Context::ParamOrArgStart() -> void { params_or_args_stack_.Push(); }
 
 auto Context::ParamOrArgComma() -> void {
-  ParamOrArgSave(node_stack_.PopExpr());
+  // Support expressions, parameters, and other nodes like `StructFieldValue`
+  // that produce InstIds.
+  ParamOrArgSave(node_stack_.Pop<SemIR::InstId>());
 }
 
 auto Context::ParamOrArgEndNoPop(Parse::NodeKind start_kind) -> void {
   if (!node_stack_.PeekIs(start_kind)) {
-    ParamOrArgSave(node_stack_.PopExpr());
+    // Support expressions, parameters, and other nodes like `StructFieldValue`
+    // that produce InstIds.
+    ParamOrArgSave(node_stack_.Pop<SemIR::InstId>());
   }
 }
 

+ 205 - 94
toolchain/check/node_stack.h

@@ -11,6 +11,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/typed_nodes.h"
 #include "toolchain/sem_ir/inst.h"
 
 namespace Carbon::Check {
@@ -23,10 +24,11 @@ namespace Carbon::Check {
 //
 // Pop APIs will run basic verification:
 //
-// - If receiving a pop_parse_kind, verify that the parse_node being popped is
-//   of pop_parse_kind.
-// - Validates presence of inst_id based on whether it's a solo
-//   parse_node.
+// - If receiving a Parse::NodeKind, verify that the parse_node being popped has
+//   that kind. Similarly, if receiving a Parse::NodeCategory, make sure the
+//   of the popped parse_node overlaps that category.
+// - Validates the kind of id data in the node based on the kind or category of
+//   the parse_node.
 //
 // These should be assumed API constraints unless otherwise mentioned on a
 // method. The main exception is PopAndIgnore, which doesn't do verification.
@@ -77,6 +79,20 @@ class NodeStack {
     return PeekIs(RequiredParseKind);
   }
 
+  // Returns whether the node on the top of the stack has an overlapping
+  // category.
+  auto PeekIs(Parse::NodeCategory category) const -> bool {
+    return !stack_.empty() && !!(PeekParseNodeKind().category() & category);
+  }
+
+  // Returns whether the node on the top of the stack has an overlapping
+  // category. Templated for consistency with other functions taking a parse
+  // node category.
+  template <Parse::NodeCategory RequiredParseCategory>
+  auto PeekIs() const -> bool {
+    return PeekIs(RequiredParseCategory);
+  }
+
   // Returns whether there is a name on top of the stack.
   auto PeekIsName() const -> bool {
     return !stack_.empty() &&
@@ -130,9 +146,7 @@ class NodeStack {
 
   // Pops an expression from the top of the stack and returns the parse_node and
   // the ID.
-  auto PopExprWithParseNode() -> std::pair<Parse::NodeId, SemIR::InstId> {
-    return PopWithParseNode<SemIR::InstId>();
-  }
+  auto PopExprWithParseNode() -> std::pair<Parse::AnyExprId, SemIR::InstId>;
 
   // Pops a pattern from the top of the stack and returns the parse_node and
   // the ID.
@@ -194,6 +208,15 @@ class NodeStack {
                    << "; see value in ParseNodeKindToIdKind";
   }
 
+  // Pops the top of the stack and returns the parse_node and the ID.
+  template <Parse::NodeCategory RequiredParseCategory>
+  auto PopWithParseNode() -> auto {
+    auto id = Peek<RequiredParseCategory>();
+    Parse::NodeIdInCategory<RequiredParseCategory> parse_node(
+        stack_.pop_back_val().parse_node);
+    return std::make_pair(parse_node, id);
+  }
+
   // Pops the top of the stack and returns the parse_node and the ID if it is
   // of the specified kind.
   template <const Parse::NodeKind& RequiredParseKind>
@@ -205,8 +228,19 @@ class NodeStack {
     return PopWithParseNode<RequiredParseKind>();
   }
 
+  // Pops the top of the stack and returns the parse_node and the ID if it is
+  // of the specified category
+  template <Parse::NodeCategory RequiredParseCategory>
+  auto PopWithParseNodeIf()
+      -> std::optional<decltype(PopWithParseNode<RequiredParseCategory>())> {
+    if (!PeekIs<RequiredParseCategory>()) {
+      return std::nullopt;
+    }
+    return PopWithParseNode<RequiredParseCategory>();
+  }
+
   // Pops an expression from the top of the stack and returns the ID.
-  // Expressions map multiple Parse::NodeKinds to SemIR::InstId always.
+  // Expressions always map Parse::NodeCategory::Expr nodes to SemIR::InstId.
   auto PopExpr() -> SemIR::InstId { return PopExprWithParseNode().second; }
 
   // Pops a pattern from the top of the stack and returns the ID.
@@ -218,15 +252,24 @@ class NodeStack {
   // Pops a name from the top of the stack and returns the ID.
   auto PopName() -> SemIR::NameId { return PopNameWithParseNode().second; }
 
-  // TODO: Can we add a `Pop<...>` that takes a parse node category? See
-  // https://github.com/carbon-language/carbon-lang/pull/3534/files#r1432067519
-
   // Pops the top of the stack and returns the ID.
   template <const Parse::NodeKind& RequiredParseKind>
   auto Pop() -> auto {
     return PopWithParseNode<RequiredParseKind>().second;
   }
 
+  // Pops the top of the stack and returns the ID.
+  template <Parse::NodeCategory RequiredParseCategory>
+  auto Pop() -> auto {
+    return PopWithParseNode<RequiredParseCategory>().second;
+  }
+
+  // Pops the top of the stack and returns the ID.
+  template <typename IdT>
+  auto Pop() -> IdT {
+    return PopWithParseNode<IdT>().second;
+  }
+
   // Pops the top of the stack if it has the given kind, and returns the ID.
   // Otherwise returns std::nullopt.
   template <const Parse::NodeKind& RequiredParseKind>
@@ -237,6 +280,16 @@ class NodeStack {
     return std::nullopt;
   }
 
+  // Pops the top of the stack if it has the given category, and returns the ID.
+  // Otherwise returns std::nullopt.
+  template <Parse::NodeCategory RequiredParseCategory>
+  auto PopIf() -> std::optional<decltype(Pop<RequiredParseCategory>())> {
+    if (PeekIs<RequiredParseCategory>()) {
+      return Pop<RequiredParseCategory>();
+    }
+    return std::nullopt;
+  }
+
   // Peeks at the parse node of the top of the node stack.
   auto PeekParseNode() const -> Parse::NodeId {
     return stack_.back().parse_node;
@@ -278,6 +331,23 @@ class NodeStack {
                    << "; see value in ParseNodeKindToIdKind";
   }
 
+  // Peeks at the ID associated with the top of the name stack.
+  template <Parse::NodeCategory RequiredParseCategory>
+  auto Peek() const -> auto {
+    Entry back = stack_.back();
+    RequireParseCategory<RequiredParseCategory>(back.parse_node);
+    constexpr std::optional<IdKind> RequiredIdKind =
+        ParseNodeCategoryToIdKind(RequiredParseCategory);
+    static_assert(RequiredIdKind.has_value());
+    if constexpr (*RequiredIdKind == IdKind::InstId) {
+      return back.id<SemIR::InstId>();
+    } else {
+      static_assert(*RequiredIdKind == IdKind::NameId,
+                    "Unpeekable IdKind for parse category");
+      return back.id<SemIR::NameId>();
+    }
+  }
+
   // Prints the stack for a stack dump.
   auto PrintForStackDump(llvm::raw_ostream& output) const -> void;
 
@@ -293,6 +363,7 @@ class NodeStack {
     ClassId,
     InterfaceId,
     NameId,
+    // NOTE: Currently unused.
     TypeId,
     // No associated ID type.
     SoloParseNode,
@@ -363,91 +434,114 @@ class NodeStack {
   };
   static_assert(sizeof(Entry) == 8, "Unexpected Entry size");
 
+  // Translate a parse node category to the enum ID kind it should always
+  // provide, if it is consistent.
+  static constexpr auto ParseNodeCategoryToIdKind(Parse::NodeCategory category)
+      -> std::optional<IdKind> {
+    // TODO: Patterns should also produce an `InstId`, but currently
+    // `TuplePattern` produces an `InstBlockId`.
+    if (!!(category & Parse::NodeCategory::Expr)) {
+      // Check for no consistent IdKind due to category with multiple bits set.
+      if (!!(category & ~Parse::NodeCategory::Expr)) {
+        return std::nullopt;
+      }
+      return IdKind::InstId;
+    }
+    if (!!(category & Parse::NodeCategory::MemberName)) {
+      // Check for no consistent IdKind due to category with multiple bits set.
+      if (!!(category & ~Parse::NodeCategory::MemberName)) {
+        return std::nullopt;
+      }
+      return IdKind::NameId;
+    }
+    constexpr Parse::NodeCategory UnusedCategories =
+        Parse::NodeCategory::Decl | Parse::NodeCategory::Statement |
+        Parse::NodeCategory::Modifier;
+    if (!!(category & UnusedCategories)) {
+      // Check for no consistent IdKind due to category with multiple bits set.
+      if (!!(category & ~UnusedCategories)) {
+        return std::nullopt;
+      }
+      return IdKind::Unused;
+    }
+    return std::nullopt;
+  }
+
+  static constexpr auto ComputeIdKindTable()
+      -> std::array<IdKind, Parse::NodeKind::ValidCount> {
+    std::array<IdKind, Parse::NodeKind::ValidCount> table = {};
+
+    auto to_id_kind =
+        [](const Parse::NodeKind::Definition& node_kind) -> IdKind {
+      if (auto from_category =
+              NodeStack::ParseNodeCategoryToIdKind(node_kind.category())) {
+        return *from_category;
+      }
+      switch (node_kind) {
+        case Parse::NodeKind::Addr:
+        case Parse::NodeKind::BindingPattern:
+        case Parse::NodeKind::CallExprStart:
+        case Parse::NodeKind::GenericBindingPattern:
+        case Parse::NodeKind::IfExprThen:
+        case Parse::NodeKind::ReturnType:
+        case Parse::NodeKind::ShortCircuitOperandAnd:
+        case Parse::NodeKind::ShortCircuitOperandOr:
+        case Parse::NodeKind::StructFieldValue:
+        case Parse::NodeKind::StructFieldType:
+        case Parse::NodeKind::VariableInitializer:
+          return IdKind::InstId;
+        case Parse::NodeKind::IfCondition:
+        case Parse::NodeKind::IfExprIf:
+        case Parse::NodeKind::ImplicitParamList:
+        case Parse::NodeKind::TuplePattern:
+        case Parse::NodeKind::WhileCondition:
+        case Parse::NodeKind::WhileConditionStart:
+          return IdKind::InstBlockId;
+        case Parse::NodeKind::FunctionDefinitionStart:
+          return IdKind::FunctionId;
+        case Parse::NodeKind::ClassDefinitionStart:
+          return IdKind::ClassId;
+        case Parse::NodeKind::InterfaceDefinitionStart:
+          return IdKind::InterfaceId;
+        case Parse::NodeKind::IdentifierName:
+        case Parse::NodeKind::SelfValueName:
+          return IdKind::NameId;
+        case Parse::NodeKind::ArrayExprSemi:
+        case Parse::NodeKind::ClassIntroducer:
+        case Parse::NodeKind::CodeBlockStart:
+        case Parse::NodeKind::ExprOpenParen:
+        case Parse::NodeKind::FunctionIntroducer:
+        case Parse::NodeKind::IfStatementElse:
+        case Parse::NodeKind::ImplicitParamListStart:
+        case Parse::NodeKind::InterfaceIntroducer:
+        case Parse::NodeKind::LetIntroducer:
+        case Parse::NodeKind::QualifiedName:
+        case Parse::NodeKind::ReturnedModifier:
+        case Parse::NodeKind::ReturnStatementStart:
+        case Parse::NodeKind::ReturnVarModifier:
+        case Parse::NodeKind::StructLiteralOrStructTypeLiteralStart:
+        case Parse::NodeKind::TuplePatternStart:
+        case Parse::NodeKind::VariableIntroducer:
+          return IdKind::SoloParseNode;
+        default:
+          return IdKind::Unused;
+      }
+    };
+
+#define CARBON_PARSE_NODE_KIND(Name) \
+  table[Parse::Name::Kind.AsInt()] = to_id_kind(Parse::Name::Kind);
+#include "toolchain/parse/node_kind.def"
+
+    return table;
+  }
+
+  // Lookup table to implement `ParseNodeKindToIdKind`. Initialized to the
+  // return value of `ComputeIdKindTable()`.
+  static const std::array<IdKind, Parse::NodeKind::ValidCount> IdKindTable;
+
   // Translate a parse node kind to the enum ID kind it should always provide.
   static constexpr auto ParseNodeKindToIdKind(Parse::NodeKind kind) -> IdKind {
-    switch (kind) {
-      case Parse::NodeKind::Addr:
-      case Parse::NodeKind::ArrayExpr:
-      case Parse::NodeKind::BindingPattern:
-      case Parse::NodeKind::CallExpr:
-      case Parse::NodeKind::CallExprStart:
-      case Parse::NodeKind::GenericBindingPattern:
-      case Parse::NodeKind::IdentifierNameExpr:
-      case Parse::NodeKind::IfExprThen:
-      case Parse::NodeKind::IfExprElse:
-      case Parse::NodeKind::IndexExpr:
-      case Parse::NodeKind::MemberAccessExpr:
-      case Parse::NodeKind::PackageExpr:
-      case Parse::NodeKind::ParenExpr:
-      case Parse::NodeKind::ReturnType:
-      case Parse::NodeKind::SelfTypeNameExpr:
-      case Parse::NodeKind::SelfValueNameExpr:
-      case Parse::NodeKind::ShortCircuitOperandAnd:
-      case Parse::NodeKind::ShortCircuitOperandOr:
-      case Parse::NodeKind::ShortCircuitOperatorAnd:
-      case Parse::NodeKind::ShortCircuitOperatorOr:
-      case Parse::NodeKind::StructFieldValue:
-      case Parse::NodeKind::StructLiteral:
-      case Parse::NodeKind::StructFieldType:
-      case Parse::NodeKind::StructTypeLiteral:
-      case Parse::NodeKind::TupleLiteral:
-      case Parse::NodeKind::VariableInitializer:
-        return IdKind::InstId;
-      case Parse::NodeKind::IfCondition:
-      case Parse::NodeKind::IfExprIf:
-      case Parse::NodeKind::ImplicitParamList:
-      case Parse::NodeKind::TuplePattern:
-      case Parse::NodeKind::WhileCondition:
-      case Parse::NodeKind::WhileConditionStart:
-        return IdKind::InstBlockId;
-      case Parse::NodeKind::FunctionDefinitionStart:
-        return IdKind::FunctionId;
-      case Parse::NodeKind::ClassDefinitionStart:
-        return IdKind::ClassId;
-      case Parse::NodeKind::InterfaceDefinitionStart:
-        return IdKind::InterfaceId;
-      case Parse::NodeKind::BaseName:
-      case Parse::NodeKind::IdentifierName:
-      case Parse::NodeKind::SelfValueName:
-        return IdKind::NameId;
-      case Parse::NodeKind::ArrayExprSemi:
-      case Parse::NodeKind::ClassIntroducer:
-      case Parse::NodeKind::CodeBlockStart:
-      case Parse::NodeKind::ExprOpenParen:
-      case Parse::NodeKind::FunctionIntroducer:
-      case Parse::NodeKind::IfStatementElse:
-      case Parse::NodeKind::ImplicitParamListStart:
-      case Parse::NodeKind::InterfaceIntroducer:
-      case Parse::NodeKind::LetIntroducer:
-      case Parse::NodeKind::QualifiedName:
-      case Parse::NodeKind::ReturnedModifier:
-      case Parse::NodeKind::ReturnStatementStart:
-      case Parse::NodeKind::ReturnVarModifier:
-      case Parse::NodeKind::StructLiteralOrStructTypeLiteralStart:
-      case Parse::NodeKind::TuplePatternStart:
-      case Parse::NodeKind::VariableIntroducer:
-        return IdKind::SoloParseNode;
-// Use x-macros to handle boilerplate cases.
-#define CARBON_PARSE_NODE_KIND(...)
-#define CARBON_PARSE_NODE_KIND_INFIX_OPERATOR(Name, ...) \
-  case Parse::NodeKind::InfixOperator##Name:             \
-    return IdKind::InstId;
-#define CARBON_PARSE_NODE_KIND_POSTFIX_OPERATOR(Name, ...) \
-  case Parse::NodeKind::PostfixOperator##Name:             \
-    return IdKind::InstId;
-#define CARBON_PARSE_NODE_KIND_PREFIX_OPERATOR(Name, ...) \
-  case Parse::NodeKind::PrefixOperator##Name:             \
-    return IdKind::InstId;
-#define CARBON_PARSE_NODE_KIND_TOKEN_LITERAL(Name, ...) \
-  case Parse::NodeKind::Name:                           \
-    return IdKind::InstId;
-#define CARBON_PARSE_NODE_KIND_TOKEN_MODIFIER(Name, ...) \
-  case Parse::NodeKind::Name##Modifier:                  \
-    return IdKind::SoloParseNode;
-#include "toolchain/parse/node_kind.def"
-      default:
-        return IdKind::Unused;
-    }
+    return IdKindTable[kind.AsInt()];
   }
 
   // Translates an ID type to the enum ID kind for comparison with
@@ -510,6 +604,15 @@ class NodeStack {
         << "Expected " << RequiredParseKind << ", found " << actual_kind;
   }
 
+  // Require an entry to have the given Parse::NodeCategory.
+  template <Parse::NodeCategory RequiredParseCategory>
+  auto RequireParseCategory(Parse::NodeId parse_node) const -> void {
+    auto kind = parse_tree_->node_kind(parse_node);
+    CARBON_CHECK(!!(RequiredParseCategory & kind.category()))
+        << "Expected " << RequiredParseCategory << ", found " << kind
+        << " with category " << kind.category();
+  }
+
   // The file's parse tree.
   const Parse::Tree* parse_tree_;
 
@@ -522,6 +625,14 @@ class NodeStack {
   llvm::SmallVector<Entry> stack_;
 };
 
+constexpr std::array<NodeStack::IdKind, Parse::NodeKind::ValidCount>
+    NodeStack::IdKindTable = NodeStack::ComputeIdKindTable();
+
+inline auto NodeStack::PopExprWithParseNode()
+    -> std::pair<Parse::AnyExprId, SemIR::InstId> {
+  return PopWithParseNode<Parse::NodeCategory::Expr>();
+}
+
 }  // namespace Carbon::Check
 
 #endif  // CARBON_TOOLCHAIN_CHECK_NODE_STACK_H_

+ 17 - 1
toolchain/parse/node_kind.h

@@ -32,7 +32,7 @@ enum class NodeCategory : uint32_t {
   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/Statement)
 };
 
-inline auto operator!(NodeCategory k) -> bool {
+inline constexpr auto operator!(NodeCategory k) -> bool {
   return !static_cast<uint32_t>(k);
 }
 
@@ -70,6 +70,10 @@ class NodeKind : public CARBON_ENUM_BASE(NodeKind) {
   // Returns which categories this node kind is in.
   auto category() const -> NodeCategory;
 
+  // Number of different kinds, usable in a constexpr context.
+  static const int ValidCount;
+
+  using EnumBase::AsInt;
   using EnumBase::Create;
 
   class Definition;
@@ -88,6 +92,18 @@ class NodeKind : public CARBON_ENUM_BASE(NodeKind) {
   CARBON_ENUM_CONSTANT_DEFINITION(NodeKind, Name)
 #include "toolchain/parse/node_kind.def"
 
+constexpr int NodeKind::ValidCount = 0
+// NOLINTNEXTLINE(bugprone-macro-parentheses)
+#define CARBON_PARSE_NODE_KIND(Name) +1
+#include "toolchain/parse/node_kind.def"
+    ;
+
+static_assert(
+    NodeKind::ValidCount != 0,
+    "The above `constexpr` definition of `ValidCount` makes it available in "
+    "a `constexpr` context despite being declared as merely `const`. We use it "
+    "in a static assert here to ensure that.");
+
 // We expect the parse node kind to fit compactly into 8 bits.
 static_assert(sizeof(NodeKind) == 1, "Kind objects include padding!");
 

+ 6 - 5
toolchain/parse/typed_nodes.h

@@ -283,9 +283,9 @@ struct ReturnType {
 };
 
 // A function signature: `fn F() -> i32`.
-template <const NodeKind& KindT>
+template <const NodeKind& KindT, NodeCategory Category>
 struct FunctionSignature {
-  static constexpr auto Kind = KindT.Define(NodeCategory::Decl);
+  static constexpr auto Kind = KindT.Define(Category);
 
   FunctionIntroducerId introducer;
   llvm::SmallVector<AnyModifierId> modifiers;
@@ -296,9 +296,10 @@ struct FunctionSignature {
   std::optional<ReturnTypeId> return_type;
 };
 
-using FunctionDecl = FunctionSignature<NodeKind::FunctionDecl>;
+using FunctionDecl =
+    FunctionSignature<NodeKind::FunctionDecl, NodeCategory::Decl>;
 using FunctionDefinitionStart =
-    FunctionSignature<NodeKind::FunctionDefinitionStart>;
+    FunctionSignature<NodeKind::FunctionDefinitionStart, NodeCategory::None>;
 
 // A function definition: `fn F() -> i32 { ... }`.
 struct FunctionDefinition {
@@ -484,7 +485,7 @@ struct WhileStatement {
 // Expression nodes
 // ----------------
 
-using ArrayExprStart = LeafNode<NodeKind::ArrayExprStart, NodeCategory::Expr>;
+using ArrayExprStart = LeafNode<NodeKind::ArrayExprStart>;
 
 // The start of an array type, `[i32;`.
 //