Ver código fonte

Typed wrappers around parse tree nodes (#3534)

These are intended to allow the structure of a parse tree node to be
described more precisely in code, to support these use cases:

- Automated checking that the parse tree conforms to the expected
structure. (Added to `Tree::Verify`.)
- Easier reading and understanding of the structure of the parse tree by
toolchain developers. (See `parse/typed_nodes.h`.)
- Easier navigation of the parse tree, for example for tooling uses and
for use when forming diagnostics.

On this last point, an object representing the file may be inspecting
using `Tree::ExtractFile`, as in:
```
auto file = tree->ExtractFile();
for (AnyDeclId decl_id : file.decls) {
  // `decl_id` is convertible to a `NodeId`.
  if (std::optional<FunctionDecl> fn_decl =
      tree->ExtractAs<FunctionDecl>(decl_id)) {
    // fn_decl->params is a `TuplePatternId` (which extends `NodeId`)
    // that is guaranteed to reference a `TuplePattern`.
    std::optional<TuplePattern> params = tree->Extract(fn_decl->params);
    // `params` has a value unless there was an error in that node.
  } else if (auto class_def = tree->ExtractAs<ClassDefinition>(decl_id)) {
    // ...
  }
}
```

The `Extract...` functions collect the child nodes into the typed parse
node's fields (internally using a `Tree::SiblingIterator`) for easy
access. However, this is not as fast as directly observing the tree
structure using the postorder strategy being used by the check stage.

These functions rely on using struct reflection on the typed parse node
definitions from `parse/typed_nodes.h` to get the expected structure of
child nodes and then populate them.

Note that validating these in `Tree::Verify` adds significant cost to
it, and is currently included in the parsing stage. Without this change,
a 10 mloc test case of lex & parse takes 4.129 s ± 0.041 s. With this
change, it takes 5.768 s ± 0.036 s.

This builds upon and completes #3393.

Co-authored-by: Richard Smith <richard@metafoo.co.uk>

---------

Co-authored-by: Richard Smith <richard@metafoo.co.uk>
Co-authored-by: Chandler Carruth <chandlerc@gmail.com>
josh11b 2 anos atrás
pai
commit
2e97f27b8d
38 arquivos alterados com 1865 adições e 76 exclusões
  1. 14 2
      common/struct_reflection.h
  2. 22 1
      common/struct_reflection_test.cpp
  3. 1 0
      toolchain/check/check.cpp
  4. 1 1
      toolchain/check/decl_state.h
  5. 5 0
      toolchain/check/node_stack.h
  6. 25 1
      toolchain/parse/BUILD
  7. 342 0
      toolchain/parse/extract.cpp
  8. 1 0
      toolchain/parse/handle_brace_expr.cpp
  9. 7 5
      toolchain/parse/handle_let.cpp
  10. 99 0
      toolchain/parse/node_ids.h
  11. 13 0
      toolchain/parse/node_kind.cpp
  12. 21 13
      toolchain/parse/node_kind.def
  13. 71 0
      toolchain/parse/node_kind.h
  14. 1 1
      toolchain/parse/testdata/let/fail_empty.carbon
  15. 1 1
      toolchain/parse/testdata/struct/fail_comma_only.carbon
  16. 1 1
      toolchain/parse/testdata/struct/fail_comma_repeat_in_type.carbon
  17. 1 1
      toolchain/parse/testdata/struct/fail_comma_repeat_in_value.carbon
  18. 1 1
      toolchain/parse/testdata/struct/fail_dot_only.carbon
  19. 1 1
      toolchain/parse/testdata/struct/fail_dot_string_colon.carbon
  20. 1 1
      toolchain/parse/testdata/struct/fail_dot_string_equals.carbon
  21. 1 1
      toolchain/parse/testdata/struct/fail_identifier_colon.carbon
  22. 1 1
      toolchain/parse/testdata/struct/fail_identifier_equals.carbon
  23. 1 1
      toolchain/parse/testdata/struct/fail_identifier_only.carbon
  24. 1 1
      toolchain/parse/testdata/struct/fail_missing_type.carbon
  25. 1 1
      toolchain/parse/testdata/struct/fail_missing_value.carbon
  26. 1 1
      toolchain/parse/testdata/struct/fail_mix_type_and_value.carbon
  27. 1 1
      toolchain/parse/testdata/struct/fail_mix_value_and_type.carbon
  28. 2 2
      toolchain/parse/testdata/struct/fail_mix_with_unknown.carbon
  29. 1 1
      toolchain/parse/testdata/struct/fail_no_colon_or_equals.carbon
  30. 1 1
      toolchain/parse/testdata/struct/fail_type_no_designator.carbon
  31. 23 0
      toolchain/parse/tree.cpp
  32. 123 18
      toolchain/parse/tree.h
  33. 909 0
      toolchain/parse/typed_nodes.h
  34. 152 0
      toolchain/parse/typed_nodes_test.cpp
  35. 5 5
      toolchain/sem_ir/inst.h
  36. 1 1
      toolchain/sem_ir/inst_kind.cpp
  37. 6 6
      toolchain/sem_ir/typed_insts.h
  38. 6 5
      toolchain/sem_ir/typed_insts_test.cpp

+ 14 - 2
common/struct_reflection.h

@@ -21,7 +21,7 @@
 // - Only simple aggregate structs are supported. Types with base classes,
 // - Only simple aggregate structs are supported. Types with base classes,
 //   non-public data members, constructors, or virtual functions are not
 //   non-public data members, constructors, or virtual functions are not
 //   supported.
 //   supported.
-// - Structs with more than 5 fields are not supported. This limit is easy to
+// - Structs with more than 6 fields are not supported. This limit is easy to
 //   increase if needed, but removing it entirely is hard.
 //   increase if needed, but removing it entirely is hard.
 // - Structs containing a reference to the same type are not supported.
 // - Structs containing a reference to the same type are not supported.
 
 
@@ -75,7 +75,8 @@ constexpr auto CountFields() -> int {
   if constexpr (CanListInitialize<T, Fields...>(0)) {
   if constexpr (CanListInitialize<T, Fields...>(0)) {
     return CountFields<T, true, Fields..., AnyField<T>>();
     return CountFields<T, true, Fields..., AnyField<T>>();
   } else if constexpr (AnyWorkedSoFar) {
   } else if constexpr (AnyWorkedSoFar) {
-    static_assert(sizeof...(Fields) <= 5,
+    // Note: Compare against the maximum number of fields supported *PLUS 1*.
+    static_assert(sizeof...(Fields) <= 7,
                   "Unsupported: too many fields in struct");
                   "Unsupported: too many fields in struct");
     return sizeof...(Fields) - 1;
     return sizeof...(Fields) - 1;
   } else if constexpr (sizeof...(Fields) > 32) {
   } else if constexpr (sizeof...(Fields) > 32) {
@@ -150,6 +151,17 @@ struct FieldAccessor<5> {
   }
   }
 };
 };
 
 
+template <>
+struct FieldAccessor<6> {
+  template <typename T>
+  static auto Get(T& value) -> auto {
+    auto& [field0, field1, field2, field3, field4, field5] = value;
+    return std::tuple<decltype(field0), decltype(field1), decltype(field2),
+                      decltype(field3), decltype(field4), decltype(field5)>(
+        field0, field1, field2, field3, field4, field5);
+  }
+};
+
 }  // namespace Internal
 }  // namespace Internal
 
 
 // Get the fields of the struct `T` as a tuple.
 // Get the fields of the struct `T` as a tuple.

+ 22 - 1
common/struct_reflection_test.cpp

@@ -20,6 +20,15 @@ struct TwoFields {
   int y;
   int y;
 };
 };
 
 
+struct SixFields {
+  int one;
+  int two;
+  int three;
+  int four;
+  int five;
+  int six;
+};
+
 struct ReferenceField {
 struct ReferenceField {
   int& ref;
   int& ref;
 };
 };
@@ -60,6 +69,7 @@ TEST(StructReflectionTest, CountFields) {
   static_assert(Internal::CountFields<ZeroFields>() == 0);
   static_assert(Internal::CountFields<ZeroFields>() == 0);
   static_assert(Internal::CountFields<OneField>() == 1);
   static_assert(Internal::CountFields<OneField>() == 1);
   static_assert(Internal::CountFields<TwoFields>() == 2);
   static_assert(Internal::CountFields<TwoFields>() == 2);
+  static_assert(Internal::CountFields<SixFields>() == 6);
   static_assert(Internal::CountFields<ReferenceField>() == 1);
   static_assert(Internal::CountFields<ReferenceField>() == 1);
   static_assert(Internal::CountFields<OneFieldNoDefaultConstructor>() == 1);
   static_assert(Internal::CountFields<OneFieldNoDefaultConstructor>() == 1);
 }
 }
@@ -74,12 +84,23 @@ TEST(StructReflectionTest, OneField) {
   EXPECT_EQ(std::get<0>(fields), 1);
   EXPECT_EQ(std::get<0>(fields), 1);
 }
 }
 
 
-TEST(StructReflectionTest, TwoField) {
+TEST(StructReflectionTest, TwoFields) {
   std::tuple<int, int> fields = AsTuple(TwoFields{.x = 1, .y = 2});
   std::tuple<int, int> fields = AsTuple(TwoFields{.x = 1, .y = 2});
   EXPECT_EQ(std::get<0>(fields), 1);
   EXPECT_EQ(std::get<0>(fields), 1);
   EXPECT_EQ(std::get<1>(fields), 2);
   EXPECT_EQ(std::get<1>(fields), 2);
 }
 }
 
 
+TEST(StructReflectionTest, SixFields) {
+  std::tuple<int, int, int, int, int, int> fields = AsTuple(SixFields{
+      .one = 1, .two = 2, .three = 3, .four = 4, .five = 5, .six = 6});
+  EXPECT_EQ(std::get<0>(fields), 1);
+  EXPECT_EQ(std::get<1>(fields), 2);
+  EXPECT_EQ(std::get<2>(fields), 3);
+  EXPECT_EQ(std::get<3>(fields), 4);
+  EXPECT_EQ(std::get<4>(fields), 5);
+  EXPECT_EQ(std::get<5>(fields), 6);
+}
+
 TEST(StructReflectionTest, NoDefaultConstructor) {
 TEST(StructReflectionTest, NoDefaultConstructor) {
   std::tuple<NoDefaultConstructor, NoDefaultConstructor> fields =
   std::tuple<NoDefaultConstructor, NoDefaultConstructor> fields =
       AsTuple(TwoFieldsNoDefaultConstructor{.x = NoDefaultConstructor(1),
       AsTuple(TwoFieldsNoDefaultConstructor{.x = NoDefaultConstructor(1),

+ 1 - 0
toolchain/check/check.cpp

@@ -155,6 +155,7 @@ static auto ProcessParseNodes(Context& context,
     // clang warns on unhandled enum values; clang-tidy is incorrect here.
     // clang warns on unhandled enum values; clang-tidy is incorrect here.
     // NOLINTNEXTLINE(bugprone-switch-missing-default-case)
     // NOLINTNEXTLINE(bugprone-switch-missing-default-case)
     switch (auto parse_kind = context.parse_tree().node_kind(parse_node)) {
     switch (auto parse_kind = context.parse_tree().node_kind(parse_node)) {
+      // TODO: Switch to `Parse::Name##Id(parse_node)` here.
 #define CARBON_PARSE_NODE_KIND(Name)                                         \
 #define CARBON_PARSE_NODE_KIND(Name)                                         \
   case Parse::NodeKind::Name: {                                              \
   case Parse::NodeKind::Name: {                                              \
     if (!Check::Handle##Name(context, parse_node)) {                         \
     if (!Check::Handle##Name(context, parse_node)) {                         \

+ 1 - 1
toolchain/check/decl_state.h

@@ -43,7 +43,7 @@ enum class KeywordModifierSet : uint32_t {
 };
 };
 
 
 inline auto operator!(KeywordModifierSet k) -> bool {
 inline auto operator!(KeywordModifierSet k) -> bool {
-  return !static_cast<unsigned>(k);
+  return !static_cast<uint32_t>(k);
 }
 }
 
 
 // State stored for each declaration we are currently in: the kind of
 // State stored for each declaration we are currently in: the kind of

+ 5 - 0
toolchain/check/node_stack.h

@@ -92,6 +92,7 @@ class NodeStack {
   }
   }
 
 
   // Pops the top of the stack and returns the parse_node.
   // Pops the top of the stack and returns the parse_node.
+  // TODO: return a parse::NodeIdForKind<RequiredParseKind> instead.
   template <const Parse::NodeKind& RequiredParseKind>
   template <const Parse::NodeKind& RequiredParseKind>
   auto PopForSoloParseNode() -> Parse::NodeId {
   auto PopForSoloParseNode() -> Parse::NodeId {
     Entry back = PopEntry<SemIR::InstId>();
     Entry back = PopEntry<SemIR::InstId>();
@@ -102,6 +103,7 @@ class NodeStack {
 
 
   // Pops the top of the stack if it is the given kind, and returns the
   // Pops the top of the stack if it is the given kind, and returns the
   // parse_node. Otherwise, returns std::nullopt.
   // parse_node. Otherwise, returns std::nullopt.
+  // TODO: Return a `Parse::NodeIdForKind<RequiredParseKind>` instead.
   template <const Parse::NodeKind& RequiredParseKind>
   template <const Parse::NodeKind& RequiredParseKind>
   auto PopForSoloParseNodeIf() -> std::optional<Parse::NodeId> {
   auto PopForSoloParseNodeIf() -> std::optional<Parse::NodeId> {
     if (PeekIs<RequiredParseKind>()) {
     if (PeekIs<RequiredParseKind>()) {
@@ -200,6 +202,9 @@ class NodeStack {
   // Pops a name from the top of the stack and returns the ID.
   // Pops a name from the top of the stack and returns the ID.
   auto PopName() -> SemIR::NameId { return PopNameWithParseNode().second; }
   auto PopName() -> SemIR::NameId { return PopNameWithParseNode().second; }
 
 
+  // TODO: Can we add a `Pop<...>` that takes a parse node category? See
+  // https://github.com/carbon-language/carbon-lang/pull/3534/files#r1432067519
+
   // Pops the top of the stack and returns the ID.
   // Pops the top of the stack and returns the ID.
   template <const Parse::NodeKind& RequiredParseKind>
   template <const Parse::NodeKind& RequiredParseKind>
   auto Pop() -> auto {
   auto Pop() -> auto {

+ 25 - 1
toolchain/parse/BUILD

@@ -16,12 +16,34 @@ filegroup(
 cc_library(
 cc_library(
     name = "node_kind",
     name = "node_kind",
     srcs = ["node_kind.cpp"],
     srcs = ["node_kind.cpp"],
-    hdrs = ["node_kind.h"],
+    hdrs = [
+        "node_ids.h",
+        "node_kind.h",
+        "typed_nodes.h",
+    ],
     textual_hdrs = ["node_kind.def"],
     textual_hdrs = ["node_kind.def"],
     deps = [
     deps = [
         "//common:check",
         "//common:check",
         "//common:enum_base",
         "//common:enum_base",
+        "//toolchain/base:index_base",
         "//toolchain/lex:token_kind",
         "//toolchain/lex:token_kind",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_test(
+    name = "typed_nodes_test",
+    size = "small",
+    srcs = ["typed_nodes_test.cpp"],
+    deps = [
+        ":node_kind",
+        ":tree",
+        "//testing/base:gtest_main",
+        "//toolchain/diagnostics:diagnostic_emitter",
+        "//toolchain/diagnostics:mocks",
+        "//toolchain/lex",
+        "//toolchain/lex:tokenized_buffer",
+        "@com_google_googletest//:gtest",
     ],
     ],
 )
 )
 
 
@@ -38,6 +60,7 @@ cc_library(
     srcs = [
     srcs = [
         "context.cpp",
         "context.cpp",
         "context.h",
         "context.h",
+        "extract.cpp",
         "tree.cpp",
         "tree.cpp",
     ] +
     ] +
     # Glob handler files to avoid missing any.
     # Glob handler files to avoid missing any.
@@ -52,6 +75,7 @@ cc_library(
         "//common:check",
         "//common:check",
         "//common:error",
         "//common:error",
         "//common:ostream",
         "//common:ostream",
+        "//common:struct_reflection",
         "//common:vlog",
         "//common:vlog",
         "//toolchain/base:pretty_stack_trace_function",
         "//toolchain/base:pretty_stack_trace_function",
         "//toolchain/base:value_store",
         "//toolchain/base:value_store",

+ 342 - 0
toolchain/parse/extract.cpp

@@ -0,0 +1,342 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <tuple>
+#include <utility>
+
+#include "common/error.h"
+#include "common/struct_reflection.h"
+#include "toolchain/parse/tree.h"
+#include "toolchain/parse/typed_nodes.h"
+
+namespace Carbon::Parse {
+
+// A trait type that should be specialized by types that can be extracted
+// from a parse tree. A specialization should provide the following API:
+//
+// ```cpp
+// template<>
+// struct Extractable<T> {
+//   // Extract a value of this type from the sequence of nodes starting at
+//   // `it`, and increment `it` past this type. Returns `std::nullopt` if
+//   // the tree is malformed. If `trace != nullptr`, writes what actions
+//   // were taken to `*trace`.
+//   static auto Extract(Tree* tree, Tree::SiblingIterator& it,
+//                       Tree::SiblingIterator end,
+//                       ErrorBuilder* trace) -> std::optional<T>;
+// };
+// ```
+//
+// Note that `Tree::SiblingIterator`s iterate in reverse order through the
+// children of a node.
+//
+// This class is only in this file.
+template <typename T>
+struct Extractable;
+
+// Extract a `NodeId` as a single child.
+template <>
+struct Extractable<NodeId> {
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<NodeId> {
+    if (it == end) {
+      if (trace) {
+        *trace << "NodeId error: no more children\n";
+      }
+      return std::nullopt;
+    }
+    if (trace) {
+      *trace << "NodeId: " << tree->node_kind(*it) << " consumed\n";
+    }
+    return NodeId(*it++);
+  }
+};
+
+// Extract a `FooId`, which is the same as `NodeIdForKind<NodeKind::Foo>`,
+// as a single required child.
+template <const NodeKind& Kind>
+struct Extractable<NodeIdForKind<Kind>> {
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<NodeIdForKind<Kind>> {
+    if (it == end || tree->node_kind(*it) != Kind) {
+      if (trace) {
+        if (it == end) {
+          *trace << "NodeIdForKind error: no more children, expected " << Kind
+                 << "\n";
+        } else {
+          *trace << "NodeIdForKind error: wrong kind " << tree->node_kind(*it)
+                 << ", expected " << Kind << "\n";
+        }
+      }
+      return std::nullopt;
+    }
+    if (trace) {
+      *trace << "NodeIdForKind: " << Kind << " consumed\n";
+    }
+    return NodeIdForKind<Kind>(*it++);
+  }
+};
+
+// Extract a `NodeIdInCategory<Category>` as a single child.
+template <NodeCategory Category>
+struct Extractable<NodeIdInCategory<Category>> {
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<NodeIdInCategory<Category>> {
+    if (trace) {
+      *trace << "NodeIdInCategory";
+      // TODO: Make NodeCategory printable instead.
+      if (!Category) {
+        *trace << " <none>";
+      }
+#define CARBON_NODE_CATEGORY(Name)         \
+  if (!!(Category & NodeCategory::Name)) { \
+    *trace << " " #Name;                   \
+  }
+      CARBON_NODE_CATEGORY(Decl);
+      CARBON_NODE_CATEGORY(Expr);
+      CARBON_NODE_CATEGORY(Modifier);
+      CARBON_NODE_CATEGORY(NameComponent);
+      CARBON_NODE_CATEGORY(Pattern);
+      CARBON_NODE_CATEGORY(Statement);
+#undef CARBON_NODE_CATEGORY
+    }
+
+    if (it == end || !(tree->node_kind(*it).category() & Category)) {
+      if (trace) {
+        if (it == end) {
+          *trace << " error: no more children\n";
+        } else {
+          *trace << " error: kind " << tree->node_kind(*it)
+                 << " doesn't match\n";
+        }
+      }
+      return std::nullopt;
+    }
+    if (trace) {
+      *trace << ": kind " << tree->node_kind(*it) << " consumed\n";
+    }
+    return NodeIdInCategory<Category>(*it++);
+  }
+};
+
+// Extract a `NodeIdOneOf<T, U>` as a single required child.
+template <typename T, typename U>
+struct Extractable<NodeIdOneOf<T, U>> {
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<NodeIdOneOf<T, U>> {
+    auto kind = tree->node_kind(*it);
+    if (it == end || (kind != T::Kind && kind != U::Kind)) {
+      if (trace) {
+        if (it == end) {
+          *trace << "NodeIdOneOf error: no more children, expected " << T::Kind
+                 << " or " << U::Kind << "\n";
+        } else {
+          *trace << "NodeIdOneOf error: wrong kind " << tree->node_kind(*it)
+                 << ", expected " << T::Kind << " or " << U::Kind << "\n";
+        }
+      }
+      return std::nullopt;
+    }
+    if (trace) {
+      *trace << "NodeIdOneOf " << T::Kind << " or " << U::Kind << ": "
+             << tree->node_kind(*it) << " consumed";
+    }
+    return NodeIdOneOf<T, U>(*it++);
+  }
+};
+
+// Extract a `NodeIdNot<T>` as a single required child.
+template <typename T>
+struct Extractable<NodeIdNot<T>> {
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<NodeIdNot<T>> {
+    if (it == end || tree->node_kind(*it) == T::Kind) {
+      if (trace) {
+        if (it == end) {
+          *trace << "NodeIdNot " << T::Kind << " error: no more children\n";
+        } else {
+          *trace << "NodeIdNot error: unexpected " << T::Kind << "\n";
+        }
+      }
+      return std::nullopt;
+    }
+    if (trace) {
+      *trace << "NodeIdNot " << T::Kind << ": " << tree->node_kind(*it)
+             << " consumed\n";
+    }
+    return NodeIdNot<T>(*it++);
+  }
+};
+
+// Extract an `llvm::SmallVector<T>` by extracting `T`s until we can't.
+template <typename T>
+struct Extractable<llvm::SmallVector<T>> {
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<llvm::SmallVector<T>> {
+    if (trace) {
+      *trace << "Vector: begin\n";
+    }
+    llvm::SmallVector<T> result;
+    while (it != end) {
+      auto old_it = it;
+      auto item = Extractable<T>::Extract(tree, it, end, trace);
+      if (!item.has_value()) {
+        it = old_it;
+        break;
+      }
+      result.push_back(*item);
+    }
+    std::reverse(result.begin(), result.end());
+    if (trace) {
+      *trace << "Vector: end\n";
+    }
+    return result;
+  }
+};
+
+// Extract an `optional<T>` from a list of child nodes by attempting to extract
+// a `T`, and extracting nothing if that fails.
+template <typename T>
+struct Extractable<std::optional<T>> {
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<std::optional<T>> {
+    if (trace) {
+      *trace << "Optional" << typeid(T).name() << ": begin\n";
+    }
+    auto old_it = it;
+    std::optional<T> value = Extractable<T>::Extract(tree, it, end, trace);
+    if (value) {
+      if (trace) {
+        *trace << "Optional" << typeid(T).name() << ": found\n";
+      }
+      return value;
+    }
+    if (trace) {
+      *trace << "Optional" << typeid(T).name() << ": missing\n";
+    }
+    it = old_it;
+    return value;
+  }
+};
+
+// Extract a `tuple<T...>` from a list of child nodes by extracting each `T` in
+// reverse order.
+template <typename... T>
+struct Extractable<std::tuple<T...>> {
+  template <std::size_t... Index>
+  static auto ExtractImpl(const Tree* tree, Tree::SiblingIterator& it,
+                          Tree::SiblingIterator end, ErrorBuilder* trace,
+                          std::index_sequence<Index...>)
+      -> std::optional<std::tuple<T...>> {
+    std::tuple<std::optional<T>...> fields;
+    if (trace) {
+      *trace << "Tuple: begin\n";
+    }
+
+    // Use a fold over the `=` operator to parse fields from right to left.
+    [[maybe_unused]] int unused;
+    bool ok = true;
+    static_cast<void>(
+        ((ok && (ok = (std::get<Index>(fields) =
+                           Extractable<T>::Extract(tree, it, end, trace))
+                          .has_value()),
+          unused) = ... = 0));
+
+    if (!ok) {
+      if (trace) {
+        *trace << "Tuple: error\n";
+      }
+      return std::nullopt;
+    }
+
+    if (trace) {
+      *trace << "Tuple: success\n";
+    }
+    return std::tuple<T...>{std::move(std::get<Index>(fields).value())...};
+  }
+
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<std::tuple<T...>> {
+    return ExtractImpl(tree, it, end, trace,
+                       std::make_index_sequence<sizeof...(T)>());
+  }
+};
+
+// Extract the fields of a simple aggregate type.
+template <typename T>
+struct Extractable {
+  static_assert(std::is_aggregate_v<T>, "Unsupported child type");
+  static auto ExtractImpl(const Tree* tree, Tree::SiblingIterator& it,
+                          Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<T> {
+    if (trace) {
+      *trace << "Aggregate " << typeid(T).name() << ": begin\n";
+    }
+    // Extract the corresponding tuple type.
+    using TupleType = decltype(StructReflection::AsTuple(std::declval<T>()));
+    auto tuple = Extractable<TupleType>::Extract(tree, it, end, trace);
+    if (!tuple.has_value()) {
+      if (trace) {
+        *trace << "Aggregate" << typeid(T).name() << ": error\n";
+      }
+      return std::nullopt;
+    }
+
+    if (trace) {
+      *trace << "Aggregate" << typeid(T).name() << ": success\n";
+    }
+    // Convert the tuple to the struct type.
+    return std::apply(
+        [](auto&&... value) {
+          return T{std::forward<decltype(value)>(value)...};
+        },
+        *tuple);
+  }
+
+  static auto Extract(const Tree* tree, Tree::SiblingIterator& it,
+                      Tree::SiblingIterator end, ErrorBuilder* trace)
+      -> std::optional<T> {
+    static_assert(!HasKindMember<T>, "Missing Id suffix");
+    return ExtractImpl(tree, it, end, trace);
+  }
+};
+
+template <typename T>
+auto Tree::TryExtractNodeFromChildren(
+    llvm::iterator_range<Tree::SiblingIterator> children,
+    ErrorBuilder* trace) const -> std::optional<T> {
+  auto it = children.begin();
+  auto result = Extractable<T>::ExtractImpl(this, it, children.end(), trace);
+  if (it != children.end()) {
+    if (trace) {
+      *trace << "Error: " << node_kind(*it) << " node left unconsumed.";
+    }
+    return std::nullopt;
+  }
+  return result;
+}
+
+// Manually instantiate Tree::TryExtractNodeFromChildren
+#define CARBON_PARSE_NODE_KIND(KindName)                    \
+  template auto Tree::TryExtractNodeFromChildren<KindName>( \
+      llvm::iterator_range<Tree::SiblingIterator> children, \
+      ErrorBuilder * trace) const -> std::optional<KindName>;
+
+// Also instantiate for `File`, even though it isn't a parse node.
+CARBON_PARSE_NODE_KIND(File)
+#include "toolchain/parse/node_kind.def"
+
+auto Tree::ExtractFile() const -> File {
+  return ExtractNodeFromChildren<File>(roots());
+}
+
+}  // namespace Carbon::Parse

+ 1 - 0
toolchain/parse/handle_brace_expr.cpp

@@ -149,6 +149,7 @@ static auto HandleBraceExprParamFinish(Context& context, NodeKind node_kind,
   if (state.has_error) {
   if (state.has_error) {
     context.AddLeafNode(NodeKind::InvalidParse, state.token,
     context.AddLeafNode(NodeKind::InvalidParse, state.token,
                         /*has_error=*/true);
                         /*has_error=*/true);
+    context.ReturnErrorOnState();
   } else {
   } else {
     context.AddNode(node_kind, state.token, state.subtree_start,
     context.AddNode(node_kind, state.token, state.subtree_start,
                     /*has_error=*/false);
                     /*has_error=*/false);

+ 7 - 5
toolchain/parse/handle_let.cpp

@@ -30,11 +30,13 @@ auto HandleLetAfterPattern(Context& context) -> void {
   if (auto equals = context.ConsumeIf(Lex::TokenKind::Equal)) {
   if (auto equals = context.ConsumeIf(Lex::TokenKind::Equal)) {
     context.AddLeafNode(NodeKind::LetInitializer, *equals);
     context.AddLeafNode(NodeKind::LetInitializer, *equals);
     context.PushState(State::Expr);
     context.PushState(State::Expr);
-  } else if (!state.has_error) {
-    CARBON_DIAGNOSTIC(
-        ExpectedInitializerAfterLet, Error,
-        "Expected `=`; `let` declaration must have an initializer.");
-    context.emitter().Emit(*context.position(), ExpectedInitializerAfterLet);
+  } else {
+    if (!state.has_error) {
+      CARBON_DIAGNOSTIC(
+          ExpectedInitializerAfterLet, Error,
+          "Expected `=`; `let` declaration must have an initializer.");
+      context.emitter().Emit(*context.position(), ExpectedInitializerAfterLet);
+    }
     context.ReturnErrorOnState();
     context.ReturnErrorOnState();
   }
   }
 }
 }

+ 99 - 0
toolchain/parse/node_ids.h

@@ -0,0 +1,99 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_PARSE_NODE_IDS_H_
+#define CARBON_TOOLCHAIN_PARSE_NODE_IDS_H_
+
+#include "toolchain/base/index_base.h"
+#include "toolchain/parse/node_kind.h"
+
+namespace Carbon::Parse {
+
+// A lightweight handle representing a node in the tree.
+//
+// Objects of this type are small and cheap to copy and store. They don't
+// contain any of the information about the node, and serve as a handle that
+// can be used with the underlying tree to query for detailed information.
+struct NodeId : public IdBase {
+  // An explicitly invalid instance.
+  static const NodeId Invalid;
+
+  using IdBase::IdBase;
+};
+
+constexpr NodeId NodeId::Invalid = NodeId(NodeId::InvalidIndex);
+
+// For looking up the type associated with a given id type.
+template <typename T>
+struct NodeForId;
+
+// `<KindName>Id` is a typed version of `NodeId` that references a node of kind
+// `<KindName>`:
+template <const NodeKind&>
+struct NodeIdForKind : public NodeId {
+  static const NodeIdForKind Invalid;
+
+  explicit NodeIdForKind(NodeId node_id) : NodeId(node_id) {}
+};
+template <const NodeKind& Kind>
+constexpr NodeIdForKind<Kind> NodeIdForKind<Kind>::Invalid =
+    NodeIdForKind(NodeId::Invalid.index);
+
+#define CARBON_PARSE_NODE_KIND(KindName) \
+  using KindName##Id = NodeIdForKind<NodeKind::KindName>;
+#include "toolchain/parse/node_kind.def"
+
+// NodeId that matches any NodeKind whose `category()` overlaps with `Category`.
+template <NodeCategory Category>
+struct NodeIdInCategory : public NodeId {
+  // An explicitly invalid instance.
+  static const NodeIdInCategory<Category> Invalid;
+
+  explicit NodeIdInCategory(NodeId node_id) : NodeId(node_id) {}
+};
+
+template <NodeCategory Category>
+constexpr NodeIdInCategory<Category> NodeIdInCategory<Category>::Invalid =
+    NodeIdInCategory<Category>(NodeId::InvalidIndex);
+
+// Aliases for `NodeIdInCategory` to describe particular categories of nodes.
+using AnyDeclId = NodeIdInCategory<NodeCategory::Decl>;
+using AnyExprId = NodeIdInCategory<NodeCategory::Expr>;
+using AnyModifierId = NodeIdInCategory<NodeCategory::Modifier>;
+using AnyNameComponentId = NodeIdInCategory<NodeCategory::NameComponent>;
+using AnyPatternId = NodeIdInCategory<NodeCategory::Pattern>;
+using AnyStatementId = NodeIdInCategory<NodeCategory::Statement>;
+
+// NodeId with kind that matches either T::Kind or U::Kind.
+template <typename T, typename U>
+struct NodeIdOneOf : public NodeId {
+  // An explicitly invalid instance.
+  static const NodeIdOneOf<T, U> Invalid;
+
+  explicit NodeIdOneOf(NodeId node_id) : NodeId(node_id) {}
+};
+
+template <typename T, typename U>
+constexpr NodeIdOneOf<T, U> NodeIdOneOf<T, U>::Invalid =
+    NodeIdOneOf<T, U>(NodeId::InvalidIndex);
+
+// NodeId with kind that is anything but T::Kind.
+template <typename T>
+struct NodeIdNot : public NodeId {
+  // An explicitly invalid instance.
+  static const NodeIdNot<T> Invalid;
+
+  explicit NodeIdNot(NodeId node_id) : NodeId(node_id) {}
+};
+
+template <typename T>
+constexpr NodeIdNot<T> NodeIdNot<T>::Invalid =
+    NodeIdNot<T>(NodeId::InvalidIndex);
+
+// Note that the support for extracting these types using the `Tree::Extract*`
+// functions is defined in `extract.cpp`.
+
+}  // namespace Carbon::Parse
+
+#endif  // CARBON_TOOLCHAIN_PARSE_NODE_IDS_H_

+ 13 - 0
toolchain/parse/node_kind.cpp

@@ -5,6 +5,7 @@
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/node_kind.h"
 
 
 #include "common/check.h"
 #include "common/check.h"
+#include "toolchain/parse/typed_nodes.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
 
 
@@ -76,4 +77,16 @@ auto NodeKind::CheckMatchesTokenKind(Lex::TokenKind token_kind, bool has_error)
       << ", but expected token kind " << expected_token_kind;
       << ", but expected token kind " << expected_token_kind;
 }
 }
 
 
+auto NodeKind::category() const -> NodeCategory {
+  return definition().category();
+}
+
+auto NodeKind::definition() const -> const Definition& {
+  static constexpr const Definition* Table[] = {
+#define CARBON_PARSE_NODE_KIND(Name) &Parse::Name::Kind,
+#include "toolchain/parse/node_kind.def"
+  };
+  return *Table[AsInt()];
+}
+
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 21 - 13
toolchain/parse/node_kind.def

@@ -203,6 +203,7 @@ CARBON_PARSE_NODE_KIND_BRACKET(ImportDirective, ImportIntroducer,
 // `library` as directive:
 // `library` as directive:
 //   LibraryIntroducer
 //   LibraryIntroducer
 //   DefaultLibrary or _external_: LibraryName
 //   DefaultLibrary or _external_: LibraryName
+//   PackageApi or PackageImpl
 // LibraryDirective
 // LibraryDirective
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(DefaultLibrary, 0, Default)
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(DefaultLibrary, 0, Default)
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(LibraryIntroducer, 0, Library)
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(LibraryIntroducer, 0, Library)
@@ -210,14 +211,14 @@ CARBON_PARSE_NODE_KIND_BRACKET(LibraryDirective, LibraryIntroducer,
                                CARBON_IF_VALID(Semi))
                                CARBON_IF_VALID(Semi))
 
 
 // `library` in `package` or `import`:
 // `library` in `package` or `import`:
-//   _external_: LibraryName
+//   _external_: LibraryName or DefaultLibrary
 // LibrarySpecifier
 // LibrarySpecifier
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(LibrarySpecifier, 1, Library)
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(LibrarySpecifier, 1, Library)
 
 
 // `namespace`:
 // `namespace`:
 //   NamespaceStart
 //   NamespaceStart
 //   _repeated_ _external_: modifier
 //   _repeated_ _external_: modifier
-//   _external_: Name or QualifiedDecl
+//   _external_: IdentifierName or QualifiedDecl
 // Namespace
 // Namespace
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(NamespaceStart, 0, Namespace)
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(NamespaceStart, 0, Namespace)
 CARBON_PARSE_NODE_KIND_BRACKET(Namespace, NamespaceStart, CARBON_IF_VALID(Semi))
 CARBON_PARSE_NODE_KIND_BRACKET(Namespace, NamespaceStart, CARBON_IF_VALID(Semi))
@@ -234,7 +235,8 @@ CARBON_PARSE_NODE_KIND_BRACKET(CodeBlock, CodeBlockStart,
 // `fn`:
 // `fn`:
 //     FunctionIntroducer
 //     FunctionIntroducer
 //     _repeated_ _external_: modifier
 //     _repeated_ _external_: modifier
-//     _external_: Name or QualifiedDecl
+//     _external_: IdentifierName or QualifiedDecl
+//     _optional_ _external_: ImplicitParamList
 //     _external_: TuplePattern
 //     _external_: TuplePattern
 //       _external_: type expression
 //       _external_: type expression
 //     ReturnType
 //     ReturnType
@@ -291,7 +293,7 @@ CARBON_PARSE_NODE_KIND_CHILD_COUNT(ArrayExprSemi, 2, CARBON_IF_VALID(Semi))
 CARBON_PARSE_NODE_KIND_BRACKET(ArrayExpr, ArrayExprSemi, CloseSquareBracket)
 CARBON_PARSE_NODE_KIND_BRACKET(ArrayExpr, ArrayExprSemi, CloseSquareBracket)
 
 
 // A binding pattern, such as `name: Type`:
 // A binding pattern, such as `name: Type`:
-//       Name or SelfValueName
+//       IdentifierName or SelfValueName
 //       _external_: type expression
 //       _external_: type expression
 //     [Generic]BindingPattern
 //     [Generic]BindingPattern
 //   _optional_ Address
 //   _optional_ Address
@@ -446,8 +448,8 @@ CARBON_PARSE_NODE_KIND_CHILD_COUNT(CallExprComma, 0, Comma)
 CARBON_PARSE_NODE_KIND_BRACKET(CallExpr, CallExprStart, CloseParen)
 CARBON_PARSE_NODE_KIND_BRACKET(CallExpr, CallExprStart, CloseParen)
 
 
 // A qualified declaration, such as `a.b`:
 // A qualified declaration, such as `a.b`:
-//   _external_: Name or QualifiedDecl
-//   _external_: Name
+//   _external_: IdentifierName or QualifiedDecl
+//   _external_: IdentifierName
 // QualifiedDecl
 // QualifiedDecl
 //
 //
 // TODO: This will eventually more general expressions, for example with
 // TODO: This will eventually more general expressions, for example with
@@ -458,14 +460,14 @@ CARBON_PARSE_NODE_KIND_CHILD_COUNT(QualifiedDecl, 2, Period)
 // `GetObject().(Interface.member)`:
 // `GetObject().(Interface.member)`:
 //   _external_: lhs expression
 //   _external_: lhs expression
 //   _external_: rhs expression
 //   _external_: rhs expression
-// QualifiedExpr
+// MemberAccessExpr
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(MemberAccessExpr, 2, Period)
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(MemberAccessExpr, 2, Period)
 
 
 // A pointer member access expression, such as `a->b` or
 // A pointer member access expression, such as `a->b` or
 // `GetObject()->(Interface.member)`:
 // `GetObject()->(Interface.member)`:
 //   _external_: lhs expression
 //   _external_: lhs expression
 //   _external_: rhs expression
 //   _external_: rhs expression
-// QualifiedExpr
+// PointerMemberAccessExpr
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(PointerMemberAccessExpr, 2, MinusGreater)
 CARBON_PARSE_NODE_KIND_CHILD_COUNT(PointerMemberAccessExpr, 2, MinusGreater)
 
 
 // A value literal.
 // A value literal.
@@ -558,7 +560,7 @@ CARBON_PARSE_NODE_KIND_CHILD_COUNT(IfExprElse, 3, CARBON_IF_VALID(Else))
 
 
 // Struct literals, such as `{.a = 0}`:
 // Struct literals, such as `{.a = 0}`:
 //   StructLiteralOrStructTypeLiteralStart
 //   StructLiteralOrStructTypeLiteralStart
-//         _external_: Name
+//         _external_: IdentifierName or BaseName
 //       StructFieldDesignator
 //       StructFieldDesignator
 //       _external_: expression
 //       _external_: expression
 //     StructFieldValue
 //     StructFieldValue
@@ -568,7 +570,7 @@ CARBON_PARSE_NODE_KIND_CHILD_COUNT(IfExprElse, 3, CARBON_IF_VALID(Else))
 //
 //
 // Struct type literals, such as `{.a: i32}`:
 // Struct type literals, such as `{.a: i32}`:
 //   StructLiteralOrStructTypeLiteralStart
 //   StructLiteralOrStructTypeLiteralStart
-//         _external_: Name
+//         _external_: IdentifierName or BaseName
 //       StructFieldDesignator
 //       StructFieldDesignator
 //       _external_: type expression
 //       _external_: type expression
 //     StructFieldType
 //     StructFieldType
@@ -609,7 +611,9 @@ CARBON_PARSE_NODE_KIND_TOKEN_MODIFIER(Virtual)
 // `class`:
 // `class`:
 //     ClassIntroducer
 //     ClassIntroducer
 //     _repeated_ _external_: modifier
 //     _repeated_ _external_: modifier
-//     _external_: Name or QualifiedDecl
+//     _external_: IdentifierName or QualifiedDecl
+//     _optional_ _external_: ImplicitParamList
+//     _optional_ _external_: TuplePattern
 //   ClassDefinitionStart
 //   ClassDefinitionStart
 //   _external_: declarations
 //   _external_: declarations
 // ClassDefinition
 // ClassDefinition
@@ -638,7 +642,9 @@ CARBON_PARSE_NODE_KIND_BRACKET(BaseDecl, BaseIntroducer, CARBON_IF_VALID(Semi))
 // `interface`:
 // `interface`:
 //     InterfaceIntroducer
 //     InterfaceIntroducer
 //     _repeated_ _external_: modifier
 //     _repeated_ _external_: modifier
-//     _external_: Name or QualifiedDecl
+//     _external_: IdentifierName or QualifiedDecl
+//     _optional_ _external_: ImplicitParamList
+//     _optional_ _external_: TuplePattern
 //   InterfaceDefinitionStart
 //   InterfaceDefinitionStart
 //   _external_: declarations
 //   _external_: declarations
 // InterfaceDefinition
 // InterfaceDefinition
@@ -684,7 +690,9 @@ CARBON_PARSE_NODE_KIND_BRACKET(ImplDecl, ImplIntroducer, CARBON_IF_VALID(Semi))
 // `constraint`:
 // `constraint`:
 //     NamedConstraintIntroducer
 //     NamedConstraintIntroducer
 //     _repeated_ _external_: modifier
 //     _repeated_ _external_: modifier
-//     _external_: Name or QualifiedDecl
+//     _external_: IdentifierName or QualifiedDecl
+//     _optional_ _external_: ImplicitParamList
+//     _optional_ _external_: TuplePattern
 //   NamedConstraintDefinitionStart
 //   NamedConstraintDefinitionStart
 //   _external_: declarations
 //   _external_: declarations
 // NamedConstraintDefinition
 // NamedConstraintDefinition

+ 71 - 0
toolchain/parse/node_kind.h

@@ -8,10 +8,33 @@
 #include <cstdint>
 #include <cstdint>
 
 
 #include "common/enum_base.h"
 #include "common/enum_base.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include "toolchain/lex/token_kind.h"
 #include "toolchain/lex/token_kind.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
 
 
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+
+// Represents a set of keyword modifiers, using a separate bit per modifier.
+//
+// We expect this to grow, so are using a bigger size than needed.
+// NOLINTNEXTLINE(performance-enum-size)
+enum class NodeCategory : uint32_t {
+  Decl = 1 << 0,
+  Expr = 1 << 1,
+  Modifier = 1 << 2,
+  NameComponent = 1 << 3,
+  Pattern = 1 << 4,
+  Statement = 1 << 5,
+  None = 0,
+
+  LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/Statement)
+};
+
+inline auto operator!(NodeCategory k) -> bool {
+  return !static_cast<uint32_t>(k);
+}
+
 CARBON_DEFINE_RAW_ENUM_CLASS(NodeKind, uint8_t) {
 CARBON_DEFINE_RAW_ENUM_CLASS(NodeKind, uint8_t) {
 #define CARBON_PARSE_NODE_KIND(Name) CARBON_RAW_ENUM_ENUMERATOR(Name)
 #define CARBON_PARSE_NODE_KIND(Name) CARBON_RAW_ENUM_ENUMERATOR(Name)
 #include "toolchain/parse/node_kind.def"
 #include "toolchain/parse/node_kind.def"
@@ -40,7 +63,21 @@ class NodeKind : public CARBON_ENUM_BASE(NodeKind) {
   // that has_bracket is false.
   // that has_bracket is false.
   auto child_count() const -> int32_t;
   auto child_count() const -> int32_t;
 
 
+  // Returns which categories this node kind is in.
+  auto category() const -> NodeCategory;
+
   using EnumBase::Create;
   using EnumBase::Create;
+
+  class Definition;
+
+  // Provides a definition for this parse node kind. Should only be called
+  // once, to construct the kind as part of defining it in `typed_nodes.h`.
+  constexpr auto Define(NodeCategory category = NodeCategory::None) const
+      -> Definition;
+
+ private:
+  // Looks up the definition for this instruction kind.
+  auto definition() const -> const Definition&;
 };
 };
 
 
 #define CARBON_PARSE_NODE_KIND(Name) \
 #define CARBON_PARSE_NODE_KIND(Name) \
@@ -50,6 +87,40 @@ class NodeKind : public CARBON_ENUM_BASE(NodeKind) {
 // We expect the parse node kind to fit compactly into 8 bits.
 // We expect the parse node kind to fit compactly into 8 bits.
 static_assert(sizeof(NodeKind) == 1, "Kind objects include padding!");
 static_assert(sizeof(NodeKind) == 1, "Kind objects include padding!");
 
 
+// A definition of a parse node kind. This is a NodeKind value, plus
+// ancillary data such as the name to use for the node kind in LLVM IR. These
+// are not copyable, and only one instance of this type is expected to exist per
+// parse node kind, specifically `TypedNode::Kind`. Use `NodeKind` instead as a
+// thin wrapper around a parse node kind index.
+class NodeKind::Definition : public NodeKind {
+ public:
+  // Not copyable.
+  Definition(const Definition&) = delete;
+  auto operator=(const Definition&) -> Definition& = delete;
+
+  // Returns which categories this node kind is in.
+  constexpr auto category() const -> NodeCategory { return category_; }
+
+ private:
+  friend class NodeKind;
+
+  constexpr Definition(NodeKind kind, NodeCategory category)
+      : NodeKind(kind), category_(category) {}
+
+  NodeCategory category_;
+};
+
+constexpr auto NodeKind::Define(NodeCategory category) const -> Definition {
+  return Definition(*this, category);
+}
+
+// HasKindMember<T> is true if T has a `static const NodeKind::Definition Kind`
+// member.
+template <typename T, typename KindType = const NodeKind::Definition*>
+inline constexpr bool HasKindMember = false;
+template <typename T>
+inline constexpr bool HasKindMember<T, decltype(&T::Kind)> = true;
+
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse
 
 
 #endif  // CARBON_TOOLCHAIN_PARSE_NODE_KIND_H_
 #endif  // CARBON_TOOLCHAIN_PARSE_NODE_KIND_H_

+ 1 - 1
toolchain/parse/testdata/let/fail_empty.carbon

@@ -16,6 +16,6 @@ let;
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: ';', has_error: yes},
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: ';', has_error: yes},
 // CHECK:STDOUT:         {kind: 'InvalidParse', text: ';', has_error: yes},
 // CHECK:STDOUT:         {kind: 'InvalidParse', text: ';', has_error: yes},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ';', has_error: yes, subtree_size: 3},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ';', has_error: yes, subtree_size: 3},
-// CHECK:STDOUT:     {kind: 'LetDecl', text: ';', subtree_size: 5},
+// CHECK:STDOUT:     {kind: 'LetDecl', text: ';', has_error: yes, subtree_size: 5},
 // CHECK:STDOUT:     {kind: 'FileEnd', text: ''},
 // CHECK:STDOUT:     {kind: 'FileEnd', text: ''},
 // CHECK:STDOUT:   ]
 // CHECK:STDOUT:   ]

+ 1 - 1
toolchain/parse/testdata/struct/fail_comma_only.carbon

@@ -17,7 +17,7 @@ var x: {,} = {};
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ',', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ',', has_error: yes},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 4},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 4},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 6},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 6},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_comma_repeat_in_type.carbon

@@ -22,7 +22,7 @@ var x: {.a: i32,,} = {};
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ',', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ',', has_error: yes},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
-// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', subtree_size: 9},
+// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', has_error: yes, subtree_size: 9},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 11},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 11},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_comma_repeat_in_value.carbon

@@ -22,7 +22,7 @@ var x: {.a = 0,,} = {};
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ',', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ',', has_error: yes},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 9},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 9},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 11},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 11},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_dot_only.carbon

@@ -18,7 +18,7 @@ var x: {.} = {};
 // CHECK:STDOUT:             {kind: 'IdentifierName', text: '}', has_error: yes},
 // CHECK:STDOUT:             {kind: 'IdentifierName', text: '}', has_error: yes},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '.', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '.', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 5},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 7},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 7},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_dot_string_colon.carbon

@@ -24,7 +24,7 @@ var x: {."hello": i32, .y: i32} = {};
 // CHECK:STDOUT:             {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:             {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:             {kind: 'IntTypeLiteral', text: 'i32'},
 // CHECK:STDOUT:             {kind: 'IntTypeLiteral', text: 'i32'},
 // CHECK:STDOUT:           {kind: 'StructFieldType', text: ':', subtree_size: 4},
 // CHECK:STDOUT:           {kind: 'StructFieldType', text: ':', subtree_size: 4},
-// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', subtree_size: 11},
+// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', has_error: yes, subtree_size: 11},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 13},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 13},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_dot_string_equals.carbon

@@ -24,7 +24,7 @@ var x: {."hello" = 0, .y = 4} = {};
 // CHECK:STDOUT:             {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:             {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:             {kind: 'IntLiteral', text: '4'},
 // CHECK:STDOUT:             {kind: 'IntLiteral', text: '4'},
 // CHECK:STDOUT:           {kind: 'StructFieldValue', text: '=', subtree_size: 4},
 // CHECK:STDOUT:           {kind: 'StructFieldValue', text: '=', subtree_size: 4},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 11},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 11},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 13},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 13},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_identifier_colon.carbon

@@ -16,7 +16,7 @@ var x: {a:} = {};
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'a', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'a', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 3},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 3},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_identifier_equals.carbon

@@ -16,7 +16,7 @@ var x: {a=} = {};
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'a', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'a', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 3},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 3},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_identifier_only.carbon

@@ -16,7 +16,7 @@ var x: {a} = {};
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'a', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'a', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 3},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 3},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_missing_type.carbon

@@ -19,7 +19,7 @@ var x: {.a:} = {};
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '}', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '}', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ':', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: ':', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', subtree_size: 6},
+// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', has_error: yes, subtree_size: 6},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 8},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 8},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_missing_value.carbon

@@ -19,7 +19,7 @@ var x: {.a=} = {};
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '}', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '}', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '=', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '=', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 6},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 6},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 8},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 8},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_mix_type_and_value.carbon

@@ -23,7 +23,7 @@ var x: {.a: i32, .b = 0} = {};
 // CHECK:STDOUT:             {kind: 'IdentifierName', text: 'b'},
 // CHECK:STDOUT:             {kind: 'IdentifierName', text: 'b'},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '.', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '.', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', subtree_size: 10},
+// CHECK:STDOUT:         {kind: 'StructTypeLiteral', text: '}', has_error: yes, subtree_size: 10},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 12},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 12},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_mix_value_and_type.carbon

@@ -21,7 +21,7 @@ var x: {.a = 0, b: i32} = {};
 // CHECK:STDOUT:           {kind: 'StructFieldValue', text: '=', subtree_size: 4},
 // CHECK:STDOUT:           {kind: 'StructFieldValue', text: '=', subtree_size: 4},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'StructComma', text: ','},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'b', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'b', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 8},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 8},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 10},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 10},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 2 - 2
toolchain/parse/testdata/struct/fail_mix_with_unknown.carbon

@@ -41,7 +41,7 @@ var x: i32 = {.a: i32, .b, .c = 1};
 // CHECK:STDOUT:           {kind: 'IdentifierName', text: 'c'},
 // CHECK:STDOUT:           {kind: 'IdentifierName', text: 'c'},
 // CHECK:STDOUT:         {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:         {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:         {kind: 'InvalidParse', text: '.', has_error: yes},
 // CHECK:STDOUT:         {kind: 'InvalidParse', text: '.', has_error: yes},
-// CHECK:STDOUT:       {kind: 'StructLiteral', text: '}', subtree_size: 14},
+// CHECK:STDOUT:       {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 14},
 // CHECK:STDOUT:     {kind: 'VariableDecl', text: ';', subtree_size: 20},
 // CHECK:STDOUT:     {kind: 'VariableDecl', text: ';', subtree_size: 20},
 // CHECK:STDOUT:       {kind: 'VariableIntroducer', text: 'var'},
 // CHECK:STDOUT:       {kind: 'VariableIntroducer', text: 'var'},
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
@@ -61,7 +61,7 @@ var x: i32 = {.a: i32, .b, .c = 1};
 // CHECK:STDOUT:           {kind: 'IdentifierName', text: 'c'},
 // CHECK:STDOUT:           {kind: 'IdentifierName', text: 'c'},
 // CHECK:STDOUT:         {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:         {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:         {kind: 'InvalidParse', text: '.', has_error: yes},
 // CHECK:STDOUT:         {kind: 'InvalidParse', text: '.', has_error: yes},
-// CHECK:STDOUT:       {kind: 'StructTypeLiteral', text: '}', subtree_size: 14},
+// CHECK:STDOUT:       {kind: 'StructTypeLiteral', text: '}', has_error: yes, subtree_size: 14},
 // CHECK:STDOUT:     {kind: 'VariableDecl', text: ';', subtree_size: 20},
 // CHECK:STDOUT:     {kind: 'VariableDecl', text: ';', subtree_size: 20},
 // CHECK:STDOUT:     {kind: 'FileEnd', text: ''},
 // CHECK:STDOUT:     {kind: 'FileEnd', text: ''},
 // CHECK:STDOUT:   ]
 // CHECK:STDOUT:   ]

+ 1 - 1
toolchain/parse/testdata/struct/fail_no_colon_or_equals.carbon

@@ -18,7 +18,7 @@ var x: {.a} = {};
 // CHECK:STDOUT:             {kind: 'IdentifierName', text: 'a'},
 // CHECK:STDOUT:             {kind: 'IdentifierName', text: 'a'},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'StructFieldDesignator', text: '.', subtree_size: 2},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '.', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: '.', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 5},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 7},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 7},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 1 - 1
toolchain/parse/testdata/struct/fail_type_no_designator.carbon

@@ -16,7 +16,7 @@ var x: {i32} = {};
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:         {kind: 'IdentifierName', text: 'x'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'i32', has_error: yes},
 // CHECK:STDOUT:           {kind: 'InvalidParse', text: 'i32', has_error: yes},
-// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', subtree_size: 3},
+// CHECK:STDOUT:         {kind: 'StructLiteral', text: '}', has_error: yes, subtree_size: 3},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'BindingPattern', text: ':', subtree_size: 5},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:       {kind: 'VariableInitializer', text: '='},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},
 // CHECK:STDOUT:         {kind: 'StructLiteralOrStructTypeLiteralStart', text: '{'},

+ 23 - 0
toolchain/parse/tree.cpp

@@ -12,6 +12,7 @@
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/context.h"
 #include "toolchain/parse/context.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/node_kind.h"
+#include "toolchain/parse/typed_nodes.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
 
 
@@ -219,6 +220,16 @@ auto Tree::Print(llvm::raw_ostream& output, bool preorder) const -> void {
   output << "  ]\n";
   output << "  ]\n";
 }
 }
 
 
+static auto TestExtract(const Tree* tree, NodeId node_id, NodeKind kind,
+                        ErrorBuilder* trace) -> bool {
+  switch (kind) {
+#define CARBON_PARSE_NODE_KIND(Name) \
+  case NodeKind::Name:               \
+    return tree->VerifyExtractAs<Name>(node_id, trace).has_value();
+#include "toolchain/parse/node_kind.def"
+  }
+}
+
 auto Tree::Verify() const -> ErrorOr<Success> {
 auto Tree::Verify() const -> ErrorOr<Success> {
   llvm::SmallVector<NodeId> nodes;
   llvm::SmallVector<NodeId> nodes;
   // Traverse the tree in postorder.
   // Traverse the tree in postorder.
@@ -235,6 +246,18 @@ auto Tree::Verify() const -> ErrorOr<Success> {
       return Error(llvm::formatv(
       return Error(llvm::formatv(
           "Node #{0} is a placeholder node that wasn't replaced.", n.index));
           "Node #{0} is a placeholder node that wasn't replaced.", n.index));
     }
     }
+    // Should extract successfully if node not marked as having an error.
+    // Without this code, a 10 mloc test case of lex & parse takes
+    // 4.129 s ±  0.041 s. With this additional verification, it takes
+    // 5.768 s ±  0.036 s.
+    if (!n_impl.has_error && !TestExtract(this, n, n_impl.kind, nullptr)) {
+      ErrorBuilder trace;
+      trace << llvm::formatv(
+          "NodeId #{0} couldn't be extracted as a {1}. Trace:\n", n,
+          n_impl.kind);
+      TestExtract(this, n, n_impl.kind, &trace);
+      return trace;
+    }
 
 
     int subtree_size = 1;
     int subtree_size = 1;
     if (n_impl.kind.has_bracket()) {
     if (n_impl.kind.has_bracket()) {

+ 123 - 18
toolchain/parse/tree.h

@@ -7,6 +7,7 @@
 
 
 #include <iterator>
 #include <iterator>
 
 
+#include "common/check.h"
 #include "common/error.h"
 #include "common/error.h"
 #include "common/ostream.h"
 #include "common/ostream.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVector.h"
@@ -14,23 +15,13 @@
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/ADT/iterator_range.h"
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/lex/tokenized_buffer.h"
+#include "toolchain/parse/node_ids.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/node_kind.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
 
 
-// A lightweight handle representing a node in the tree.
-//
-// Objects of this type are small and cheap to copy and store. They don't
-// contain any of the information about the node, and serve as a handle that
-// can be used with the underlying tree to query for detailed information.
-struct NodeId : public IdBase {
-  // An explicitly invalid instance.
-  static const NodeId Invalid;
-
-  using IdBase::IdBase;
-};
-
-constexpr NodeId NodeId::Invalid = NodeId(NodeId::InvalidIndex);
+// Defined in typed_nodes.h. Include that to call `Tree::ExtractFile()`.
+struct File;
 
 
 // A tree of parsed tokens based on the language grammar.
 // A tree of parsed tokens based on the language grammar.
 //
 //
@@ -120,6 +111,19 @@ class Tree : public Printable<Tree> {
 
 
   auto node_subtree_size(NodeId n) const -> int32_t;
   auto node_subtree_size(NodeId n) const -> int32_t;
 
 
+  // Returns whether this node is a valid node of the specified type.
+  template <typename T>
+  auto IsValid(NodeId node_id) const -> bool {
+    return node_kind(node_id) == T::Kind && !node_has_error(node_id);
+  }
+
+  template <typename IdT>
+  auto IsValid(IdT id) const -> bool {
+    using T = typename NodeForId<IdT>::TypedNode;
+    CARBON_DCHECK(node_kind(id) == T::Kind);
+    return !node_has_error(id);
+  }
+
   auto packaging_directive() const -> const std::optional<PackagingDirective>& {
   auto packaging_directive() const -> const std::optional<PackagingDirective>& {
     return packaging_directive_;
     return packaging_directive_;
   }
   }
@@ -168,14 +172,56 @@ class Tree : public Printable<Tree> {
   // line-oriented shell tools from `grep` to `awk`.
   // line-oriented shell tools from `grep` to `awk`.
   auto Print(llvm::raw_ostream& output, bool preorder) const -> void;
   auto Print(llvm::raw_ostream& output, bool preorder) const -> void;
 
 
+  // The following `Extract*` function provide an alternative way of accessing
+  // the nodes of a tree. It is intended to be more convenient and type-safe,
+  // but slower and can't be used on nodes that are marked as having an error.
+  // It is appropriate for uses that are less performance sensitive, like
+  // diagnostics. Example usage:
+  // ```
+  // auto file = tree->ExtractFile();
+  // for (AnyDeclId decl_id : file.decls) {
+  //   // `decl_id` is convertible to a `NodeId`.
+  //   if (std::optional<FunctionDecl> fn_decl =
+  //       tree->ExtractAs<FunctionDecl>(decl_id)) {
+  //     // fn_decl->params is a `TuplePatternId` (which extends `NodeId`)
+  //     // that is guaranteed to reference a `TuplePattern`.
+  //     std::optional<TuplePattern> params = tree->Extract(fn_decl->params);
+  //     // `params` has a value unless there was an error in that node.
+  //   } else if (auto class_def = tree->ExtractAs<ClassDefinition>(decl_id)) {
+  //     // ...
+  //   }
+  // }
+  // ```
+
+  // Extract a `File` object representing the parse tree for the whole file.
+  // #include "toolchain/parse/typed_nodes.h" to get the definition of `File`
+  // and the types representing its children nodes.
+  auto ExtractFile() const -> File;
+
+  // Converts this node_id to a typed node of a specified type, if it is a valid
+  // node of that kind.
+  template <typename T>
+  auto ExtractAs(NodeId node_id) const -> std::optional<T>;
+
+  // Converts to a typed node, if it is not an error.
+  template <typename IdT>
+  auto Extract(IdT id) const
+      -> std::optional<typename NodeForId<IdT>::TypedNode>;
+
   // Verifies the parse tree structure. Checks invariants of the parse tree
   // Verifies the parse tree structure. Checks invariants of the parse tree
   // structure and returns verification errors.
   // structure and returns verification errors.
   //
   //
-  // This is primarily intended to be used as a
-  // debugging aid. This routine doesn't directly CHECK so that it can be used
-  // within a debugger.
+  // This is fairly slow, and is primarily intended to be used as a debugging
+  // aid. This routine doesn't directly CHECK so that it can be used within a
+  // debugger.
   auto Verify() const -> ErrorOr<Success>;
   auto Verify() const -> ErrorOr<Success>;
 
 
+  // Like ExtractAs(), but malformed tree errors are not fatal. Should only be
+  // used by `Verify()`.
+  template <typename T>
+  auto VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
+      -> std::optional<T>;
+
  private:
  private:
   friend class Context;
   friend class Context;
 
 
@@ -245,6 +291,20 @@ class Tree : public Printable<Tree> {
   auto PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
   auto PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
                  bool preorder) const -> bool;
                  bool preorder) const -> bool;
 
 
+  // Extract a node of type `T` from a sibling range. This is expected to
+  // consume the complete sibling range. Malformed tree errors are written
+  // to `*trace`, if `trace != nullptr`.
+  template <typename T>
+  auto TryExtractNodeFromChildren(
+      llvm::iterator_range<Tree::SiblingIterator> children,
+      ErrorBuilder* trace) const -> std::optional<T>;
+
+  // Extract a node of type `T` from a sibling range. This is expected to
+  // consume the complete sibling range. Malformed tree errors are fatal.
+  template <typename T>
+  auto ExtractNodeFromChildren(
+      llvm::iterator_range<Tree::SiblingIterator> children) const -> T;
+
   // Depth-first postorder sequence of node implementation data.
   // Depth-first postorder sequence of node implementation data.
   llvm::SmallVector<NodeImpl> node_impls_;
   llvm::SmallVector<NodeImpl> node_impls_;
 
 
@@ -270,7 +330,7 @@ class Tree : public Printable<Tree> {
 class Tree::PostorderIterator
 class Tree::PostorderIterator
     : public llvm::iterator_facade_base<PostorderIterator,
     : public llvm::iterator_facade_base<PostorderIterator,
                                         std::random_access_iterator_tag, NodeId,
                                         std::random_access_iterator_tag, NodeId,
-                                        int, NodeId*, NodeId>,
+                                        int, const NodeId*, NodeId>,
       public Printable<Tree::PostorderIterator> {
       public Printable<Tree::PostorderIterator> {
  public:
  public:
   PostorderIterator() = delete;
   PostorderIterator() = delete;
@@ -322,7 +382,7 @@ class Tree::PostorderIterator
 class Tree::SiblingIterator
 class Tree::SiblingIterator
     : public llvm::iterator_facade_base<SiblingIterator,
     : public llvm::iterator_facade_base<SiblingIterator,
                                         std::forward_iterator_tag, NodeId, int,
                                         std::forward_iterator_tag, NodeId, int,
-                                        NodeId*, NodeId>,
+                                        const NodeId*, NodeId>,
       public Printable<Tree::SiblingIterator> {
       public Printable<Tree::SiblingIterator> {
  public:
  public:
   explicit SiblingIterator() = delete;
   explicit SiblingIterator() = delete;
@@ -353,6 +413,51 @@ class Tree::SiblingIterator
   NodeId node_;
   NodeId node_;
 };
 };
 
 
+template <typename T>
+auto Tree::ExtractNodeFromChildren(
+    llvm::iterator_range<Tree::SiblingIterator> children) const -> T {
+  auto result = TryExtractNodeFromChildren<T>(children, nullptr);
+  if (!result.has_value()) {
+    // On error try again, this time capturing a trace.
+    ErrorBuilder trace;
+    TryExtractNodeFromChildren<T>(children, &trace);
+    CARBON_FATAL() << "Malformed parse node:\n" << Error(trace).message();
+  }
+  return *result;
+}
+
+template <typename T>
+auto Tree::ExtractAs(NodeId node_id) const -> std::optional<T> {
+  static_assert(HasKindMember<T>, "Not a parse node type");
+  if (!IsValid<T>(node_id)) {
+    return std::nullopt;
+  }
+
+  return ExtractNodeFromChildren<T>(children(node_id));
+}
+
+template <typename T>
+auto Tree::VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
+    -> std::optional<T> {
+  static_assert(HasKindMember<T>, "Not a parse node type");
+  if (!IsValid<T>(node_id)) {
+    return std::nullopt;
+  }
+
+  return TryExtractNodeFromChildren<T>(children(node_id), trace);
+}
+
+template <typename IdT>
+auto Tree::Extract(IdT id) const
+    -> std::optional<typename NodeForId<IdT>::TypedNode> {
+  if (!IsValid(id)) {
+    return std::nullopt;
+  }
+
+  using T = typename NodeForId<IdT>::TypedNode;
+  return ExtractNodeFromChildren<T>(children(id));
+}
+
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse
 
 
 #endif  // CARBON_TOOLCHAIN_PARSE_TREE_H_
 #endif  // CARBON_TOOLCHAIN_PARSE_TREE_H_

+ 909 - 0
toolchain/parse/typed_nodes.h

@@ -0,0 +1,909 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_PARSE_TYPED_NODES_H_
+#define CARBON_TOOLCHAIN_PARSE_TYPED_NODES_H_
+
+#include "toolchain/parse/node_ids.h"
+#include "toolchain/parse/node_kind.h"
+
+namespace Carbon::Parse {
+
+// Helpers for defining different kinds of parse nodes.
+// ----------------------------------------------------
+
+// A pair of a list item and its optional following comma.
+template <typename Element, typename Comma>
+struct ListItem {
+  Element value;
+  std::optional<Comma> comma;
+};
+
+// A list of items, parameterized by the kind of the elements and comma.
+template <typename Element, typename Comma>
+using CommaSeparatedList = llvm::SmallVector<ListItem<Element, Comma>>;
+
+// This class provides a shorthand for defining parse node kinds for leaf nodes.
+template <const NodeKind& KindT, NodeCategory Category = NodeCategory::None>
+struct LeafNode {
+  static constexpr auto Kind = KindT.Define(Category);
+};
+
+// ----------------------------------------------------------------------------
+// Each node kind (in node_kind.def) should have a corresponding type defined
+// here which describes the expected child structure of that parse node.
+//
+// Each of these types should start with a `static constexpr Kind` member
+// initialized by calling `Define` on the corresponding `NodeKind`, and passing
+// in the `NodeCategory` of that kind.  This will both associate the category
+// with the node kind and create the necessary kind object for the typed node.
+//
+// This should be followed by field declarations that describe the child nodes,
+// in order, that occur in the parse tree. The `Extract...` functions on the
+// parse tree use struct reflection on these fields to guide the extraction of
+// the child nodes from the tree into an object of this type with these fields
+// for convenient access.
+//
+// The types of these fields are special and describe the specific child node
+// structure of the parse node. Many of these types are defined in `node_ids.h`.
+//
+// Valid primitive types here are:
+// - `NodeId` to match any single child node
+// - `FooId` to require that child to have kind `NodeKind::Foo`
+// - `AnyCatId` to require that child to have a kind in category `Cat`
+// - `NodeIdOneOf<A, B>` to require the child to have kind `NodeKind::A` or
+// `NodeKind::B`
+// - `NodeIdNot<A>` to match any single child whose kind is not `NodeKind::A`
+//
+// There a few, restricted composite field types allowed that compose types in
+// various ways, where all of the `T`s and `U`s below are themselves valid field
+// types:
+// - `llvm::SmallVector<T>` to match any number of children matching `T`
+// - `std::optional<T>` to match 0 or 1 children matching `T`
+// - `std::tuple<T...>` to match children matching `T...`
+// - Any provided `Aggregate` type that is a simple aggregate type such as
+// `struct Aggregate { T x; U y; }`,
+//   to match children with types `T` and `U`.
+// ----------------------------------------------------------------------------
+
+// Error nodes
+// -----------
+
+// An invalid parse. Used to balance the parse tree. This type is here only to
+// ensure we have a type for each parse node kind. This node kind always has an
+// error, so can never be extracted.
+using InvalidParse =
+    LeafNode<NodeKind::InvalidParse, NodeCategory::Decl | NodeCategory::Expr>;
+
+// An invalid subtree. Always has an error so can never be extracted.
+using InvalidParseStart = LeafNode<NodeKind::InvalidParseStart>;
+struct InvalidParseSubtree {
+  static constexpr auto Kind =
+      NodeKind::InvalidParseSubtree.Define(NodeCategory::Decl);
+
+  InvalidParseStartId start;
+  llvm::SmallVector<NodeIdNot<InvalidParseStart>> extra;
+};
+
+// A placeholder node to be replaced; it will never exist in a valid parse tree.
+// Its token kind is not enforced even when valid.
+using Placeholder = LeafNode<NodeKind::Placeholder>;
+
+// File nodes
+// ----------
+
+// The start of the file.
+using FileStart = LeafNode<NodeKind::FileStart>;
+
+// The end of the file.
+using FileEnd = LeafNode<NodeKind::FileEnd>;
+
+// General-purpose nodes
+// ---------------------
+
+// An empty declaration, such as `;`.
+using EmptyDecl =
+    LeafNode<NodeKind::EmptyDecl, NodeCategory::Decl | NodeCategory::Statement>;
+
+// A name in a non-expression context, such as a declaration.
+using IdentifierName =
+    LeafNode<NodeKind::IdentifierName, NodeCategory::NameComponent>;
+
+// A name in an expression context.
+using IdentifierNameExpr =
+    LeafNode<NodeKind::IdentifierNameExpr, NodeCategory::Expr>;
+
+// The `self` value and `Self` type identifier keywords. Typically of the form
+// `self: Self`.
+using SelfValueName = LeafNode<NodeKind::SelfValueName>;
+using SelfValueNameExpr =
+    LeafNode<NodeKind::SelfValueNameExpr, NodeCategory::Expr>;
+using SelfTypeNameExpr =
+    LeafNode<NodeKind::SelfTypeNameExpr, NodeCategory::Expr>;
+
+// The `base` value keyword, introduced by `base: B`. Typically referenced in
+// an expression, as in `x.base` or `{.base = ...}`, but can also be used as a
+// declared name, as in `{.base: partial B}`.
+using BaseName = LeafNode<NodeKind::BaseName>;
+
+// A qualified name: `A.B`.
+//
+// TODO: This is not a declaration. Rename this parse node.
+struct QualifiedDecl {
+  static constexpr auto Kind =
+      NodeKind::QualifiedDecl.Define(NodeCategory::NameComponent);
+
+  // For now, this is either an IdentifierName or a QualifiedDecl.
+  AnyNameComponentId lhs;
+
+  // TODO: This will eventually need to support more general expressions, for
+  // example `GenericType(type_args).ChildType(child_type_args).Name`.
+  IdentifierNameId rhs;
+};
+
+// Library, package, import
+// ------------------------
+
+// The `package` keyword in an expression.
+using PackageExpr = LeafNode<NodeKind::PackageExpr, NodeCategory::Expr>;
+
+// The name of a package or library for `package`, `import`, and `library`.
+using PackageName = LeafNode<NodeKind::PackageName>;
+using LibraryName = LeafNode<NodeKind::LibraryName>;
+using DefaultLibrary = LeafNode<NodeKind::DefaultLibrary>;
+
+using PackageIntroducer = LeafNode<NodeKind::PackageIntroducer>;
+using PackageApi = LeafNode<NodeKind::PackageApi>;
+using PackageImpl = LeafNode<NodeKind::PackageImpl>;
+
+// `library` in `package` or `import`.
+struct LibrarySpecifier {
+  static constexpr auto Kind = NodeKind::LibrarySpecifier.Define();
+
+  NodeIdOneOf<LibraryName, DefaultLibrary> name;
+};
+
+// First line of the file, such as:
+//   `package MyPackage library "MyLibrary" impl;`
+struct PackageDirective {
+  static constexpr auto Kind = NodeKind::PackageDirective.Define();
+
+  PackageIntroducerId introducer;
+  std::optional<PackageNameId> name;
+  std::optional<LibrarySpecifierId> library;
+  NodeIdOneOf<PackageApi, PackageImpl> api_or_impl;
+};
+
+// `import TheirPackage library "TheirLibrary";`
+using ImportIntroducer = LeafNode<NodeKind::ImportIntroducer>;
+struct ImportDirective {
+  static constexpr auto Kind = NodeKind::ImportDirective.Define();
+
+  ImportIntroducerId introducer;
+  std::optional<PackageNameId> name;
+  std::optional<LibrarySpecifierId> library;
+};
+
+// `library` as directive.
+using LibraryIntroducer = LeafNode<NodeKind::LibraryIntroducer>;
+struct LibraryDirective {
+  static constexpr auto Kind = NodeKind::LibraryDirective.Define();
+
+  LibraryIntroducerId introducer;
+  NodeIdOneOf<LibraryName, DefaultLibrary> library_name;
+  NodeIdOneOf<PackageApi, PackageImpl> api_or_impl;
+};
+
+// Namespace nodes
+// ---------------
+
+using NamespaceStart = LeafNode<NodeKind::NamespaceStart>;
+
+// A namespace: `namespace N;`.
+struct Namespace {
+  static constexpr auto Kind = NodeKind::Namespace.Define(NodeCategory::Decl);
+
+  NamespaceStartId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  NodeIdOneOf<IdentifierName, QualifiedDecl> name;
+};
+
+// Pattern nodes
+// -------------
+
+// A pattern binding, such as `name: Type`.
+struct BindingPattern {
+  static constexpr auto Kind =
+      NodeKind::BindingPattern.Define(NodeCategory::Pattern);
+
+  NodeIdOneOf<IdentifierName, SelfValueName> name;
+  AnyExprId type;
+};
+
+// `name:! Type`
+struct GenericBindingPattern {
+  static constexpr auto Kind =
+      NodeKind::GenericBindingPattern.Define(NodeCategory::Pattern);
+
+  NodeIdOneOf<IdentifierName, SelfValueName> name;
+  AnyExprId type;
+};
+
+// An address-of binding: `addr self: Self*`.
+struct Address {
+  static constexpr auto Kind = NodeKind::Address.Define(NodeCategory::Pattern);
+
+  AnyPatternId inner;
+};
+
+// A template binding: `template T:! type`.
+struct Template {
+  static constexpr auto Kind = NodeKind::Template.Define(NodeCategory::Pattern);
+
+  // This is a GenericBindingPatternId in any valid program.
+  // TODO: Should the parser enforce that?
+  AnyPatternId inner;
+};
+
+using TuplePatternStart = LeafNode<NodeKind::TuplePatternStart>;
+using PatternListComma = LeafNode<NodeKind::PatternListComma>;
+
+// A parameter list or tuple pattern: `(a: i32, b: i32)`.
+struct TuplePattern {
+  static constexpr auto Kind =
+      NodeKind::TuplePattern.Define(NodeCategory::Pattern);
+
+  TuplePatternStartId left_paren;
+  CommaSeparatedList<AnyPatternId, PatternListCommaId> params;
+};
+
+using ImplicitParamListStart = LeafNode<NodeKind::ImplicitParamListStart>;
+
+// An implicit parameter list: `[T:! type, self: Self]`.
+struct ImplicitParamList {
+  static constexpr auto Kind = NodeKind::ImplicitParamList.Define();
+
+  ImplicitParamListStartId left_square;
+  CommaSeparatedList<AnyPatternId, PatternListCommaId> params;
+};
+
+// Function nodes
+// --------------
+
+using FunctionIntroducer = LeafNode<NodeKind::FunctionIntroducer>;
+
+// A return type: `-> i32`.
+struct ReturnType {
+  static constexpr auto Kind = NodeKind::ReturnType.Define();
+
+  AnyExprId type;
+};
+
+// A function signature: `fn F() -> i32`.
+template <const NodeKind& KindT>
+struct FunctionSignature {
+  static constexpr auto Kind = KindT.Define(NodeCategory::Decl);
+
+  FunctionIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  // For now, this is either an IdentifierName or a QualifiedDecl.
+  AnyNameComponentId name;
+  std::optional<ImplicitParamListId> implicit_params;
+  TuplePatternId params;
+  std::optional<ReturnTypeId> return_type;
+};
+
+using FunctionDecl = FunctionSignature<NodeKind::FunctionDecl>;
+using FunctionDefinitionStart =
+    FunctionSignature<NodeKind::FunctionDefinitionStart>;
+
+// A function definition: `fn F() -> i32 { ... }`.
+struct FunctionDefinition {
+  static constexpr auto Kind =
+      NodeKind::FunctionDefinition.Define(NodeCategory::Decl);
+
+  FunctionDefinitionStartId signature;
+  llvm::SmallVector<AnyStatementId> body;
+};
+
+// `let` nodes
+// -----------
+
+using LetIntroducer = LeafNode<NodeKind::LetIntroducer>;
+using LetInitializer = LeafNode<NodeKind::LetInitializer>;
+
+// A `let` declaration: `let a: i32 = 5;`.
+struct LetDecl {
+  static constexpr auto Kind =
+      NodeKind::LetDecl.Define(NodeCategory::Decl | NodeCategory::Statement);
+
+  LetIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  AnyPatternId pattern;
+  LetInitializerId equals;
+  AnyExprId initializer;
+};
+
+// `var` nodes
+// -----------
+
+using VariableIntroducer = LeafNode<NodeKind::VariableIntroducer>;
+using ReturnedModifier = LeafNode<NodeKind::ReturnedModifier>;
+using VariableInitializer = LeafNode<NodeKind::VariableInitializer>;
+
+// A `var` declaration: `var a: i32;` or `var a: i32 = 5;`.
+struct VariableDecl {
+  static constexpr auto Kind = NodeKind::VariableDecl.Define(
+      NodeCategory::Decl | NodeCategory::Statement);
+
+  VariableIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  std::optional<ReturnedModifierId> returned;
+  AnyPatternId pattern;
+
+  struct Initializer {
+    VariableInitializerId equals;
+    AnyExprId value;
+  };
+  std::optional<Initializer> initializer;
+};
+
+// Statement nodes
+// ---------------
+
+using CodeBlockStart = LeafNode<NodeKind::CodeBlockStart>;
+
+// A code block: `{ statement; statement; ... }`.
+struct CodeBlock {
+  static constexpr auto Kind = NodeKind::CodeBlock.Define();
+
+  CodeBlockStartId left_brace;
+  llvm::SmallVector<AnyStatementId> statements;
+};
+
+// An expression statement: `F(x);`.
+struct ExprStatement {
+  static constexpr auto Kind =
+      NodeKind::ExprStatement.Define(NodeCategory::Statement);
+
+  AnyExprId expr;
+};
+
+using BreakStatementStart = LeafNode<NodeKind::BreakStatementStart>;
+
+// A break statement: `break;`.
+struct BreakStatement {
+  static constexpr auto Kind =
+      NodeKind::BreakStatement.Define(NodeCategory::Statement);
+
+  BreakStatementStartId introducer;
+};
+
+using ContinueStatementStart = LeafNode<NodeKind::ContinueStatementStart>;
+
+// A continue statement: `continue;`.
+struct ContinueStatement {
+  static constexpr auto Kind =
+      NodeKind::ContinueStatement.Define(NodeCategory::Statement);
+
+  ContinueStatementStartId introducer;
+};
+
+using ReturnStatementStart = LeafNode<NodeKind::ReturnStatementStart>;
+using ReturnVarModifier = LeafNode<NodeKind::ReturnVarModifier>;
+
+// A return statement: `return;` or `return expr;` or `return var;`.
+struct ReturnStatement {
+  static constexpr auto Kind =
+      NodeKind::ReturnStatement.Define(NodeCategory::Statement);
+
+  ReturnStatementStartId introducer;
+  std::optional<AnyExprId> expr;
+  std::optional<ReturnVarModifierId> var;
+};
+
+using ForHeaderStart = LeafNode<NodeKind::ForHeaderStart>;
+
+// The `var ... in` portion of a `for` statement.
+struct ForIn {
+  static constexpr auto Kind = NodeKind::ForIn.Define();
+
+  VariableIntroducerId introducer;
+  AnyPatternId pattern;
+};
+
+// The `for (var ... in ...)` portion of a `for` statement.
+struct ForHeader {
+  static constexpr auto Kind = NodeKind::ForHeader.Define();
+
+  ForHeaderStartId introducer;
+  ForInId var;
+  AnyExprId range;
+};
+
+// A complete `for (...) { ... }` statement.
+struct ForStatement {
+  static constexpr auto Kind =
+      NodeKind::ForStatement.Define(NodeCategory::Statement);
+
+  ForHeaderId header;
+  CodeBlockId body;
+};
+
+using IfConditionStart = LeafNode<NodeKind::IfConditionStart>;
+
+// The condition portion of an `if` statement: `(expr)`.
+struct IfCondition {
+  static constexpr auto Kind = NodeKind::IfCondition.Define();
+
+  IfConditionStartId left_paren;
+  AnyExprId condition;
+};
+
+using IfStatementElse = LeafNode<NodeKind::IfStatementElse>;
+
+// An `if` statement: `if (expr) { ... } else { ... }`.
+struct IfStatement {
+  static constexpr auto Kind =
+      NodeKind::IfStatement.Define(NodeCategory::Statement);
+
+  IfConditionId head;
+  CodeBlockId then;
+
+  struct Else {
+    IfStatementElseId else_token;
+    NodeIdOneOf<CodeBlock, IfStatement> body;
+  };
+  std::optional<Else> else_clause;
+};
+
+using WhileConditionStart = LeafNode<NodeKind::WhileConditionStart>;
+
+// The condition portion of a `while` statement: `(expr)`.
+struct WhileCondition {
+  static constexpr auto Kind = NodeKind::WhileCondition.Define();
+
+  WhileConditionStartId left_paren;
+  AnyExprId condition;
+};
+
+// A `while` statement: `while (expr) { ... }`.
+struct WhileStatement {
+  static constexpr auto Kind =
+      NodeKind::WhileStatement.Define(NodeCategory::Statement);
+
+  WhileConditionId head;
+  CodeBlockId body;
+};
+
+// Expression nodes
+// ----------------
+
+using ArrayExprStart = LeafNode<NodeKind::ArrayExprStart, NodeCategory::Expr>;
+
+// The start of an array type, `[i32;`.
+//
+// TODO: Consider flattening this into `ArrayExpr`.
+struct ArrayExprSemi {
+  static constexpr auto Kind = NodeKind::ArrayExprSemi.Define();
+
+  ArrayExprStartId left_square;
+  AnyExprId type;
+};
+
+// An array type, such as  `[i32; 3]` or `[i32;]`.
+struct ArrayExpr {
+  static constexpr auto Kind = NodeKind::ArrayExpr.Define(NodeCategory::Expr);
+
+  ArrayExprSemiId start;
+  std::optional<AnyExprId> bound;
+};
+
+// The opening portion of an indexing expression: `a[`.
+//
+// TODO: Consider flattening this into `IndexExpr`.
+struct IndexExprStart {
+  static constexpr auto Kind = NodeKind::IndexExprStart.Define();
+
+  AnyExprId sequence;
+};
+
+// An indexing expression, such as `a[1]`.
+struct IndexExpr {
+  static constexpr auto Kind = NodeKind::IndexExpr.Define(NodeCategory::Expr);
+
+  IndexExprStartId start;
+  AnyExprId index;
+};
+
+using ExprOpenParen = LeafNode<NodeKind::ExprOpenParen>;
+
+// A parenthesized expression: `(a)`.
+struct ParenExpr {
+  static constexpr auto Kind = NodeKind::ParenExpr.Define(NodeCategory::Expr);
+
+  ExprOpenParenId left_paren;
+  AnyExprId expr;
+};
+
+using TupleLiteralComma = LeafNode<NodeKind::TupleLiteralComma>;
+
+// A tuple literal: `()`, `(a, b, c)`, or `(a,)`.
+struct TupleLiteral {
+  static constexpr auto Kind =
+      NodeKind::TupleLiteral.Define(NodeCategory::Expr);
+
+  ExprOpenParenId left_paren;
+  CommaSeparatedList<AnyExprId, TupleLiteralCommaId> elements;
+};
+
+// The opening portion of a call expression: `F(`.
+//
+// TODO: Consider flattening this into `CallExpr`.
+struct CallExprStart {
+  static constexpr auto Kind = NodeKind::CallExprStart.Define();
+
+  AnyExprId callee;
+};
+
+using CallExprComma = LeafNode<NodeKind::CallExprComma>;
+
+// A call expression: `F(a, b, c)`.
+struct CallExpr {
+  static constexpr auto Kind = NodeKind::CallExpr.Define(NodeCategory::Expr);
+
+  CallExprStartId start;
+  CommaSeparatedList<AnyExprId, CallExprCommaId> arguments;
+};
+
+// A simple member access expression: `a.b`.
+struct MemberAccessExpr {
+  static constexpr auto Kind =
+      NodeKind::MemberAccessExpr.Define(NodeCategory::Expr);
+
+  AnyExprId lhs;
+  // TODO: Figure out which nodes can appear here
+  NodeId rhs;
+};
+
+// A simple indirect member access expression: `a->b`.
+struct PointerMemberAccessExpr {
+  static constexpr auto Kind =
+      NodeKind::PointerMemberAccessExpr.Define(NodeCategory::Expr);
+
+  AnyExprId lhs;
+  // TODO: Figure out which nodes can appear here
+  NodeId rhs;
+};
+
+// A prefix operator expression.
+template <const NodeKind& KindT>
+struct PrefixOperator {
+  static constexpr auto Kind = KindT.Define(NodeCategory::Expr);
+
+  AnyExprId operand;
+};
+
+// An infix operator expression.
+template <const NodeKind& KindT>
+struct InfixOperator {
+  static constexpr auto Kind = KindT.Define(NodeCategory::Expr);
+
+  AnyExprId lhs;
+  AnyExprId rhs;
+};
+
+// A postfix operator expression.
+template <const NodeKind& KindT>
+struct PostfixOperator {
+  static constexpr auto Kind = KindT.Define(NodeCategory::Expr);
+
+  AnyExprId operand;
+};
+
+// Literals, operators, and modifiers
+
+#define CARBON_PARSE_NODE_KIND(...)
+#define CARBON_PARSE_NODE_KIND_TOKEN_LITERAL(Name, ...) \
+  using Name = LeafNode<NodeKind::Name, NodeCategory::Expr>;
+#define CARBON_PARSE_NODE_KIND_TOKEN_MODIFIER(Name, ...) \
+  using Name##Modifier =                                 \
+      LeafNode<NodeKind::Name##Modifier, NodeCategory::Modifier>;
+#define CARBON_PARSE_NODE_KIND_PREFIX_OPERATOR(Name, ...) \
+  using PrefixOperator##Name = PrefixOperator<NodeKind::PrefixOperator##Name>;
+#define CARBON_PARSE_NODE_KIND_INFIX_OPERATOR(Name, ...) \
+  using InfixOperator##Name = InfixOperator<NodeKind::InfixOperator##Name>;
+#define CARBON_PARSE_NODE_KIND_POSTFIX_OPERATOR(Name, ...) \
+  using PostfixOperator##Name =                            \
+      PostfixOperator<NodeKind::PostfixOperator##Name>;
+#include "toolchain/parse/node_kind.def"
+
+// The first operand of a short-circuiting infix operator: `a and` or `a or`.
+// The complete operator expression will be an InfixOperator with this as the
+// `lhs`.
+// TODO: Make this be a template if we ever need to write generic code to cover
+// both cases at once, say in check.
+struct ShortCircuitOperandAnd {
+  static constexpr auto Kind = NodeKind::ShortCircuitOperandAnd.Define();
+
+  AnyExprId operand;
+};
+
+struct ShortCircuitOperandOr {
+  static constexpr auto Kind = NodeKind::ShortCircuitOperandOr.Define();
+
+  AnyExprId operand;
+};
+
+struct ShortCircuitOperatorAnd {
+  static constexpr auto Kind =
+      NodeKind::ShortCircuitOperatorAnd.Define(NodeCategory::Expr);
+
+  ShortCircuitOperandAndId lhs;
+  AnyExprId rhs;
+};
+
+struct ShortCircuitOperatorOr {
+  static constexpr auto Kind =
+      NodeKind::ShortCircuitOperatorOr.Define(NodeCategory::Expr);
+
+  ShortCircuitOperandOrId lhs;
+  AnyExprId rhs;
+};
+
+// The `if` portion of an `if` expression: `if expr`.
+struct IfExprIf {
+  static constexpr auto Kind = NodeKind::IfExprIf.Define();
+
+  AnyExprId condition;
+};
+
+// The `then` portion of an `if` expression: `then expr`.
+struct IfExprThen {
+  static constexpr auto Kind = NodeKind::IfExprThen.Define();
+
+  AnyExprId result;
+};
+
+// A full `if` expression: `if expr then expr else expr`.
+struct IfExprElse {
+  static constexpr auto Kind = NodeKind::IfExprElse.Define(NodeCategory::Expr);
+
+  IfExprIfId start;
+  IfExprThenId then;
+  AnyExprId else_result;
+};
+
+// Struct literals and struct type literals
+// ----------------------------------------
+
+// `{`
+using StructLiteralOrStructTypeLiteralStart =
+    LeafNode<NodeKind::StructLiteralOrStructTypeLiteralStart>;
+// `,`
+using StructComma = LeafNode<NodeKind::StructComma>;
+
+// `.a`
+struct StructFieldDesignator {
+  static constexpr auto Kind = NodeKind::StructFieldDesignator.Define();
+
+  NodeIdOneOf<IdentifierName, BaseName> name;
+};
+
+// `.a = 0`
+struct StructFieldValue {
+  static constexpr auto Kind = NodeKind::StructFieldValue.Define();
+
+  StructFieldDesignatorId designator;
+  AnyExprId expr;
+};
+
+// `.a: i32`
+struct StructFieldType {
+  static constexpr auto Kind = NodeKind::StructFieldType.Define();
+
+  StructFieldDesignatorId designator;
+  AnyExprId type_expr;
+};
+
+// Struct literals, such as `{.a = 0}`.
+struct StructLiteral {
+  static constexpr auto Kind =
+      NodeKind::StructLiteral.Define(NodeCategory::Expr);
+
+  StructLiteralOrStructTypeLiteralStartId introducer;
+  CommaSeparatedList<StructFieldValueId, StructCommaId> fields;
+};
+
+// Struct type literals, such as `{.a: i32}`.
+struct StructTypeLiteral {
+  static constexpr auto Kind =
+      NodeKind::StructTypeLiteral.Define(NodeCategory::Expr);
+
+  StructLiteralOrStructTypeLiteralStartId introducer;
+  CommaSeparatedList<StructFieldTypeId, StructCommaId> fields;
+};
+
+// `class` declarations and definitions
+// ------------------------------------
+
+// `class`
+using ClassIntroducer = LeafNode<NodeKind::ClassIntroducer>;
+
+// A class signature `class C`
+template <const NodeKind& KindT, NodeCategory Category>
+struct ClassSignature {
+  static constexpr auto Kind = KindT.Define(Category);
+
+  ClassIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  AnyNameComponentId name;
+  std::optional<ImplicitParamListId> implicit_params;
+  std::optional<TuplePatternId> params;
+};
+
+// `class C;`
+using ClassDecl = ClassSignature<NodeKind::ClassDecl, NodeCategory::Decl>;
+// `class C {`
+using ClassDefinitionStart =
+    ClassSignature<NodeKind::ClassDefinitionStart, NodeCategory::None>;
+
+// `class C { ... }`
+struct ClassDefinition {
+  static constexpr auto Kind =
+      NodeKind::ClassDefinition.Define(NodeCategory::Decl);
+
+  ClassDefinitionStartId signature;
+  llvm::SmallVector<AnyDeclId> members;
+};
+
+// Base class declaration
+// ----------------------
+
+// `base`
+using BaseIntroducer = LeafNode<NodeKind::BaseIntroducer>;
+using BaseColon = LeafNode<NodeKind::BaseColon>;
+// `extend base: BaseClass;`
+struct BaseDecl {
+  static constexpr auto Kind = NodeKind::BaseDecl.Define(NodeCategory::Decl);
+
+  BaseIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  BaseColonId colon;
+  AnyExprId base_class;
+};
+
+// Interface declarations and definitions
+// --------------------------------------
+
+// `interface`
+using InterfaceIntroducer = LeafNode<NodeKind::InterfaceIntroducer>;
+
+// `interface I`
+template <const NodeKind& KindT, NodeCategory Category>
+struct InterfaceSignature {
+  static constexpr auto Kind = KindT.Define(Category);
+
+  InterfaceIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  AnyNameComponentId name;
+  std::optional<ImplicitParamListId> implicit_params;
+  std::optional<TuplePatternId> params;
+};
+
+// `interface I;`
+using InterfaceDecl =
+    InterfaceSignature<NodeKind::InterfaceDecl, NodeCategory::Decl>;
+// `interface I {`
+using InterfaceDefinitionStart =
+    InterfaceSignature<NodeKind::InterfaceDefinitionStart, NodeCategory::None>;
+
+// `interface I { ... }`
+struct InterfaceDefinition {
+  static constexpr auto Kind =
+      NodeKind::InterfaceDefinition.Define(NodeCategory::Decl);
+
+  InterfaceDefinitionStartId signature;
+  llvm::SmallVector<AnyDeclId> members;
+};
+
+// `impl`...`as` declarations and definitions
+// ------------------------------------------
+
+// `impl`
+using ImplIntroducer = LeafNode<NodeKind::ImplIntroducer>;
+// `as`
+using ImplAs = LeafNode<NodeKind::ImplAs>;
+
+// `forall [...]`
+struct ImplForall {
+  static constexpr auto Kind = NodeKind::ImplForall.Define();
+
+  ImplicitParamListId params;
+};
+
+// `impl T as I`
+template <const NodeKind& KindT, NodeCategory Category>
+struct ImplSignature {
+  static constexpr auto Kind = KindT.Define(Category);
+
+  ImplIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  std::optional<ImplForallId> forall;
+  std::optional<AnyExprId> type_expr;
+  ImplAsId as;
+  AnyExprId interface;
+};
+
+// `impl T as I;`
+using ImplDecl = ImplSignature<NodeKind::ImplDecl, NodeCategory::Decl>;
+// `impl T as I {`
+using ImplDefinitionStart =
+    ImplSignature<NodeKind::ImplDefinitionStart, NodeCategory::None>;
+
+// `impl T as I { ... }`
+struct ImplDefinition {
+  static constexpr auto Kind =
+      NodeKind::ImplDefinition.Define(NodeCategory::Decl);
+
+  ImplDefinitionStartId signature;
+  llvm::SmallVector<AnyDeclId> members;
+};
+
+// Named constraint declarations and definitions
+// ---------------------------------------------
+
+// `constraint`
+using NamedConstraintIntroducer = LeafNode<NodeKind::NamedConstraintIntroducer>;
+
+// `constraint NC`
+template <const NodeKind& KindT, NodeCategory Category>
+struct NamedConstraintSignature {
+  static constexpr auto Kind = KindT.Define(Category);
+
+  NamedConstraintIntroducerId introducer;
+  llvm::SmallVector<AnyModifierId> modifiers;
+  AnyNameComponentId name;
+  std::optional<ImplicitParamListId> implicit_params;
+  std::optional<TuplePatternId> params;
+};
+
+// `constraint NC;`
+using NamedConstraintDecl =
+    NamedConstraintSignature<NodeKind::NamedConstraintDecl, NodeCategory::Decl>;
+// `constraint NC {`
+using NamedConstraintDefinitionStart =
+    NamedConstraintSignature<NodeKind::NamedConstraintDefinitionStart,
+                             NodeCategory::None>;
+
+// `constraint NC { ... }`
+struct NamedConstraintDefinition {
+  static constexpr auto Kind =
+      NodeKind::NamedConstraintDefinition.Define(NodeCategory::Decl);
+
+  NamedConstraintDefinitionStartId signature;
+  llvm::SmallVector<AnyDeclId> members;
+};
+
+// ---------------------------------------------------------------------------
+
+// A complete source file. Note that there is no corresponding parse node for
+// the file. The file is instead the complete contents of the parse tree.
+struct File {
+  FileStartId start;
+  llvm::SmallVector<AnyDeclId> decls;
+  FileEndId end;
+};
+
+// Define `Foo` as the node type for the ID type `FooId`.
+#define CARBON_PARSE_NODE_KIND(KindName) \
+  template <>                            \
+  struct NodeForId<KindName##Id> {       \
+    using TypedNode = KindName;          \
+  };
+#include "toolchain/parse/node_kind.def"
+
+}  // namespace Carbon::Parse
+
+#endif  // CARBON_TOOLCHAIN_PARSE_TYPED_NODES_H_

+ 152 - 0
toolchain/parse/typed_nodes_test.cpp

@@ -0,0 +1,152 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "toolchain/parse/typed_nodes.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <forward_list>
+
+#include "toolchain/diagnostics/mocks.h"
+#include "toolchain/lex/lex.h"
+#include "toolchain/lex/tokenized_buffer.h"
+#include "toolchain/parse/tree.h"
+
+namespace Carbon::Parse {
+namespace {
+
+// Check that each node kind defines a Kind member using the correct
+// NodeKind enumerator.
+#define CARBON_PARSE_NODE_KIND(Name) \
+  static_assert(Name::Kind == NodeKind::Name, #Name);
+#include "toolchain/parse/node_kind.def"
+
+class TypedNodeTest : public ::testing::Test {
+ protected:
+  auto GetSourceBuffer(llvm::StringRef t) -> SourceBuffer& {
+    CARBON_CHECK(fs_.addFile("test.carbon", /*ModificationTime=*/0,
+                             llvm::MemoryBuffer::getMemBuffer(t)));
+    source_storage_.push_front(std::move(
+        *SourceBuffer::CreateFromFile(fs_, "test.carbon", consumer_)));
+    return source_storage_.front();
+  }
+
+  auto GetTokenizedBuffer(llvm::StringRef t) -> Lex::TokenizedBuffer& {
+    token_storage_.push_front(
+        Lex::Lex(value_stores_, GetSourceBuffer(t), consumer_));
+    return token_storage_.front();
+  }
+
+  auto GetTree(llvm::StringRef t) -> Tree& {
+    tree_storage_.push_front(Tree::Parse(GetTokenizedBuffer(t), consumer_,
+                                         /*vlog_stream=*/nullptr));
+    return tree_storage_.front();
+  }
+
+  SharedValueStores value_stores_;
+  llvm::vfs::InMemoryFileSystem fs_;
+  std::forward_list<SourceBuffer> source_storage_;
+  std::forward_list<Lex::TokenizedBuffer> token_storage_;
+  std::forward_list<Tree> tree_storage_;
+  DiagnosticConsumer& consumer_ = ConsoleDiagnosticConsumer();
+};
+
+TEST_F(TypedNodeTest, Empty) {
+  auto* tree = &GetTree("");
+  auto file = tree->ExtractFile();
+
+  EXPECT_TRUE(tree->IsValid(file.start));
+  EXPECT_TRUE(tree->ExtractAs<FileStart>(file.start).has_value());
+  EXPECT_TRUE(tree->Extract(file.start).has_value());
+
+  EXPECT_TRUE(tree->IsValid(file.end));
+  EXPECT_TRUE(tree->ExtractAs<FileEnd>(file.end).has_value());
+  EXPECT_TRUE(tree->Extract(file.end).has_value());
+
+  EXPECT_FALSE(tree->IsValid<FileEnd>(file.start));
+  EXPECT_FALSE(tree->ExtractAs<FileEnd>(file.start).has_value());
+}
+
+TEST_F(TypedNodeTest, Function) {
+  auto* tree = &GetTree(R"carbon(
+    fn F() {}
+    virtual fn G() -> i32;
+  )carbon");
+  auto file = tree->ExtractFile();
+
+  ASSERT_EQ(file.decls.size(), 2);
+
+  auto f_fn = tree->ExtractAs<FunctionDefinition>(file.decls[0]);
+  ASSERT_TRUE(f_fn.has_value());
+  auto f_sig = tree->Extract(f_fn->signature);
+  ASSERT_TRUE(f_sig.has_value());
+  EXPECT_FALSE(f_sig->return_type.has_value());
+  EXPECT_TRUE(f_sig->modifiers.empty());
+
+  auto g_fn = tree->ExtractAs<FunctionDecl>(file.decls[1]);
+  ASSERT_TRUE(g_fn.has_value());
+  EXPECT_TRUE(g_fn->return_type.has_value());
+  EXPECT_FALSE(g_fn->modifiers.empty());
+}
+
+TEST_F(TypedNodeTest, ModifierOrder) {
+  auto* tree = &GetTree(R"carbon(
+    private abstract virtual default interface I;
+  )carbon");
+  auto file = tree->ExtractFile();
+
+  ASSERT_EQ(file.decls.size(), 1);
+
+  auto decl = tree->ExtractAs<InterfaceDecl>(file.decls[0]);
+  ASSERT_TRUE(decl.has_value());
+  ASSERT_EQ(decl->modifiers.size(), 4);
+  // Note that the order here matches the source order, but is reversed from
+  // sibling iteration order.
+  ASSERT_TRUE(tree->ExtractAs<PrivateModifier>(decl->modifiers[0]).has_value());
+  ASSERT_TRUE(
+      tree->ExtractAs<AbstractModifier>(decl->modifiers[1]).has_value());
+  ASSERT_TRUE(tree->ExtractAs<VirtualModifier>(decl->modifiers[2]).has_value());
+  ASSERT_TRUE(tree->ExtractAs<DefaultModifier>(decl->modifiers[3]).has_value());
+}
+
+TEST_F(TypedNodeTest, For) {
+  auto* tree = &GetTree(R"carbon(
+    fn F(arr: [i32; 5]) {
+      for (var v: i32 in arr) {
+        Print(v);
+      }
+    }
+  )carbon");
+  auto file = tree->ExtractFile();
+
+  ASSERT_EQ(file.decls.size(), 1);
+  auto fn = tree->ExtractAs<FunctionDefinition>(file.decls[0]);
+  ASSERT_TRUE(fn.has_value());
+  ASSERT_EQ(fn->body.size(), 1);
+  auto for_stmt = tree->ExtractAs<ForStatement>(fn->body[0]);
+  ASSERT_TRUE(for_stmt.has_value());
+  auto for_header = tree->Extract(for_stmt->header);
+  ASSERT_TRUE(for_header.has_value());
+  auto for_var = tree->Extract(for_header->var);
+  ASSERT_TRUE(for_var.has_value());
+  auto for_var_binding = tree->ExtractAs<BindingPattern>(for_var->pattern);
+  ASSERT_TRUE(for_var_binding.has_value());
+  auto for_var_name = tree->ExtractAs<IdentifierName>(for_var_binding->name);
+  ASSERT_TRUE(for_var_name.has_value());
+}
+
+auto CategoryMatches(const NodeKind::Definition& def, NodeKind kind,
+                     const char* name) {
+  EXPECT_EQ(def.category(), kind.category()) << name;
+}
+
+TEST_F(TypedNodeTest, CategoryMatches) {
+#define CARBON_PARSE_NODE_KIND(Name) \
+  CategoryMatches(Name::Kind, NodeKind::Name, #Name);
+#include "toolchain/parse/node_kind.def"
+}
+
+}  // namespace
+}  // namespace Carbon::Parse

+ 5 - 5
toolchain/sem_ir/inst.h

@@ -27,7 +27,7 @@ struct TypedInstArgsInfo {
   using Tuple = decltype(StructReflection::AsTuple(std::declval<TypedInst>()));
   using Tuple = decltype(StructReflection::AsTuple(std::declval<TypedInst>()));
 
 
   static constexpr int FirstArgField =
   static constexpr int FirstArgField =
-      HasParseNode<TypedInst> + HasTypeId<TypedInst>;
+      HasParseNodeMember<TypedInst> + HasTypeIdMember<TypedInst>;
 
 
   static constexpr int NumArgs = std::tuple_size_v<Tuple> - FirstArgField;
   static constexpr int NumArgs = std::tuple_size_v<Tuple> - FirstArgField;
   static_assert(NumArgs <= 2,
   static_assert(NumArgs <= 2,
@@ -73,10 +73,10 @@ class Inst : public Printable<Inst> {
         type_id_(TypeId::Invalid),
         type_id_(TypeId::Invalid),
         arg0_(InstId::InvalidIndex),
         arg0_(InstId::InvalidIndex),
         arg1_(InstId::InvalidIndex) {
         arg1_(InstId::InvalidIndex) {
-    if constexpr (HasParseNode<TypedInst>) {
+    if constexpr (HasParseNodeMember<TypedInst>) {
       parse_node_ = typed_inst.parse_node;
       parse_node_ = typed_inst.parse_node;
     }
     }
-    if constexpr (HasTypeId<TypedInst>) {
+    if constexpr (HasTypeIdMember<TypedInst>) {
       type_id_ = typed_inst.type_id;
       type_id_ = typed_inst.type_id;
     }
     }
     if constexpr (Info::NumArgs > 0) {
     if constexpr (Info::NumArgs > 0) {
@@ -100,7 +100,7 @@ class Inst : public Printable<Inst> {
     CARBON_CHECK(Is<TypedInst>()) << "Casting inst of kind " << kind()
     CARBON_CHECK(Is<TypedInst>()) << "Casting inst of kind " << kind()
                                   << " to wrong kind " << TypedInst::Kind;
                                   << " to wrong kind " << TypedInst::Kind;
     auto build_with_type_id_and_args = [&](auto... type_id_and_args) {
     auto build_with_type_id_and_args = [&](auto... type_id_and_args) {
-      if constexpr (HasParseNode<TypedInst>) {
+      if constexpr (HasParseNodeMember<TypedInst>) {
         return TypedInst{parse_node(), type_id_and_args...};
         return TypedInst{parse_node(), type_id_and_args...};
       } else {
       } else {
         return TypedInst{type_id_and_args...};
         return TypedInst{type_id_and_args...};
@@ -108,7 +108,7 @@ class Inst : public Printable<Inst> {
     };
     };
 
 
     auto build_with_args = [&](auto... args) {
     auto build_with_args = [&](auto... args) {
-      if constexpr (HasTypeId<TypedInst>) {
+      if constexpr (HasTypeIdMember<TypedInst>) {
         return build_with_type_id_and_args(type_id(), args...);
         return build_with_type_id_and_args(type_id(), args...);
       } else {
       } else {
         return build_with_type_id_and_args(args...);
         return build_with_type_id_and_args(args...);

+ 1 - 1
toolchain/sem_ir/inst_kind.cpp

@@ -20,7 +20,7 @@ auto InstKind::ir_name() const -> llvm::StringLiteral {
 auto InstKind::value_kind() const -> InstValueKind {
 auto InstKind::value_kind() const -> InstValueKind {
   static constexpr InstValueKind Table[] = {
   static constexpr InstValueKind Table[] = {
 #define CARBON_SEM_IR_INST_KIND(Name) \
 #define CARBON_SEM_IR_INST_KIND(Name) \
-  HasTypeId<SemIR::Name> ? InstValueKind::Typed : InstValueKind::None,
+  HasTypeIdMember<SemIR::Name> ? InstValueKind::Typed : InstValueKind::None,
 #include "toolchain/sem_ir/inst_kind.def"
 #include "toolchain/sem_ir/inst_kind.def"
   };
   };
   return Table[AsInt()];
   return Table[AsInt()];

+ 6 - 6
toolchain/sem_ir/typed_insts.h

@@ -616,17 +616,17 @@ struct VarStorage {
   NameId name_id;
   NameId name_id;
 };
 };
 
 
-// HasParseNode<T> is true if T has a `Parse::NodeId parse_node` field.
+// HasParseNodeMember<T> is true if T has a `Parse::NodeId parse_node` field.
 template <typename T, typename ParseNodeType = Parse::NodeId T::*>
 template <typename T, typename ParseNodeType = Parse::NodeId T::*>
-inline constexpr bool HasParseNode = false;
+inline constexpr bool HasParseNodeMember = false;
 template <typename T>
 template <typename T>
-inline constexpr bool HasParseNode<T, decltype(&T::parse_node)> = true;
+inline constexpr bool HasParseNodeMember<T, decltype(&T::parse_node)> = true;
 
 
-// HasTypeId<T> is true if T has a `TypeId type_id` field.
+// HasTypeIdMember<T> is true if T has a `TypeId type_id` field.
 template <typename T, typename TypeIdType = TypeId T::*>
 template <typename T, typename TypeIdType = TypeId T::*>
-inline constexpr bool HasTypeId = false;
+inline constexpr bool HasTypeIdMember = false;
 template <typename T>
 template <typename T>
-inline constexpr bool HasTypeId<T, decltype(&T::type_id)> = true;
+inline constexpr bool HasTypeIdMember<T, decltype(&T::type_id)> = true;
 
 
 }  // namespace Carbon::SemIR
 }  // namespace Carbon::SemIR
 
 

+ 6 - 5
toolchain/sem_ir/typed_insts_test.cpp

@@ -43,10 +43,10 @@ template <typename TypedInst>
 auto CommonFieldOrder() -> void {
 auto CommonFieldOrder() -> void {
   Inst inst = MakeInstWithNumberedFields(TypedInst::Kind);
   Inst inst = MakeInstWithNumberedFields(TypedInst::Kind);
   auto typed = inst.As<TypedInst>();
   auto typed = inst.As<TypedInst>();
-  if constexpr (HasParseNode<TypedInst>) {
+  if constexpr (HasParseNodeMember<TypedInst>) {
     EXPECT_EQ(typed.parse_node, Parse::NodeId(1));
     EXPECT_EQ(typed.parse_node, Parse::NodeId(1));
   }
   }
-  if constexpr (HasTypeId<TypedInst>) {
+  if constexpr (HasTypeIdMember<TypedInst>) {
     EXPECT_EQ(typed.type_id, TypeId(2));
     EXPECT_EQ(typed.type_id, TypeId(2));
   }
   }
 }
 }
@@ -77,7 +77,8 @@ auto RoundTrip() -> void {
   auto typed1 = inst1.As<TypedInst>();
   auto typed1 = inst1.As<TypedInst>();
   Inst inst2 = typed1;
   Inst inst2 = typed1;
 
 
-  ExpectEqInsts(inst1, inst2, HasParseNode<TypedInst>, HasTypeId<TypedInst>);
+  ExpectEqInsts(inst1, inst2, HasParseNodeMember<TypedInst>,
+                HasTypeIdMember<TypedInst>);
 
 
   // If the typed instruction has no padding, we should get exactly the same
   // If the typed instruction has no padding, we should get exactly the same
   // thing if we convert back from an instruction.
   // thing if we convert back from an instruction.
@@ -129,8 +130,8 @@ auto StructLayout() -> void {
   if constexpr (std::has_unique_object_representations_v<TypedInst>) {
   if constexpr (std::has_unique_object_representations_v<TypedInst>) {
     auto typed =
     auto typed =
         MakeInstWithNumberedFields(TypedInst::Kind).template As<TypedInst>();
         MakeInstWithNumberedFields(TypedInst::Kind).template As<TypedInst>();
-    StructLayoutHelper(&typed, sizeof(typed), HasParseNode<TypedInst>,
-                       HasTypeId<TypedInst>);
+    StructLayoutHelper(&typed, sizeof(typed), HasParseNodeMember<TypedInst>,
+                       HasTypeIdMember<TypedInst>);
   }
   }
 }
 }