Procházet zdrojové kódy

Separate subtree size information from parse nodes. (#4174)

Move subtree sizes over to TreeAndSubtrees, using the different
structure to represent the additional parse work that occurs, as well as
making it clear which functions require the extra information. My intent
is to make it hard to use this by accident.

The subtree size is still tracked during Parse::Tree construction. I
think a lot of that can be cleaned up, although we use it during
placeholder assignment so it may take some work. I wanted to see what
people thought about this before taking action on such a change.

I'm using a 1m line source file generated by #4124 for testing. Command
is `time bazel-bin/toolchain/install/prefix_root/bin/carbon compile
--phase=check --dump-mem-usage ~/tmp/data.carbon`

At head, what I'm seeing is:

```
...
parse_tree_.node_impls_:
  used_bytes:      61516116
  reserved_bytes:  61516116
...
Total:
  used_bytes:      447814230
  reserved_bytes:  551663894
...
1.43s user 0.14s system 99% cpu 1.565 total
```

With `Tree::Verify` disabled completely, it looks like:
```
parse_tree_.node_impls_:
  used_bytes:      41010744
  reserved_bytes:  41010744
...
Total:
  used_bytes:      427308858
  reserved_bytes:  531158522
...
1.20s user 0.13s system 99% cpu 1.332 total
```

Re-enabling just the basic verification (what is now `Tree::Verify`),
I'm seeing maybe 0.05s slower, but that's within noise for my system. I
do see variability in my timing results, and overall I think this is a
0.2s +/- 0.1s improvement versus the earlier (always testing `Extract`
code) implementation. That's opt; debug builds will be unaffected,
because the same checking occurs as before.

Note, the subtree size is a third of the node representation, which is
why I'm showing the decrease in memory usage here.
Jon Ross-Perkins před 1 rokem
rodič
revize
f67791cfee
40 změnil soubory, kde provedl 767 přidání a 692 odebrání
  1. 8 4
      language_server/language_server.cpp
  2. 3 2
      toolchain/check/check.cpp
  3. 3 0
      toolchain/check/check.h
  4. 5 2
      toolchain/check/context.cpp
  5. 11 0
      toolchain/check/context.h
  6. 2 1
      toolchain/check/handle_impl.cpp
  7. 1 0
      toolchain/driver/BUILD
  8. 30 6
      toolchain/driver/driver.cpp
  9. 5 1
      toolchain/parse/BUILD
  10. 15 23
      toolchain/parse/context.cpp
  11. 4 6
      toolchain/parse/context.h
  12. 24 19
      toolchain/parse/extract.cpp
  13. 2 4
      toolchain/parse/handle_array_expr.cpp
  14. 3 5
      toolchain/parse/handle_binding_pattern.cpp
  15. 2 4
      toolchain/parse/handle_brace_expr.cpp
  16. 2 4
      toolchain/parse/handle_call_expr.cpp
  17. 4 4
      toolchain/parse/handle_choice.cpp
  18. 2 4
      toolchain/parse/handle_code_block.cpp
  19. 2 4
      toolchain/parse/handle_decl_definition.cpp
  20. 2 3
      toolchain/parse/handle_decl_name_and_params.cpp
  21. 2 4
      toolchain/parse/handle_decl_scope_loop.cpp
  22. 9 15
      toolchain/parse/handle_expr.cpp
  23. 6 9
      toolchain/parse/handle_function.cpp
  24. 2 4
      toolchain/parse/handle_impl.cpp
  25. 2 2
      toolchain/parse/handle_import_and_package.cpp
  26. 1 1
      toolchain/parse/handle_index_expr.cpp
  27. 1 2
      toolchain/parse/handle_let.cpp
  28. 18 28
      toolchain/parse/handle_match.cpp
  29. 2 4
      toolchain/parse/handle_paren_expr.cpp
  30. 1 1
      toolchain/parse/handle_pattern_list.cpp
  31. 3 4
      toolchain/parse/handle_period.cpp
  32. 5 9
      toolchain/parse/handle_statement.cpp
  33. 2 4
      toolchain/parse/handle_var.cpp
  34. 22 245
      toolchain/parse/tree.cpp
  35. 7 246
      toolchain/parse/tree.h
  36. 244 0
      toolchain/parse/tree_and_subtrees.cpp
  37. 277 0
      toolchain/parse/tree_and_subtrees.h
  38. 16 8
      toolchain/parse/tree_node_diagnostic_converter.h
  39. 5 2
      toolchain/parse/tree_test.cpp
  40. 12 8
      toolchain/parse/typed_nodes_test.cpp

+ 8 - 4
language_server/language_server.cpp

@@ -10,6 +10,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/source/source_buffer.h"
 
 namespace Carbon::LS {
@@ -79,11 +80,12 @@ auto LanguageServer::onReply(llvm::json::Value /*id*/,
 // Returns the text of first child of kind Parse::NodeKind::IdentifierName.
 static auto GetIdentifierName(const SharedValueStores& value_stores,
                               const Lex::TokenizedBuffer& tokens,
-                              const Parse::Tree& p, Parse::NodeId node)
+                              const Parse::TreeAndSubtrees& p,
+                              Parse::NodeId node)
     -> std::optional<llvm::StringRef> {
   for (auto ch : p.children(node)) {
-    if (p.node_kind(ch) == Parse::NodeKind::IdentifierName) {
-      auto token = p.node_token(ch);
+    if (p.tree().node_kind(ch) == Parse::NodeKind::IdentifierName) {
+      auto token = p.tree().node_token(ch);
       if (tokens.GetKind(token) == Lex::TokenKind::Identifier) {
         return value_stores.identifiers().Get(tokens.GetIdentifier(token));
       }
@@ -104,6 +106,7 @@ void LanguageServer::OnDocumentSymbol(
   auto buf = SourceBuffer::MakeFromFile(vfs, file, NullDiagnosticConsumer());
   auto lexed = Lex::Lex(value_stores, *buf, NullDiagnosticConsumer());
   auto parsed = Parse::Parse(lexed, NullDiagnosticConsumer(), nullptr);
+  Parse::TreeAndSubtrees tree_and_subtrees(lexed, parsed);
   std::vector<clang::clangd::DocumentSymbol> result;
   for (const auto& node : parsed.postorder()) {
     clang::clangd::SymbolKind symbol_kind;
@@ -126,7 +129,8 @@ void LanguageServer::OnDocumentSymbol(
         continue;
     }
 
-    if (auto name = GetIdentifierName(value_stores, lexed, parsed, node)) {
+    if (auto name =
+            GetIdentifierName(value_stores, lexed, tree_and_subtrees, node)) {
       auto tok = parsed.node_token(node);
       clang::clangd::Position pos{lexed.GetLineNumber(tok) - 1,
                                   lexed.GetColumnNumber(tok) - 1};

+ 3 - 2
toolchain/check/check.cpp

@@ -65,7 +65,7 @@ struct UnitInfo {
       : check_ir_id(check_ir_id),
         unit(&unit),
         converter(unit.tokens, unit.tokens->source().filename(),
-                  unit.parse_tree),
+                  unit.get_parse_tree_and_subtrees),
         err_tracker(*unit.consumer),
         emitter(converter, err_tracker) {}
 
@@ -891,7 +891,8 @@ static auto CheckParseTree(
   SemIRDiagnosticConverter converter(node_converters, &sem_ir);
   Context::DiagnosticEmitter emitter(converter, unit_info.err_tracker);
   Context context(*unit_info.unit->tokens, emitter, *unit_info.unit->parse_tree,
-                  sem_ir, vlog_stream);
+                  unit_info.unit->get_parse_tree_and_subtrees, sem_ir,
+                  vlog_stream);
   PrettyStackTraceFunction context_dumper(
       [&](llvm::raw_ostream& output) { context.PrintForStackDump(output); });
 

+ 3 - 0
toolchain/check/check.h

@@ -10,6 +10,7 @@
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/sem_ir/file.h"
 
 namespace Carbon::Check {
@@ -20,6 +21,8 @@ struct Unit {
   const Lex::TokenizedBuffer* tokens;
   const Parse::Tree* parse_tree;
   DiagnosticConsumer* consumer;
+  // Returns a lazily constructed TreeAndSubtrees.
+  std::function<const Parse::TreeAndSubtrees&()> get_parse_tree_and_subtrees;
   // The generated IR. Unset on input, set on output.
   std::optional<SemIR::File>* sem_ir;
 };

+ 5 - 2
toolchain/check/context.cpp

@@ -36,11 +36,14 @@
 namespace Carbon::Check {
 
 Context::Context(const Lex::TokenizedBuffer& tokens, DiagnosticEmitter& emitter,
-                 const Parse::Tree& parse_tree, SemIR::File& sem_ir,
-                 llvm::raw_ostream* vlog_stream)
+                 const Parse::Tree& parse_tree,
+                 llvm::function_ref<const Parse::TreeAndSubtrees&()>
+                     get_parse_tree_and_subtrees,
+                 SemIR::File& sem_ir, llvm::raw_ostream* vlog_stream)
     : tokens_(&tokens),
       emitter_(&emitter),
       parse_tree_(&parse_tree),
+      get_parse_tree_and_subtrees_(get_parse_tree_and_subtrees),
       sem_ir_(&sem_ir),
       vlog_stream_(vlog_stream),
       node_stack_(parse_tree, vlog_stream),

+ 11 - 0
toolchain/check/context.h

@@ -19,6 +19,7 @@
 #include "toolchain/check/scope_stack.h"
 #include "toolchain/parse/node_ids.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/sem_ir/file.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/import_ir.h"
@@ -53,6 +54,8 @@ class Context {
   // Stores references for work.
   explicit Context(const Lex::TokenizedBuffer& tokens,
                    DiagnosticEmitter& emitter, const Parse::Tree& parse_tree,
+                   llvm::function_ref<const Parse::TreeAndSubtrees&()>
+                       get_parse_tree_and_subtrees,
                    SemIR::File& sem_ir, llvm::raw_ostream* vlog_stream);
 
   // Marks an implementation TODO. Always returns false.
@@ -360,6 +363,10 @@ class Context {
 
   auto parse_tree() -> const Parse::Tree& { return *parse_tree_; }
 
+  auto parse_tree_and_subtrees() -> const Parse::TreeAndSubtrees& {
+    return get_parse_tree_and_subtrees_();
+  }
+
   auto sem_ir() -> SemIR::File& { return *sem_ir_; }
 
   auto node_stack() -> NodeStack& { return node_stack_; }
@@ -486,6 +493,10 @@ class Context {
   // The file's parse tree.
   const Parse::Tree* parse_tree_;
 
+  // Returns a lazily constructed TreeAndSubtrees.
+  llvm::function_ref<const Parse::TreeAndSubtrees&()>
+      get_parse_tree_and_subtrees_;
+
   // The SemIR::File being added to.
   SemIR::File* sem_ir_;
 

+ 2 - 1
toolchain/check/handle_impl.cpp

@@ -141,7 +141,8 @@ static auto ExtendImpl(Context& context, Parse::NodeId extend_node,
     // The explicit self type is the same as the default self type, so suggest
     // removing it and recover as if it were not present.
     if (auto self_as =
-            context.parse_tree().ExtractAs<Parse::TypeImplAs>(self_type_node)) {
+            context.parse_tree_and_subtrees().ExtractAs<Parse::TypeImplAs>(
+                self_type_node)) {
       CARBON_DIAGNOSTIC(ExtendImplSelfAsDefault, Note,
                         "Remove the explicit `Self` type here.");
       diag.Note(self_as->type_expr, ExtendImplSelfAsDefault);

+ 1 - 0
toolchain/driver/BUILD

@@ -74,6 +74,7 @@ cc_library(
         "//toolchain/lex",
         "//toolchain/lower",
         "//toolchain/parse",
+        "//toolchain/parse:tree",
         "//toolchain/sem_ir:file",
         "//toolchain/sem_ir:formatter",
         "//toolchain/sem_ir:inst_namer",

+ 30 - 6
toolchain/driver/driver.cpp

@@ -27,6 +27,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/lower/lower.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/sem_ir/formatter.h"
 #include "toolchain/sem_ir/inst_namer.h"
 #include "toolchain/source/source_buffer.h"
@@ -599,7 +600,12 @@ class Driver::CompilationUnit {
     });
     if (options_.dump_parse_tree && IncludeInDumps()) {
       consumer_->Flush();
-      parse_tree_->Print(driver_->output_stream_, options_.preorder_parse_tree);
+      const auto& tree_and_subtrees = GetParseTreeAndSubtrees();
+      if (options_.preorder_parse_tree) {
+        tree_and_subtrees.PrintPreorder(driver_->output_stream_);
+      } else {
+        tree_and_subtrees.Print(driver_->output_stream_);
+      }
     }
     if (mem_usage_) {
       mem_usage_->Collect("parse_tree_", *parse_tree_);
@@ -613,11 +619,15 @@ class Driver::CompilationUnit {
   // Returns information needed to check this unit.
   auto GetCheckUnit() -> Check::Unit {
     CARBON_CHECK(parse_tree_);
-    return {.value_stores = &value_stores_,
-            .tokens = &*tokens_,
-            .parse_tree = &*parse_tree_,
-            .consumer = consumer_,
-            .sem_ir = &sem_ir_};
+    return {
+        .value_stores = &value_stores_,
+        .tokens = &*tokens_,
+        .parse_tree = &*parse_tree_,
+        .consumer = consumer_,
+        .get_parse_tree_and_subtrees = [&]() -> const Parse::TreeAndSubtrees& {
+          return GetParseTreeAndSubtrees();
+        },
+        .sem_ir = &sem_ir_};
   }
 
   // Runs post-check logic. Returns true if checking succeeded for the IR.
@@ -778,6 +788,19 @@ class Driver::CompilationUnit {
     return true;
   }
 
+  // The TreeAndSubtrees is mainly used for debugging and diagnostics, and has
+  // significant overhead. Avoid constructing it when unused.
+  auto GetParseTreeAndSubtrees() -> const Parse::TreeAndSubtrees& {
+    if (!parse_tree_and_subtrees_) {
+      parse_tree_and_subtrees_ = Parse::TreeAndSubtrees(*tokens_, *parse_tree_);
+      if (mem_usage_) {
+        mem_usage_->Collect("parse_tree_and_subtrees_",
+                            *parse_tree_and_subtrees_);
+      }
+    }
+    return *parse_tree_and_subtrees_;
+  }
+
   // Wraps a call with log statements to indicate start and end.
   auto LogCall(llvm::StringLiteral label, llvm::function_ref<void()> fn)
       -> void {
@@ -814,6 +837,7 @@ class Driver::CompilationUnit {
   std::optional<SourceBuffer> source_;
   std::optional<Lex::TokenizedBuffer> tokens_;
   std::optional<Parse::Tree> parse_tree_;
+  std::optional<Parse::TreeAndSubtrees> parse_tree_and_subtrees_;
   std::optional<SemIR::File> sem_ir_;
   std::unique_ptr<llvm::LLVMContext> llvm_context_;
   std::unique_ptr<llvm::Module> module_;

+ 5 - 1
toolchain/parse/BUILD

@@ -106,8 +106,12 @@ cc_library(
     srcs = [
         "extract.cpp",
         "tree.cpp",
+        "tree_and_subtrees.cpp",
+    ],
+    hdrs = [
+        "tree.h",
+        "tree_and_subtrees.h",
     ],
-    hdrs = ["tree.h"],
     deps = [
         ":node_kind",
         "//common:check",

+ 15 - 23
toolchain/parse/context.cpp

@@ -68,18 +68,15 @@ Context::Context(Tree& tree, Lex::TokenizedBuffer& tokens,
 
 auto Context::AddLeafNode(NodeKind kind, Lex::TokenIndex token, bool has_error)
     -> void {
-  tree_->node_impls_.push_back(
-      Tree::NodeImpl(kind, has_error, token, /*subtree_size=*/1));
+  tree_->node_impls_.push_back(Tree::NodeImpl(kind, has_error, token));
   if (has_error) {
     tree_->has_errors_ = true;
   }
 }
 
-auto Context::AddNode(NodeKind kind, Lex::TokenIndex token, int subtree_start,
-                      bool has_error) -> void {
-  int subtree_size = tree_->size() - subtree_start + 1;
-  tree_->node_impls_.push_back(
-      Tree::NodeImpl(kind, has_error, token, subtree_size));
+auto Context::AddNode(NodeKind kind, Lex::TokenIndex token, bool has_error)
+    -> void {
+  tree_->node_impls_.push_back(Tree::NodeImpl(kind, has_error, token));
   if (has_error) {
     tree_->has_errors_ = true;
   }
@@ -91,7 +88,6 @@ auto Context::ReplacePlaceholderNode(int32_t position, NodeKind kind,
   CARBON_CHECK(position >= 0 && position < tree_->size())
       << "position: " << position << " size: " << tree_->size();
   auto* node_impl = &tree_->node_impls_[position];
-  CARBON_CHECK(node_impl->subtree_size == 1);
   CARBON_CHECK(node_impl->kind == NodeKind::Placeholder);
   node_impl->kind = kind;
   node_impl->has_error = has_error;
@@ -123,9 +119,9 @@ auto Context::ConsumeAndAddCloseSymbol(Lex::TokenIndex expected_open,
   Lex::TokenKind open_token_kind = tokens().GetKind(expected_open);
 
   if (!open_token_kind.is_opening_symbol()) {
-    AddNode(close_kind, state.token, state.subtree_start, /*has_error=*/true);
+    AddNode(close_kind, state.token, /*has_error=*/true);
   } else if (auto close_token = ConsumeIf(open_token_kind.closing_symbol())) {
-    AddNode(close_kind, *close_token, state.subtree_start, state.has_error);
+    AddNode(close_kind, *close_token, state.has_error);
   } else {
     // TODO: Include the location of the matching opening delimiter in the
     // diagnostic.
@@ -135,7 +131,7 @@ auto Context::ConsumeAndAddCloseSymbol(Lex::TokenIndex expected_open,
                    open_token_kind.closing_symbol().fixed_spelling());
 
     SkipTo(tokens().GetMatchedClosingToken(expected_open));
-    AddNode(close_kind, Consume(), state.subtree_start, /*has_error=*/true);
+    AddNode(close_kind, Consume(), /*has_error=*/true);
   }
 }
 
@@ -415,7 +411,7 @@ auto Context::AddNodeExpectingDeclSemi(StateStackEntry state,
   }
 
   if (auto semi = ConsumeIf(Lex::TokenKind::Semi)) {
-    AddNode(node_kind, *semi, state.subtree_start, /*has_error=*/false);
+    AddNode(node_kind, *semi, /*has_error=*/false);
   } else {
     if (is_def_allowed) {
       DiagnoseExpectedDeclSemiOrDefinition(decl_kind);
@@ -433,8 +429,7 @@ auto Context::RecoverFromDeclError(StateStackEntry state, NodeKind node_kind,
   if (skip_past_likely_end) {
     token = SkipPastLikelyEnd(token);
   }
-  AddNode(node_kind, token, state.subtree_start,
-          /*has_error=*/true);
+  AddNode(node_kind, token, /*has_error=*/true);
 }
 
 auto Context::ParseLibraryName(bool accept_default)
@@ -464,13 +459,11 @@ auto Context::ParseLibraryName(bool accept_default)
 auto Context::ParseLibrarySpecifier(bool accept_default)
     -> std::optional<StringLiteralValueId> {
   auto library_token = ConsumeChecked(Lex::TokenKind::Library);
-  auto library_subtree_start = tree().size();
   auto library_id = ParseLibraryName(accept_default);
   if (!library_id) {
     AddLeafNode(NodeKind::LibraryName, *position_, /*has_error=*/true);
   }
-  AddNode(NodeKind::LibrarySpecifier, library_token, library_subtree_start,
-          /*has_error=*/false);
+  AddNode(NodeKind::LibrarySpecifier, library_token, /*has_error=*/false);
   return library_id;
 }
 
@@ -503,8 +496,7 @@ static auto ParsingInDeferredDefinitionScope(Context& context) -> bool {
          state == State::DeclDefinitionFinishAsNamedConstraint;
 }
 
-auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token,
-                                         int subtree_start, bool has_error)
+auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token, bool has_error)
     -> void {
   if (ParsingInDeferredDefinitionScope(*this)) {
     deferred_definition_stack_.push_back(tree_->deferred_definitions_.Add(
@@ -512,11 +504,11 @@ auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token,
              FunctionDefinitionStartId(NodeId(tree_->node_impls_.size()))}));
   }
 
-  AddNode(NodeKind::FunctionDefinitionStart, token, subtree_start, has_error);
+  AddNode(NodeKind::FunctionDefinitionStart, token, has_error);
 }
 
-auto Context::AddFunctionDefinition(Lex::TokenIndex token, int subtree_start,
-                                    bool has_error) -> void {
+auto Context::AddFunctionDefinition(Lex::TokenIndex token, bool has_error)
+    -> void {
   if (ParsingInDeferredDefinitionScope(*this)) {
     auto definition_index = deferred_definition_stack_.pop_back_val();
     auto& definition = tree_->deferred_definitions_.Get(definition_index);
@@ -526,7 +518,7 @@ auto Context::AddFunctionDefinition(Lex::TokenIndex token, int subtree_start,
         DeferredDefinitionIndex(tree_->deferred_definitions().size());
   }
 
-  AddNode(NodeKind::FunctionDefinition, token, subtree_start, has_error);
+  AddNode(NodeKind::FunctionDefinition, token, has_error);
 }
 
 auto Context::PrintForStackDump(llvm::raw_ostream& output) const -> void {

+ 4 - 6
toolchain/parse/context.h

@@ -100,8 +100,7 @@ class Context {
       -> void;
 
   // Adds a node to the parse tree that has children.
-  auto AddNode(NodeKind kind, Lex::TokenIndex token, int subtree_start,
-               bool has_error) -> void;
+  auto AddNode(NodeKind kind, Lex::TokenIndex token, bool has_error) -> void;
 
   // Replaces the placeholder node at the indicated position with a leaf node.
   //
@@ -328,12 +327,11 @@ class Context {
 
   // Adds a function definition start node, and begins tracking a deferred
   // definition if necessary.
-  auto AddFunctionDefinitionStart(Lex::TokenIndex token, int subtree_start,
-                                  bool has_error) -> void;
+  auto AddFunctionDefinitionStart(Lex::TokenIndex token, bool has_error)
+      -> void;
   // Adds a function definition node, and ends tracking a deferred definition if
   // necessary.
-  auto AddFunctionDefinition(Lex::TokenIndex token, int subtree_start,
-                             bool has_error) -> void;
+  auto AddFunctionDefinition(Lex::TokenIndex token, bool has_error) -> void;
 
   // Prints information for a stack dump.
   auto PrintForStackDump(llvm::raw_ostream& output) const -> void;

+ 24 - 19
toolchain/parse/extract.cpp

@@ -9,6 +9,7 @@
 #include "common/error.h"
 #include "common/struct_reflection.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/parse/typed_nodes.h"
 
 namespace Carbon::Parse {
@@ -20,12 +21,12 @@ namespace {
 class NodeExtractor {
  public:
   struct CheckpointState {
-    Tree::SiblingIterator it;
+    TreeAndSubtrees::SiblingIterator it;
   };
 
-  NodeExtractor(const Tree* tree, Lex::TokenizedBuffer* tokens,
+  NodeExtractor(const TreeAndSubtrees* tree, const Lex::TokenizedBuffer* tokens,
                 ErrorBuilder* trace, NodeId node_id,
-                llvm::iterator_range<Tree::SiblingIterator> children)
+                llvm::iterator_range<TreeAndSubtrees::SiblingIterator> children)
       : tree_(tree),
         tokens_(tokens),
         trace_(trace),
@@ -34,9 +35,11 @@ class NodeExtractor {
         end_(children.end()) {}
 
   auto at_end() const -> bool { return it_ == end_; }
-  auto kind() const -> NodeKind { return tree_->node_kind(*it_); }
+  auto kind() const -> NodeKind { return tree_->tree().node_kind(*it_); }
   auto has_token() const -> bool { return node_id_.is_valid(); }
-  auto token() const -> Lex::TokenIndex { return tree_->node_token(node_id_); }
+  auto token() const -> Lex::TokenIndex {
+    return tree_->tree().node_token(node_id_);
+  }
   auto token_kind() const -> Lex::TokenKind {
     return tokens_->GetKind(token());
   }
@@ -73,12 +76,12 @@ class NodeExtractor {
                             std::tuple<U...>* /*type*/) -> std::optional<T>;
 
  private:
-  const Tree* tree_;
-  Lex::TokenizedBuffer* tokens_;
+  const TreeAndSubtrees* tree_;
+  const Lex::TokenizedBuffer* tokens_;
   ErrorBuilder* trace_;
   NodeId node_id_;
-  Tree::SiblingIterator it_;
-  Tree::SiblingIterator end_;
+  TreeAndSubtrees::SiblingIterator it_;
+  TreeAndSubtrees::SiblingIterator end_;
 };
 }  // namespace
 
@@ -97,8 +100,8 @@ namespace {
 // };
 // ```
 //
-// Note that `Tree::SiblingIterator`s iterate in reverse order through the
-// children of a node.
+// Note that `TreeAndSubtrees::SiblingIterator`s iterate in reverse order
+// through the children of a node.
 //
 // This class is only in this file.
 template <typename T>
@@ -320,7 +323,7 @@ auto NodeExtractor::MatchesTokenKind(Lex::TokenKind expected_kind) const
   if (token_kind() != expected_kind) {
     if (trace_) {
       *trace_ << "Token " << expected_kind << " expected for "
-              << tree_->node_kind(node_id_) << ", found " << token_kind()
+              << tree_->tree().node_kind(node_id_) << ", found " << token_kind()
               << "\n";
     }
     return false;
@@ -405,14 +408,15 @@ struct Extractable {
 }  // namespace
 
 template <typename T>
-auto Tree::TryExtractNodeFromChildren(
-    NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children,
+auto TreeAndSubtrees::TryExtractNodeFromChildren(
+    NodeId node_id,
+    llvm::iterator_range<TreeAndSubtrees::SiblingIterator> children,
     ErrorBuilder* trace) const -> std::optional<T> {
   NodeExtractor extractor(this, tokens_, trace, node_id, children);
   auto result = Extractable<T>::ExtractImpl(extractor);
   if (!extractor.at_end()) {
     if (trace) {
-      *trace << "Error: " << node_kind(extractor.ExtractNode())
+      *trace << "Error: " << tree_->node_kind(extractor.ExtractNode())
              << " node left unconsumed.";
     }
     return std::nullopt;
@@ -421,16 +425,17 @@ auto Tree::TryExtractNodeFromChildren(
 }
 
 // Manually instantiate Tree::TryExtractNodeFromChildren
-#define CARBON_PARSE_NODE_KIND(KindName)                                    \
-  template auto Tree::TryExtractNodeFromChildren<KindName>(                 \
-      NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children, \
+#define CARBON_PARSE_NODE_KIND(KindName)                               \
+  template auto TreeAndSubtrees::TryExtractNodeFromChildren<KindName>( \
+      NodeId node_id,                                                  \
+      llvm::iterator_range<TreeAndSubtrees::SiblingIterator> children, \
       ErrorBuilder * trace) const -> std::optional<KindName>;
 
 // Also instantiate for `File`, even though it isn't a parse node.
 CARBON_PARSE_NODE_KIND(File)
 #include "toolchain/parse/node_kind.def"
 
-auto Tree::ExtractFile() const -> File {
+auto TreeAndSubtrees::ExtractFile() const -> File {
   return ExtractNodeFromChildren<File>(NodeId::Invalid, roots());
 }
 

+ 2 - 4
toolchain/parse/handle_array_expr.cpp

@@ -24,14 +24,12 @@ auto HandleArrayExprSemi(Context& context) -> void {
   auto state = context.PopState();
   auto semi = context.ConsumeIf(Lex::TokenKind::Semi);
   if (!semi) {
-    context.AddNode(NodeKind::ArrayExprSemi, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::ArrayExprSemi, *context.position(), true);
     CARBON_DIAGNOSTIC(ExpectedArraySemi, Error, "Expected `;` in array type.");
     context.emitter().Emit(*context.position(), ExpectedArraySemi);
     state.has_error = true;
   } else {
-    context.AddNode(NodeKind::ArrayExprSemi, *semi, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::ArrayExprSemi, *semi, state.has_error);
   }
   context.PushState(state, State::ArrayExprFinish);
   if (!context.PositionIs(Lex::TokenKind::CloseSquareBracket)) {

+ 3 - 5
toolchain/parse/handle_binding_pattern.cpp

@@ -76,7 +76,7 @@ static auto HandleBindingPatternFinish(Context& context, NodeKind node_kind)
     -> void {
   auto state = context.PopState();
 
-  context.AddNode(node_kind, state.token, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, state.token, state.has_error);
 
   // Propagate errors to the parent state so that they can take different
   // actions on invalid patterns.
@@ -96,8 +96,7 @@ auto HandleBindingPatternFinishAsRegular(Context& context) -> void {
 auto HandleBindingPatternAddr(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::Addr, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::Addr, state.token, state.has_error);
 
   // If an error was encountered, propagate it while adding a node.
   if (state.has_error) {
@@ -108,8 +107,7 @@ auto HandleBindingPatternAddr(Context& context) -> void {
 auto HandleBindingPatternTemplate(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::Template, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::Template, state.token, state.has_error);
 
   // If an error was encountered, propagate it while adding a node.
   if (state.has_error) {

+ 2 - 4
toolchain/parse/handle_brace_expr.cpp

@@ -151,8 +151,7 @@ static auto HandleBraceExprParamFinish(Context& context, NodeKind node_kind,
                         /*has_error=*/true);
     context.ReturnErrorOnState();
   } else {
-    context.AddNode(node_kind, state.token, state.subtree_start,
-                    /*has_error=*/false);
+    context.AddNode(node_kind, state.token, /*has_error=*/false);
   }
 
   if (context.ConsumeListToken(
@@ -183,8 +182,7 @@ static auto HandleBraceExprFinish(Context& context, NodeKind start_kind,
   auto state = context.PopState();
 
   context.ReplacePlaceholderNode(state.subtree_start, start_kind, state.token);
-  context.AddNode(end_kind, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(end_kind, context.Consume(), state.has_error);
 }
 
 auto HandleBraceExprFinishAsType(Context& context) -> void {

+ 2 - 4
toolchain/parse/handle_call_expr.cpp

@@ -11,8 +11,7 @@ auto HandleCallExpr(Context& context) -> void {
   auto state = context.PopState();
   context.PushState(state, State::CallExprFinish);
 
-  context.AddNode(NodeKind::CallExprStart, context.Consume(),
-                  state.subtree_start, state.has_error);
+  context.AddNode(NodeKind::CallExprStart, context.Consume(), state.has_error);
   if (!context.PositionIs(Lex::TokenKind::CloseParen)) {
     context.PushState(State::CallExprParamFinish);
     context.PushState(State::Expr);
@@ -37,8 +36,7 @@ auto HandleCallExprParamFinish(Context& context) -> void {
 auto HandleCallExprFinish(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::CallExpr, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::CallExpr, context.Consume(), state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 4 - 4
toolchain/parse/handle_choice.cpp

@@ -24,17 +24,17 @@ auto HandleChoiceDefinitionStart(Context& context) -> void {
     }
 
     context.AddNode(NodeKind::ChoiceDefinitionStart, *context.position(),
-                    state.subtree_start, /*has_error=*/true);
+                    /*has_error=*/true);
 
     context.AddNode(NodeKind::ChoiceDefinition, *context.position(),
-                    state.subtree_start, /*has_error=*/true);
+                    /*has_error=*/true);
 
     context.SkipPastLikelyEnd(*context.position());
     return;
   }
 
   context.AddNode(NodeKind::ChoiceDefinitionStart, context.Consume(),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 
   state.has_error = false;
   state.state = State::ChoiceDefinitionFinish;
@@ -94,6 +94,6 @@ auto HandleChoiceDefinitionFinish(Context& context) -> void {
 
   context.AddNode(NodeKind::ChoiceDefinition,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_code_block.cpp

@@ -31,11 +31,9 @@ auto HandleCodeBlockFinish(Context& context) -> void {
 
   // If the block started with an open curly, this is a close curly.
   if (context.tokens().GetKind(state.token) == Lex::TokenKind::OpenCurlyBrace) {
-    context.AddNode(NodeKind::CodeBlock, context.Consume(), state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::CodeBlock, context.Consume(), state.has_error);
   } else {
-    context.AddNode(NodeKind::CodeBlock, state.token, state.subtree_start,
-                    /*has_error=*/true);
+    context.AddNode(NodeKind::CodeBlock, state.token, /*has_error=*/true);
   }
 }
 

+ 2 - 4
toolchain/parse/handle_decl_definition.cpp

@@ -23,8 +23,7 @@ static auto HandleDeclOrDefinition(Context& context, NodeKind decl_kind,
 
   context.PushState(state, definition_finish_state);
   context.PushState(State::DeclScopeLoop);
-  context.AddNode(definition_start_kind, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(definition_start_kind, context.Consume(), state.has_error);
 }
 
 auto HandleDeclOrDefinitionAsClass(Context& context) -> void {
@@ -56,8 +55,7 @@ static auto HandleDeclDefinitionFinish(Context& context,
                                        NodeKind definition_kind) -> void {
   auto state = context.PopState();
 
-  context.AddNode(definition_kind, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(definition_kind, context.Consume(), state.has_error);
 }
 
 auto HandleDeclDefinitionFinishAsClass(Context& context) -> void {

+ 2 - 3
toolchain/parse/handle_decl_name_and_params.cpp

@@ -41,7 +41,7 @@ auto HandleDeclNameAndParams(Context& context) -> void {
     case Lex::TokenKind::Period:
       context.AddNode(NodeKind::NameQualifier,
                       context.ConsumeChecked(Lex::TokenKind::Period),
-                      state.subtree_start, state.has_error);
+                      state.has_error);
       context.PushState(State::DeclNameAndParams);
       break;
 
@@ -83,8 +83,7 @@ auto HandleDeclNameAndParamsAfterParams(Context& context) -> void {
   auto state = context.PopState();
 
   if (auto period = context.ConsumeIf(Lex::TokenKind::Period)) {
-    context.AddNode(NodeKind::NameQualifier, *period, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::NameQualifier, *period, state.has_error);
     context.PushState(State::DeclNameAndParams);
   }
 }

+ 2 - 4
toolchain/parse/handle_decl_scope_loop.cpp

@@ -18,8 +18,7 @@ static auto FinishAndSkipInvalidDecl(Context& context, int32_t subtree_start)
   context.ReplacePlaceholderNode(subtree_start, NodeKind::InvalidParseStart,
                                  cursor, /*has_error=*/true);
   context.AddNode(NodeKind::InvalidParseSubtree,
-                  context.SkipPastLikelyEnd(cursor), subtree_start,
-                  /*has_error=*/true);
+                  context.SkipPastLikelyEnd(cursor), /*has_error=*/true);
 }
 
 // Prints a diagnostic and calls FinishAndSkipInvalidDecl.
@@ -226,12 +225,11 @@ static auto TryHandleAsModifier(Context& context) -> bool {
       auto extern_token = context.Consume();
       if (context.PositionIs(Lex::TokenKind::Library)) {
         // `extern library <owning_library>` syntax.
-        auto subtree_start = context.tree().size();
         context.ParseLibrarySpecifier(/*accept_default=*/true);
         // TODO: Consider error recovery when a non-declaration token is next,
         // like a typo of the library name.
         context.AddNode(NodeKind::ExternModifierWithLibrary, extern_token,
-                        subtree_start, /*has_error=*/false);
+                        /*has_error=*/false);
       } else {
         // `extern` syntax without a library.
         context.AddLeafNode(NodeKind::ExternModifier, extern_token);

+ 9 - 15
toolchain/parse/handle_expr.cpp

@@ -276,12 +276,12 @@ auto HandleExprLoop(Context& context) -> void {
       // node so that checking can insert control flow here.
       case Lex::TokenKind::And:
         context.AddNode(NodeKind::ShortCircuitOperandAnd, state.token,
-                        state.subtree_start, state.has_error);
+                        state.has_error);
         state.state = State::ExprLoopForShortCircuitOperatorAsAnd;
         break;
       case Lex::TokenKind::Or:
         context.AddNode(NodeKind::ShortCircuitOperandOr, state.token,
-                        state.subtree_start, state.has_error);
+                        state.has_error);
         state.state = State::ExprLoopForShortCircuitOperatorAsOr;
         break;
 
@@ -307,8 +307,7 @@ auto HandleExprLoop(Context& context) -> void {
                        << operator_kind;
     }
 
-    context.AddNode(node_kind, state.token, state.subtree_start,
-                    state.has_error);
+    context.AddNode(node_kind, state.token, state.has_error);
     state.has_error = false;
     context.PushState(state);
   }
@@ -318,7 +317,7 @@ auto HandleExprLoop(Context& context) -> void {
 static auto HandleExprLoopForOperator(Context& context,
                                       Context::StateStackEntry state,
                                       NodeKind node_kind) -> void {
-  context.AddNode(node_kind, state.token, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, state.token, state.has_error);
   state.has_error = false;
   context.PushState(state, State::ExprLoop);
 }
@@ -371,8 +370,7 @@ auto HandleExprLoopForShortCircuitOperatorAsOr(Context& context) -> void {
 auto HandleIfExprFinishCondition(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::IfExprIf, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfExprIf, state.token, state.has_error);
 
   if (context.PositionIs(Lex::TokenKind::Then)) {
     context.PushState(State::IfExprFinishThen);
@@ -397,8 +395,7 @@ auto HandleIfExprFinishCondition(Context& context) -> void {
 auto HandleIfExprFinishThen(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::IfExprThen, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfExprThen, state.token, state.has_error);
 
   if (context.PositionIs(Lex::TokenKind::Else)) {
     context.PushState(State::IfExprFinishElse);
@@ -431,16 +428,14 @@ auto HandleIfExprFinishElse(Context& context) -> void {
 auto HandleIfExprFinish(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::IfExprElse, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfExprElse, state.token, state.has_error);
 }
 
 auto HandleExprStatementFinish(Context& context) -> void {
   auto state = context.PopState();
 
   if (auto semi = context.ConsumeIf(Lex::TokenKind::Semi)) {
-    context.AddNode(NodeKind::ExprStatement, *semi, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::ExprStatement, *semi, state.has_error);
     return;
   }
 
@@ -451,8 +446,7 @@ auto HandleExprStatementFinish(Context& context) -> void {
   }
 
   context.AddNode(NodeKind::ExprStatement,
-                  context.SkipPastLikelyEnd(state.token), state.subtree_start,
-                  /*has_error=*/true);
+                  context.SkipPastLikelyEnd(state.token), /*has_error=*/true);
 }
 
 }  // namespace Carbon::Parse

+ 6 - 9
toolchain/parse/handle_function.cpp

@@ -31,8 +31,7 @@ auto HandleFunctionAfterParams(Context& context) -> void {
 auto HandleFunctionReturnTypeFinish(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::ReturnType, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ReturnType, state.token, state.has_error);
 }
 
 auto HandleFunctionSignatureFinish(Context& context) -> void {
@@ -41,12 +40,11 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
   switch (context.PositionKind()) {
     case Lex::TokenKind::Semi: {
       context.AddNode(NodeKind::FunctionDecl, context.Consume(),
-                      state.subtree_start, state.has_error);
+                      state.has_error);
       break;
     }
     case Lex::TokenKind::OpenCurlyBrace: {
-      context.AddFunctionDefinitionStart(context.Consume(), state.subtree_start,
-                                         state.has_error);
+      context.AddFunctionDefinitionStart(context.Consume(), state.has_error);
       // Any error is recorded on the FunctionDefinitionStart.
       state.has_error = false;
       context.PushState(state, State::FunctionDefinitionFinish);
@@ -55,7 +53,7 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
     }
     case Lex::TokenKind::Equal: {
       context.AddNode(NodeKind::BuiltinFunctionDefinitionStart,
-                      context.Consume(), state.subtree_start, state.has_error);
+                      context.Consume(), state.has_error);
       if (!context.ConsumeAndAddLeafNodeIf(Lex::TokenKind::StringLiteral,
                                            NodeKind::BuiltinName)) {
         CARBON_DIAGNOSTIC(ExpectedBuiltinName, Error,
@@ -73,7 +71,7 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
                                      /*skip_past_likely_end=*/true);
       } else {
         context.AddNode(NodeKind::BuiltinFunctionDefinition, *semi,
-                        state.subtree_start, state.has_error);
+                        state.has_error);
       }
       break;
     }
@@ -94,8 +92,7 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
 
 auto HandleFunctionDefinitionFinish(Context& context) -> void {
   auto state = context.PopState();
-  context.AddFunctionDefinition(context.Consume(), state.subtree_start,
-                                state.has_error);
+  context.AddFunctionDefinition(context.Consume(), state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_impl.cpp

@@ -54,8 +54,7 @@ auto HandleImplAfterForall(Context& context) -> void {
   if (state.has_error) {
     context.ReturnErrorOnState();
   }
-  context.AddNode(NodeKind::ImplForall, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ImplForall, state.token, state.has_error);
   // One of:
   //   as <expression> ...
   //   <expression> as <expression>...
@@ -65,8 +64,7 @@ auto HandleImplAfterForall(Context& context) -> void {
 auto HandleImplBeforeAs(Context& context) -> void {
   auto state = context.PopState();
   if (auto as = context.ConsumeIf(Lex::TokenKind::As)) {
-    context.AddNode(NodeKind::TypeImplAs, *as, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::TypeImplAs, *as, state.has_error);
     context.PushState(State::Expr);
   } else {
     if (!state.has_error) {

+ 2 - 2
toolchain/parse/handle_import_and_package.cpp

@@ -16,7 +16,7 @@ namespace Carbon::Parse {
 static auto OnParseError(Context& context, Context::StateStackEntry state,
                          NodeKind declaration) -> void {
   return context.AddNode(declaration, context.SkipPastLikelyEnd(state.token),
-                         state.subtree_start, /*has_error=*/true);
+                         /*has_error=*/true);
 }
 
 // Determines whether the specified modifier appears within the introducer of
@@ -109,7 +109,7 @@ static auto HandleDeclContent(Context& context, Context::StateStackEntry state,
       context.set_packaging_decl(names, is_impl);
     }
 
-    context.AddNode(declaration, *semi, state.subtree_start, state.has_error);
+    context.AddNode(declaration, *semi, state.has_error);
   } else {
     context.DiagnoseExpectedDeclSemi(context.tokens().GetKind(state.token));
     on_parse_error();

+ 1 - 1
toolchain/parse/handle_index_expr.cpp

@@ -13,7 +13,7 @@ auto HandleIndexExpr(Context& context) -> void {
   context.PushState(state, State::IndexExprFinish);
   context.AddNode(NodeKind::IndexExprStart,
                   context.ConsumeChecked(Lex::TokenKind::OpenSquareBracket),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
   context.PushState(State::Expr);
 }
 

+ 1 - 2
toolchain/parse/handle_let.cpp

@@ -45,8 +45,7 @@ auto HandleLetFinish(Context& context) -> void {
     state.has_error = true;
     end_token = context.SkipPastLikelyEnd(state.token);
   }
-  context.AddNode(NodeKind::LetDecl, end_token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::LetDecl, end_token, state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 18 - 28
toolchain/parse/handle_match.cpp

@@ -20,8 +20,8 @@ static auto HandleStatementsBlockStart(Context& context, State finish,
     }
 
     context.AddLeafNode(equal_greater, *context.position(), true);
-    context.AddNode(starter, *context.position(), state.subtree_start, true);
-    context.AddNode(complete, *context.position(), state.subtree_start, true);
+    context.AddNode(starter, *context.position(), true);
+    context.AddNode(complete, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     return;
   }
@@ -35,14 +35,13 @@ static auto HandleStatementsBlockStart(Context& context, State finish,
       context.emitter().Emit(*context.position(), ExpectedMatchCaseBlock);
     }
 
-    context.AddNode(starter, *context.position(), state.subtree_start, true);
-    context.AddNode(complete, *context.position(), state.subtree_start, true);
+    context.AddNode(starter, *context.position(), true);
+    context.AddNode(complete, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     return;
   }
 
-  context.AddNode(starter, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(starter, context.Consume(), state.has_error);
   context.PushState(state, finish);
   context.PushState(State::StatementScopeLoop);
 }
@@ -77,16 +76,14 @@ auto HandleMatchConditionFinish(Context& context) -> void {
       context.emitter().Emit(*context.position(), ExpectedMatchCasesBlock);
     }
 
-    context.AddNode(NodeKind::MatchStatementStart, *context.position(),
-                    state.subtree_start, true);
-    context.AddNode(NodeKind::MatchStatement, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::MatchStatementStart, *context.position(), true);
+    context.AddNode(NodeKind::MatchStatement, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     return;
   }
 
   context.AddNode(NodeKind::MatchStatementStart, context.Consume(),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 
   state.has_error = false;
   if (context.PositionIs(Lex::TokenKind::CloseCurlyBrace)) {
@@ -145,10 +142,8 @@ auto HandleMatchCaseIntroducer(Context& context) -> void {
 auto HandleMatchCaseAfterPattern(Context& context) -> void {
   auto state = context.PopState();
   if (state.has_error) {
-    context.AddNode(NodeKind::MatchCaseStart, *context.position(),
-                    state.subtree_start, true);
-    context.AddNode(NodeKind::MatchCase, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::MatchCaseStart, *context.position(), true);
+    context.AddNode(NodeKind::MatchCase, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     return;
   }
@@ -166,13 +161,10 @@ auto HandleMatchCaseAfterPattern(Context& context) -> void {
                           true);
       context.AddLeafNode(NodeKind::InvalidParse, *context.position(), true);
       state = context.PopState();
-      context.AddNode(NodeKind::MatchCaseGuard, *context.position(),
-                      state.subtree_start, true);
+      context.AddNode(NodeKind::MatchCaseGuard, *context.position(), true);
       state = context.PopState();
-      context.AddNode(NodeKind::MatchCaseStart, *context.position(),
-                      state.subtree_start, true);
-      context.AddNode(NodeKind::MatchCase, *context.position(),
-                      state.subtree_start, true);
+      context.AddNode(NodeKind::MatchCaseStart, *context.position(), true);
+      context.AddNode(NodeKind::MatchCase, *context.position(), true);
       context.SkipPastLikelyEnd(*context.position());
       return;
     }
@@ -184,11 +176,9 @@ auto HandleMatchCaseGuardFinish(Context& context) -> void {
 
   auto close_paren = context.ConsumeIf(Lex::TokenKind::CloseParen);
   if (close_paren) {
-    context.AddNode(NodeKind::MatchCaseGuard, *close_paren, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::MatchCaseGuard, *close_paren, state.has_error);
   } else {
-    context.AddNode(NodeKind::MatchCaseGuard, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::MatchCaseGuard, *context.position(), true);
     context.ReturnErrorOnState();
     context.SkipPastLikelyEnd(*context.position());
     return;
@@ -205,7 +195,7 @@ auto HandleMatchCaseFinish(Context& context) -> void {
   auto state = context.PopState();
   context.AddNode(NodeKind::MatchCase,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 
 auto HandleMatchDefaultIntroducer(Context& context) -> void {
@@ -220,14 +210,14 @@ auto HandleMatchDefaultFinish(Context& context) -> void {
   auto state = context.PopState();
   context.AddNode(NodeKind::MatchDefault,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 
 auto HandleMatchStatementFinish(Context& context) -> void {
   auto state = context.PopState();
   context.AddNode(NodeKind::MatchStatement,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_paren_expr.cpp

@@ -21,8 +21,7 @@ auto HandleOnlyParenExpr(Context& context) -> void {
 
 static auto FinishParenExpr(Context& context,
                             const Context::StateStackEntry& state) -> void {
-  context.AddNode(NodeKind::ParenExpr, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ParenExpr, context.Consume(), state.has_error);
 }
 
 auto HandleOnlyParenExprFinish(Context& context) -> void {
@@ -108,8 +107,7 @@ auto HandleTupleLiteralFinish(Context& context) -> void {
 
   context.ReplacePlaceholderNode(state.subtree_start,
                                  NodeKind::TupleLiteralStart, state.token);
-  context.AddNode(NodeKind::TupleLiteral, context.Consume(),
-                  state.subtree_start, state.has_error);
+  context.AddNode(NodeKind::TupleLiteral, context.Consume(), state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 1 - 1
toolchain/parse/handle_pattern_list.cpp

@@ -88,7 +88,7 @@ static auto HandlePatternListFinish(Context& context, NodeKind node_kind,
   auto state = context.PopState();
 
   context.AddNode(node_kind, context.ConsumeChecked(token_kind),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 
 auto HandlePatternListFinishAsImplicit(Context& context) -> void {

+ 3 - 4
toolchain/parse/handle_period.cpp

@@ -50,7 +50,7 @@ static auto HandlePeriodOrArrow(Context& context, NodeKind node_kind,
     }
   }
 
-  context.AddNode(node_kind, dot, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, dot, state.has_error);
 }
 
 auto HandlePeriodAsExpr(Context& context) -> void {
@@ -72,14 +72,13 @@ auto HandleArrowExpr(Context& context) -> void {
 
 auto HandleCompoundMemberAccess(Context& context) -> void {
   auto state = context.PopState();
-  context.AddNode(NodeKind::MemberAccessExpr, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::MemberAccessExpr, state.token, state.has_error);
 }
 
 auto HandleCompoundPointerMemberAccess(Context& context) -> void {
   auto state = context.PopState();
   context.AddNode(NodeKind::PointerMemberAccessExpr, state.token,
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 5 - 9
toolchain/parse/handle_statement.cpp

@@ -85,7 +85,7 @@ static auto HandleStatementKeywordFinish(Context& context, NodeKind node_kind)
     // Recover to the next semicolon if possible.
     semi = context.SkipPastLikelyEnd(state.token);
   }
-  context.AddNode(node_kind, *semi, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, *semi, state.has_error);
 }
 
 auto HandleStatementBreakFinish(Context& context) -> void {
@@ -141,8 +141,7 @@ auto HandleStatementForHeaderFinish(Context& context) -> void {
 auto HandleStatementForFinish(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::ForStatement, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ForStatement, state.token, state.has_error);
 }
 
 auto HandleStatementIf(Context& context) -> void {
@@ -170,15 +169,13 @@ auto HandleStatementIfThenBlockFinish(Context& context) -> void {
                           ? State::StatementIf
                           : State::CodeBlock);
   } else {
-    context.AddNode(NodeKind::IfStatement, state.token, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::IfStatement, state.token, state.has_error);
   }
 }
 
 auto HandleStatementIfElseBlockFinish(Context& context) -> void {
   auto state = context.PopState();
-  context.AddNode(NodeKind::IfStatement, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfStatement, state.token, state.has_error);
 }
 
 auto HandleStatementReturn(Context& context) -> void {
@@ -234,8 +231,7 @@ auto HandleStatementWhileConditionFinish(Context& context) -> void {
 auto HandleStatementWhileBlockFinish(Context& context) -> void {
   auto state = context.PopState();
 
-  context.AddNode(NodeKind::WhileStatement, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::WhileStatement, state.token, state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_var.cpp

@@ -85,8 +85,7 @@ auto HandleVarFinishAsDecl(Context& context) -> void {
     state.has_error = true;
     end_token = context.SkipPastLikelyEnd(state.token);
   }
-  context.AddNode(NodeKind::VariableDecl, end_token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::VariableDecl, end_token, state.has_error);
 }
 
 auto HandleVarFinishAsFor(Context& context) -> void {
@@ -108,8 +107,7 @@ auto HandleVarFinishAsFor(Context& context) -> void {
     state.has_error = true;
   }
 
-  context.AddNode(NodeKind::ForIn, end_token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ForIn, end_token, state.has_error);
 }
 
 }  // namespace Carbon::Parse

+ 22 - 245
toolchain/parse/tree.cpp

@@ -10,6 +10,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/node_kind.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/parse/typed_nodes.h"
 
 namespace Carbon::Parse {
@@ -20,28 +21,6 @@ auto Tree::postorder() const -> llvm::iterator_range<PostorderIterator> {
       PostorderIterator(NodeId(node_impls_.size())));
 }
 
-auto Tree::postorder(NodeId n) const
-    -> llvm::iterator_range<PostorderIterator> {
-  // The postorder ends after this node, the root, and begins at the start of
-  // its subtree.
-  int start_index = n.index - node_impls_[n.index].subtree_size + 1;
-  return PostorderIterator::MakeRange(NodeId(start_index), n);
-}
-
-auto Tree::children(NodeId n) const -> llvm::iterator_range<SiblingIterator> {
-  CARBON_CHECK(n.is_valid());
-  int end_index = n.index - node_impls_[n.index].subtree_size;
-  return llvm::iterator_range<SiblingIterator>(
-      SiblingIterator(*this, NodeId(n.index - 1)),
-      SiblingIterator(*this, NodeId(end_index)));
-}
-
-auto Tree::roots() const -> llvm::iterator_range<SiblingIterator> {
-  return llvm::iterator_range<SiblingIterator>(
-      SiblingIterator(*this, NodeId(static_cast<int>(node_impls_.size()) - 1)),
-      SiblingIterator(*this, NodeId(-1)));
-}
-
 auto Tree::node_has_error(NodeId n) const -> bool {
   CARBON_CHECK(n.is_valid());
   return node_impls_[n.index].has_error;
@@ -57,245 +36,47 @@ auto Tree::node_token(NodeId n) const -> Lex::TokenIndex {
   return node_impls_[n.index].token;
 }
 
-auto Tree::node_subtree_size(NodeId n) const -> int32_t {
-  CARBON_CHECK(n.is_valid());
-  return node_impls_[n.index].subtree_size;
-}
-
-auto Tree::PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
-                     bool preorder) const -> bool {
-  const auto& n_impl = node_impls_[n.index];
-  output.indent(2 * (depth + 2));
-  output << "{";
-  // If children are being added, include node_index in order to disambiguate
-  // nodes.
-  if (preorder) {
-    output << "node_index: " << n << ", ";
-  }
-  output << "kind: '" << n_impl.kind << "', text: '"
-         << tokens_->GetTokenText(n_impl.token) << "'";
-
-  if (n_impl.has_error) {
-    output << ", has_error: yes";
-  }
-
-  if (n_impl.subtree_size > 1) {
-    output << ", subtree_size: " << n_impl.subtree_size;
-    if (preorder) {
-      output << ", children: [\n";
-      return true;
-    }
-  }
-  output << "}";
-  return false;
-}
-
 auto Tree::Print(llvm::raw_ostream& output) const -> void {
-  output << "- filename: " << tokens_->source().filename() << "\n"
-         << "  parse_tree: [\n";
-
-  // Walk the tree just to calculate depths for each node.
-  llvm::SmallVector<int> indents;
-  indents.append(size(), 0);
-
-  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
-  for (NodeId n : roots()) {
-    node_stack.push_back({n, 0});
-  }
-
-  while (!node_stack.empty()) {
-    NodeId n = NodeId::Invalid;
-    int depth;
-    std::tie(n, depth) = node_stack.pop_back_val();
-    for (NodeId sibling_n : children(n)) {
-      indents[sibling_n.index] = depth + 1;
-      node_stack.push_back({sibling_n, depth + 1});
-    }
-  }
-
-  for (NodeId n : postorder()) {
-    PrintNode(output, n, indents[n.index], /*preorder=*/false);
-    output << ",\n";
-  }
-  output << "  ]\n";
-}
-
-auto Tree::Print(llvm::raw_ostream& output, bool preorder) const -> void {
-  if (!preorder) {
-    Print(output);
-    return;
-  }
-
-  output << "- filename: " << tokens_->source().filename() << "\n"
-         << "  parse_tree: [\n";
-
-  // The parse tree is stored in postorder. The preorder can be constructed
-  // by reversing the order of each level of siblings within an RPO. The
-  // sibling iterators are directly built around RPO and so can be used with a
-  // stack to produce preorder.
-
-  // The roots, like siblings, are in RPO (so reversed), but we add them in
-  // order here because we'll pop off the stack effectively reversing then.
-  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
-  for (NodeId n : roots()) {
-    node_stack.push_back({n, 0});
-  }
-
-  while (!node_stack.empty()) {
-    NodeId n = NodeId::Invalid;
-    int depth;
-    std::tie(n, depth) = node_stack.pop_back_val();
-
-    if (PrintNode(output, n, depth, /*preorder=*/true)) {
-      // Has children, so we descend. We append the children in order here as
-      // well because they will get reversed when popped off the stack.
-      for (NodeId sibling_n : children(n)) {
-        node_stack.push_back({sibling_n, depth + 1});
-      }
-      continue;
-    }
-
-    int next_depth = node_stack.empty() ? 0 : node_stack.back().second;
-    CARBON_CHECK(next_depth <= depth) << "Cannot have the next depth increase!";
-    for (int close_children_count : llvm::seq(0, depth - next_depth)) {
-      (void)close_children_count;
-      output << "]}";
-    }
-
-    // We always end with a comma and a new line as we'll move to the next
-    // node at whatever the current level ends up being.
-    output << "  ,\n";
-  }
-  output << "  ]\n";
-}
-
-auto Tree::CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
-    -> void {
-  mem_usage.Add(MemUsage::ConcatLabel(label, "node_impls_"), node_impls_);
-  mem_usage.Add(MemUsage::ConcatLabel(label, "imports_"), imports_);
-}
-
-auto Tree::VerifyExtract(NodeId node_id, NodeKind kind,
-                         ErrorBuilder* trace) const -> bool {
-  switch (kind) {
-#define CARBON_PARSE_NODE_KIND(Name) \
-  case NodeKind::Name:               \
-    return VerifyExtractAs<Name>(node_id, trace).has_value();
-#include "toolchain/parse/node_kind.def"
-  }
+  TreeAndSubtrees(*tokens_, *this).Print(output);
 }
 
 auto Tree::Verify() const -> ErrorOr<Success> {
   llvm::SmallVector<NodeId> nodes;
   // Traverse the tree in postorder.
   for (NodeId n : postorder()) {
-    const auto& n_impl = node_impls_[n.index];
-
-    if (n_impl.has_error && !has_errors_) {
+    if (node_has_error(n) && !has_errors()) {
       return Error(llvm::formatv(
-          "NodeId #{0} has errors, but the tree is not marked as having any.",
-          n.index));
+          "Node {0} has errors, but the tree is not marked as having any.", n));
     }
 
-    if (n_impl.kind == NodeKind::Placeholder) {
+    if (node_kind(n) == NodeKind::Placeholder) {
       return Error(llvm::formatv(
-          "Node #{0} is a placeholder node that wasn't replaced.", n.index));
-    }
-    // Should extract successfully if node not marked as having an error.
-    // Without this code, a 10 mloc test case of lex & parse takes
-    // 4.129 s ± 0.041 s. With this additional verification, it takes
-    // 5.768 s ± 0.036 s.
-    if (!n_impl.has_error && !VerifyExtract(n, n_impl.kind, nullptr)) {
-      ErrorBuilder trace;
-      trace << llvm::formatv(
-          "NodeId #{0} couldn't be extracted as a {1}. Trace:\n", n,
-          n_impl.kind);
-      VerifyExtract(n, n_impl.kind, &trace);
-      return trace;
+          "Node {0} is a placeholder node that wasn't replaced.", n));
     }
-
-    int subtree_size = 1;
-    if (n_impl.kind.has_bracket()) {
-      int child_count = 0;
-      while (true) {
-        if (nodes.empty()) {
-          return Error(
-              llvm::formatv("NodeId #{0} is a {1} with bracket {2}, but didn't "
-                            "find the bracket.",
-                            n, n_impl.kind, n_impl.kind.bracket()));
-        }
-        auto child_impl = node_impls_[nodes.pop_back_val().index];
-        subtree_size += child_impl.subtree_size;
-        ++child_count;
-        if (n_impl.kind.bracket() == child_impl.kind) {
-          // If there's a bracketing node and a child count, verify the child
-          // count too.
-          if (n_impl.kind.has_child_count() &&
-              child_count != n_impl.kind.child_count()) {
-            return Error(llvm::formatv(
-                "NodeId #{0} is a {1} with child_count {2}, but encountered "
-                "{3} nodes before we reached the bracketing node.",
-                n, n_impl.kind, n_impl.kind.child_count(), child_count));
-          }
-          break;
-        }
-      }
-    } else {
-      for (int i : llvm::seq(n_impl.kind.child_count())) {
-        if (nodes.empty()) {
-          return Error(llvm::formatv(
-              "NodeId #{0} is a {1} with child_count {2}, but only had {3} "
-              "nodes to consume.",
-              n, n_impl.kind, n_impl.kind.child_count(), i));
-        }
-        auto child_impl = node_impls_[nodes.pop_back_val().index];
-        subtree_size += child_impl.subtree_size;
-      }
-    }
-    if (n_impl.subtree_size != subtree_size) {
-      return Error(llvm::formatv(
-          "NodeId #{0} is a {1} with subtree_size of {2}, but calculated {3}.",
-          n, n_impl.kind, n_impl.subtree_size, subtree_size));
-    }
-    nodes.push_back(n);
   }
 
-  // Remaining nodes should all be roots in the tree; make sure they line up.
-  CARBON_CHECK(nodes.back().index ==
-               static_cast<int32_t>(node_impls_.size()) - 1)
-      << nodes.back() << " " << node_impls_.size() - 1;
-  int prev_index = -1;
-  for (const auto& n : nodes) {
-    const auto& n_impl = node_impls_[n.index];
-
-    if (n.index - n_impl.subtree_size != prev_index) {
-      return Error(
-          llvm::formatv("NodeId #{0} is a root {1} with subtree_size {2}, but "
-                        "previous root was at #{3}.",
-                        n, n_impl.kind, n_impl.subtree_size, prev_index));
-    }
-    prev_index = n.index;
+  if (!has_errors() &&
+      static_cast<int32_t>(size()) != tokens_->expected_parse_tree_size()) {
+    return Error(llvm::formatv(
+        "Tree has {0} nodes and no errors, but "
+        "Lex::TokenizedBuffer expected {1} nodes for {2} tokens.",
+        size(), tokens_->expected_parse_tree_size(), tokens_->size()));
   }
 
-  // Validate the roots, ensures Tree::ExtractFile() doesn't CHECK-fail.
-  if (!TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), nullptr)) {
-    ErrorBuilder trace;
-    trace << "Roots of tree couldn't be extracted as a `File`. Trace:\n";
-    TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), &trace);
-    return trace;
-  }
+#ifndef NDEBUG
+  TreeAndSubtrees subtrees(*tokens_, *this);
+  CARBON_RETURN_IF_ERROR(subtrees.Verify());
+#endif  // NDEBUG
 
-  if (!has_errors_ && static_cast<int32_t>(node_impls_.size()) !=
-                          tokens_->expected_parse_tree_size()) {
-    return Error(
-        llvm::formatv("Tree has {0} nodes and no errors, but "
-                      "Lex::TokenizedBuffer expected {1} nodes for {2} tokens.",
-                      node_impls_.size(), tokens_->expected_parse_tree_size(),
-                      tokens_->size()));
-  }
   return Success();
 }
 
+auto Tree::CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
+    -> void {
+  mem_usage.Add(MemUsage::ConcatLabel(label, "node_impls_"), node_impls_);
+  mem_usage.Add(MemUsage::ConcatLabel(label, "imports_"), imports_);
+}
+
 auto Tree::PostorderIterator::MakeRange(NodeId begin, NodeId end)
     -> llvm::iterator_range<PostorderIterator> {
   CARBON_CHECK(begin.is_valid() && end.is_valid());
@@ -307,8 +88,4 @@ auto Tree::PostorderIterator::Print(llvm::raw_ostream& output) const -> void {
   output << node_;
 }
 
-auto Tree::SiblingIterator::Print(llvm::raw_ostream& output) const -> void {
-  output << node_;
-}
-
 }  // namespace Carbon::Parse

+ 7 - 246
toolchain/parse/tree.h

@@ -78,7 +78,6 @@ struct File;
 class Tree : public Printable<Tree> {
  public:
   class PostorderIterator;
-  class SiblingIterator;
 
   // Names in packaging, whether the file's packaging or an import. Links back
   // to the node for diagnostics.
@@ -114,20 +113,6 @@ class Tree : public Printable<Tree> {
   // postorder.
   auto postorder() const -> llvm::iterator_range<PostorderIterator>;
 
-  // Returns an iterable range over the parse tree node and all of its
-  // descendants in depth-first postorder.
-  auto postorder(NodeId n) const -> llvm::iterator_range<PostorderIterator>;
-
-  // Returns an iterable range over the direct children of a node in the parse
-  // tree. This is a forward range, but is constant time to increment. The order
-  // of children is the same as would be found in a reverse postorder traversal.
-  auto children(NodeId n) const -> llvm::iterator_range<SiblingIterator>;
-
-  // Returns an iterable range over the roots of the parse tree. This is a
-  // forward range, but is constant time to increment. The order of roots is the
-  // same as would be found in a reverse postorder traversal.
-  auto roots() const -> llvm::iterator_range<SiblingIterator>;
-
   // Tests whether a particular node contains an error and may not match the
   // full expected structure of the grammar.
   auto node_has_error(NodeId n) const -> bool;
@@ -183,96 +168,19 @@ class Tree : public Printable<Tree> {
     return deferred_definitions_;
   }
 
-  // See the other Print comments.
+  // Builds TreeAndSubtrees to print the tree.
   auto Print(llvm::raw_ostream& output) const -> void;
 
-  // Prints a description of the parse tree to the provided `raw_ostream`.
-  //
-  // The tree may be printed in either preorder or postorder. Output represents
-  // each node as a YAML record; in preorder, children are nested.
-  //
-  // In both, a node is formatted as:
-  //   ```
-  //   {kind: 'foo', text: '...'}
-  //   ```
-  //
-  // The top level is formatted as an array of these nodes.
-  //   ```
-  //   [
-  //   {kind: 'foo', text: '...'},
-  //   {kind: 'foo', text: '...'},
-  //   ...
-  //   ]
-  //   ```
-  //
-  // In postorder, nodes are indented in order to indicate depth. For example, a
-  // node with two children, one of them with an error:
-  //   ```
-  //     {kind: 'bar', text: '...', has_error: yes},
-  //     {kind: 'baz', text: '...'}
-  //   {kind: 'foo', text: '...', subtree_size: 2}
-  //   ```
-  //
-  // In preorder, nodes are marked as children with postorder (storage) index.
-  // For example, a node with two children, one of them with an error:
-  //   ```
-  //   {node_index: 2, kind: 'foo', text: '...', subtree_size: 2, children: [
-  //     {node_index: 0, kind: 'bar', text: '...', has_error: yes},
-  //     {node_index: 1, kind: 'baz', text: '...'}]}
-  //   ```
-  //
-  // This can be parsed as YAML using tools like `python-yq` combined with `jq`
-  // on the command line. The format is also reasonably amenable to other
-  // line-oriented shell tools from `grep` to `awk`.
-  auto Print(llvm::raw_ostream& output, bool preorder) const -> void;
-
   // Collects memory usage of members.
   auto CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
       -> void;
 
-  // The following `Extract*` function provide an alternative way of accessing
-  // the nodes of a tree. It is intended to be more convenient and type-safe,
-  // but slower and can't be used on nodes that are marked as having an error.
-  // It is appropriate for uses that are less performance sensitive, like
-  // diagnostics. Example usage:
-  // ```
-  // auto file = tree->ExtractFile();
-  // for (AnyDeclId decl_id : file.decls) {
-  //   // `decl_id` is convertible to a `NodeId`.
-  //   if (std::optional<FunctionDecl> fn_decl =
-  //       tree->ExtractAs<FunctionDecl>(decl_id)) {
-  //     // fn_decl->params is a `TuplePatternId` (which extends `NodeId`)
-  //     // that is guaranteed to reference a `TuplePattern`.
-  //     std::optional<TuplePattern> params = tree->Extract(fn_decl->params);
-  //     // `params` has a value unless there was an error in that node.
-  //   } else if (auto class_def = tree->ExtractAs<ClassDefinition>(decl_id)) {
-  //     // ...
-  //   }
-  // }
-  // ```
-
-  // Extract a `File` object representing the parse tree for the whole file.
-  // #include "toolchain/parse/typed_nodes.h" to get the definition of `File`
-  // and the types representing its children nodes. This is implemented in
-  // extract.cpp.
-  auto ExtractFile() const -> File;
-
-  // Converts this node_id to a typed node of a specified type, if it is a valid
-  // node of that kind.
-  template <typename T>
-  auto ExtractAs(NodeId node_id) const -> std::optional<T>;
-
-  // Converts to a typed node, if it is not an error.
-  template <typename IdT>
-  auto Extract(IdT id) const
-      -> std::optional<typename NodeForId<IdT>::TypedNode>;
-
   // Verifies the parse tree structure. Checks invariants of the parse tree
   // structure and returns verification errors.
   //
-  // This is fairly slow, and is primarily intended to be used as a debugging
-  // aid. This routine doesn't directly CHECK so that it can be used within a
-  // debugger.
+  // In opt builds, this does some minimal checking. In debug builds, it'll
+  // build a TreeAndSubtrees and run further verification. This doesn't directly
+  // CHECK so that it can be used within a debugger.
   auto Verify() const -> ErrorOr<Success>;
 
  private:
@@ -285,12 +193,8 @@ class Tree : public Printable<Tree> {
   // The in-memory representation of data used for a particular node in the
   // tree.
   struct NodeImpl {
-    explicit NodeImpl(NodeKind kind, bool has_error, Lex::TokenIndex token,
-                      int subtree_size)
-        : kind(kind),
-          has_error(has_error),
-          token(token),
-          subtree_size(subtree_size) {}
+    explicit NodeImpl(NodeKind kind, bool has_error, Lex::TokenIndex token)
+        : kind(kind), has_error(has_error), token(token) {}
 
     // The kind of this node. Note that this is only a single byte.
     NodeKind kind;
@@ -315,38 +219,11 @@ class Tree : public Printable<Tree> {
 
     // The token root of this node.
     Lex::TokenIndex token;
-
-    // The size of this node's subtree of the parse tree. This is the number of
-    // nodes (and thus tokens) that are covered by this node (and its
-    // descendents) in the parse tree.
-    //
-    // During a *reverse* postorder (RPO) traversal of the parse tree, this can
-    // also be thought of as the offset to the next non-descendant node. When
-    // this node is not the first child of its parent (which is the last child
-    // visited in RPO), that is the offset to the next sibling. When this node
-    // *is* the first child of its parent, this will be an offset to the node's
-    // parent's next sibling, or if it the parent is also a first child, the
-    // grandparent's next sibling, and so on.
-    //
-    // This field should always be a positive integer as at least this node is
-    // part of its subtree.
-    int32_t subtree_size;
   };
 
-  static_assert(sizeof(NodeImpl) == 12,
+  static_assert(sizeof(NodeImpl) == 8,
                 "Unexpected size of node implementation!");
 
-  // Like ExtractAs(), but malformed tree errors are not fatal. Should only be
-  // used by `Verify()` or by tests.
-  template <typename T>
-  auto VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
-      -> std::optional<T>;
-
-  // Wrapper around `VerifyExtractAs` to dispatch based on a runtime node kind.
-  // Returns true if extraction was successful.
-  auto VerifyExtract(NodeId node_id, NodeKind kind, ErrorBuilder* trace) const
-      -> bool;
-
   // Sets the kind of a node. This is intended to allow putting the tree into a
   // state where verification can fail, in order to make the failure path of
   // `Verify` testable.
@@ -354,26 +231,6 @@ class Tree : public Printable<Tree> {
     node_impls_[node_id.index].kind = kind;
   }
 
-  // Prints a single node for Print(). Returns true when preorder and there are
-  // children.
-  auto PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
-                 bool preorder) const -> bool;
-
-  // Extract a node of type `T` from a sibling range. This is expected to
-  // consume the complete sibling range. Malformed tree errors are written
-  // to `*trace`, if `trace != nullptr`. This is implemented in extract.cpp.
-  template <typename T>
-  auto TryExtractNodeFromChildren(
-      NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children,
-      ErrorBuilder* trace) const -> std::optional<T>;
-
-  // Extract a node of type `T` from a sibling range. This is expected to
-  // consume the complete sibling range. Malformed tree errors are fatal.
-  template <typename T>
-  auto ExtractNodeFromChildren(
-      NodeId node_id,
-      llvm::iterator_range<Tree::SiblingIterator> children) const -> T;
-
   // Depth-first postorder sequence of node implementation data.
   llvm::SmallVector<NodeImpl> node_impls_;
 
@@ -449,102 +306,6 @@ class Tree::PostorderIterator
   NodeId node_;
 };
 
-// A forward iterator across the siblings at a particular level in the parse
-// tree. It produces `Tree::NodeId` objects which are opaque handles and must
-// be used in conjunction with the `Tree` itself.
-//
-// While this is a forward iterator and may not have good locality within the
-// `Tree` data structure, it is still constant time to increment and
-// suitable for algorithms relying on that property.
-//
-// The siblings are discovered through a reverse postorder (RPO) tree traversal
-// (which is made constant time through cached distance information), and so the
-// relative order of siblings matches their RPO order.
-class Tree::SiblingIterator
-    : public llvm::iterator_facade_base<SiblingIterator,
-                                        std::forward_iterator_tag, NodeId, int,
-                                        const NodeId*, NodeId>,
-      public Printable<Tree::SiblingIterator> {
- public:
-  explicit SiblingIterator() = delete;
-
-  auto operator==(const SiblingIterator& rhs) const -> bool {
-    return node_ == rhs.node_;
-  }
-
-  auto operator*() const -> NodeId { return node_; }
-
-  using iterator_facade_base::operator++;
-  auto operator++() -> SiblingIterator& {
-    node_.index -= std::abs(tree_->node_impls_[node_.index].subtree_size);
-    return *this;
-  }
-
-  // Prints the underlying node index.
-  auto Print(llvm::raw_ostream& output) const -> void;
-
- private:
-  friend class Tree;
-
-  explicit SiblingIterator(const Tree& tree_arg, NodeId n)
-      : tree_(&tree_arg), node_(n) {}
-
-  const Tree* tree_;
-
-  NodeId node_;
-};
-
-template <typename T>
-auto Tree::ExtractNodeFromChildren(
-    NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children) const
-    -> T {
-  auto result = TryExtractNodeFromChildren<T>(node_id, children, nullptr);
-  if (!result.has_value()) {
-    // On error try again, this time capturing a trace.
-    ErrorBuilder trace;
-    TryExtractNodeFromChildren<T>(node_id, children, &trace);
-    CARBON_FATAL() << "Malformed parse node:\n"
-                   << static_cast<Error>(trace).message();
-  }
-  return *result;
-}
-
-template <typename T>
-auto Tree::ExtractAs(NodeId node_id) const -> std::optional<T> {
-  static_assert(HasKindMember<T>, "Not a parse node type");
-  if (!IsValid<T>(node_id)) {
-    return std::nullopt;
-  }
-
-  return ExtractNodeFromChildren<T>(node_id, children(node_id));
-}
-
-template <typename T>
-auto Tree::VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
-    -> std::optional<T> {
-  static_assert(HasKindMember<T>, "Not a parse node type");
-  if (!IsValid<T>(node_id)) {
-    if (trace) {
-      *trace << "VerifyExtractAs error: wrong kind " << node_kind(node_id)
-             << ", expected " << T::Kind << "\n";
-    }
-    return std::nullopt;
-  }
-
-  return TryExtractNodeFromChildren<T>(node_id, children(node_id), trace);
-}
-
-template <typename IdT>
-auto Tree::Extract(IdT id) const
-    -> std::optional<typename NodeForId<IdT>::TypedNode> {
-  if (!IsValid(id)) {
-    return std::nullopt;
-  }
-
-  using T = typename NodeForId<IdT>::TypedNode;
-  return ExtractNodeFromChildren<T>(id, children(id));
-}
-
 template <const NodeKind& K>
 struct Tree::ConvertTo<NodeIdForKind<K>> {
   static auto AllowedFor(NodeKind kind) -> bool { return kind == K; }

+ 244 - 0
toolchain/parse/tree_and_subtrees.cpp

@@ -0,0 +1,244 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "toolchain/parse/tree_and_subtrees.h"
+
+namespace Carbon::Parse {
+
+TreeAndSubtrees::TreeAndSubtrees(const Lex::TokenizedBuffer& tokens,
+                                 const Tree& tree)
+    : tokens_(&tokens), tree_(&tree) {
+  subtree_sizes_.reserve(tree_->size());
+
+  // A stack of nodes which haven't yet been used as children.
+  llvm::SmallVector<NodeId> size_stack;
+  for (auto n : tree.postorder()) {
+    // Nodes always include themselves.
+    int32_t size = 1;
+    auto kind = tree.node_kind(n);
+    if (kind.has_child_count()) {
+      // When the child count is set, remove the specific number from the stack.
+      CARBON_CHECK(static_cast<int32_t>(size_stack.size()) >=
+                   kind.child_count())
+          << "Need " << kind.child_count() << " children for " << kind
+          << ", have " << size_stack.size() << " available";
+      for (auto i : llvm::seq(kind.child_count())) {
+        auto child = size_stack.pop_back_val();
+        CARBON_CHECK((size_t)child.index < subtree_sizes_.size());
+        size += subtree_sizes_[child.index];
+        if (kind.has_bracket() && i == kind.child_count() - 1) {
+          CARBON_CHECK(kind.bracket() == tree.node_kind(child))
+              << "Node " << kind << " needs bracket " << kind.bracket()
+              << ", found wrong bracket " << tree.node_kind(child);
+        }
+      }
+    } else {
+      while (true) {
+        CARBON_CHECK(!size_stack.empty())
+            << "Node " << kind << " is missing bracket " << kind.bracket();
+        auto child = size_stack.pop_back_val();
+        size += subtree_sizes_[child.index];
+        if (kind.bracket() == tree.node_kind(child)) {
+          break;
+        }
+      }
+    }
+    size_stack.push_back(n);
+    subtree_sizes_.push_back(size);
+  }
+
+  CARBON_CHECK(static_cast<int>(subtree_sizes_.size()) == tree_->size());
+
+  // Remaining nodes should all be roots in the tree; make sure they line up.
+  CARBON_CHECK(size_stack.back().index ==
+               static_cast<int32_t>(tree_->size()) - 1)
+      << size_stack.back() << " " << tree_->size() - 1;
+  int prev_index = -1;
+  for (const auto& n : size_stack) {
+    CARBON_CHECK(n.index - subtree_sizes_[n.index] == prev_index)
+        << "NodeId " << n << " is a root " << tree_->node_kind(n)
+        << " with subtree_size " << subtree_sizes_[n.index]
+        << ", but previous root was at " << prev_index << ".";
+    prev_index = n.index;
+  }
+}
+
+auto TreeAndSubtrees::VerifyExtract(NodeId node_id, NodeKind kind,
+                                    ErrorBuilder* trace) const -> bool {
+  switch (kind) {
+#define CARBON_PARSE_NODE_KIND(Name) \
+  case NodeKind::Name:               \
+    return VerifyExtractAs<Name>(node_id, trace).has_value();
+#include "toolchain/parse/node_kind.def"
+  }
+}
+
+auto TreeAndSubtrees::Verify() const -> ErrorOr<Success> {
+  // Validate that each node extracts successfully when not marked as having an
+  // error.
+  //
+  // Without this code, a 10 mloc test case of lex & parse takes 4.129 s ± 0.041
+  // s. With this additional verification, it takes 5.768 s ± 0.036 s.
+  for (NodeId n : tree_->postorder()) {
+    if (tree_->node_has_error(n)) {
+      continue;
+    }
+
+    auto node_kind = tree_->node_kind(n);
+    if (!VerifyExtract(n, node_kind, nullptr)) {
+      ErrorBuilder trace;
+      trace << llvm::formatv(
+          "NodeId #{0} couldn't be extracted as a {1}. Trace:\n", n, node_kind);
+      VerifyExtract(n, node_kind, &trace);
+      return trace;
+    }
+  }
+
+  // Validate the roots. Also ensures Tree::ExtractFile() doesn't error.
+  if (!TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), nullptr)) {
+    ErrorBuilder trace;
+    trace << "Roots of tree couldn't be extracted as a `File`. Trace:\n";
+    TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), &trace);
+    return trace;
+  }
+
+  return Success();
+}
+
+auto TreeAndSubtrees::postorder(NodeId n) const
+    -> llvm::iterator_range<Tree::PostorderIterator> {
+  // The postorder ends after this node, the root, and begins at the start of
+  // its subtree.
+  int start_index = n.index - subtree_sizes_[n.index] + 1;
+  return Tree::PostorderIterator::MakeRange(NodeId(start_index), n);
+}
+
+auto TreeAndSubtrees::children(NodeId n) const
+    -> llvm::iterator_range<SiblingIterator> {
+  CARBON_CHECK(n.is_valid());
+  int end_index = n.index - subtree_sizes_[n.index];
+  return llvm::iterator_range<SiblingIterator>(
+      SiblingIterator(*this, NodeId(n.index - 1)),
+      SiblingIterator(*this, NodeId(end_index)));
+}
+
+auto TreeAndSubtrees::roots() const -> llvm::iterator_range<SiblingIterator> {
+  return llvm::iterator_range<SiblingIterator>(
+      SiblingIterator(*this,
+                      NodeId(static_cast<int>(subtree_sizes_.size()) - 1)),
+      SiblingIterator(*this, NodeId(-1)));
+}
+
+auto TreeAndSubtrees::PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
+                                bool preorder) const -> bool {
+  output.indent(2 * (depth + 2));
+  output << "{";
+  // If children are being added, include node_index in order to disambiguate
+  // nodes.
+  if (preorder) {
+    output << "node_index: " << n << ", ";
+  }
+  output << "kind: '" << tree_->node_kind(n) << "', text: '"
+         << tokens_->GetTokenText(tree_->node_token(n)) << "'";
+
+  if (tree_->node_has_error(n)) {
+    output << ", has_error: yes";
+  }
+
+  if (subtree_sizes_[n.index] > 1) {
+    output << ", subtree_size: " << subtree_sizes_[n.index];
+    if (preorder) {
+      output << ", children: [\n";
+      return true;
+    }
+  }
+  output << "}";
+  return false;
+}
+
+auto TreeAndSubtrees::Print(llvm::raw_ostream& output) const -> void {
+  output << "- filename: " << tokens_->source().filename() << "\n"
+         << "  parse_tree: [\n";
+
+  // Walk the tree just to calculate depths for each node.
+  llvm::SmallVector<int> indents;
+  indents.resize(subtree_sizes_.size(), 0);
+
+  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
+  for (NodeId n : roots()) {
+    node_stack.push_back({n, 0});
+  }
+
+  while (!node_stack.empty()) {
+    NodeId n = NodeId::Invalid;
+    int depth;
+    std::tie(n, depth) = node_stack.pop_back_val();
+    for (NodeId sibling_n : children(n)) {
+      indents[sibling_n.index] = depth + 1;
+      node_stack.push_back({sibling_n, depth + 1});
+    }
+  }
+
+  for (NodeId n : tree_->postorder()) {
+    PrintNode(output, n, indents[n.index], /*preorder=*/false);
+    output << ",\n";
+  }
+  output << "  ]\n";
+}
+
+auto TreeAndSubtrees::PrintPreorder(llvm::raw_ostream& output) const -> void {
+  output << "- filename: " << tokens_->source().filename() << "\n"
+         << "  parse_tree: [\n";
+
+  // The parse tree is stored in postorder. The preorder can be constructed
+  // by reversing the order of each level of siblings within an RPO. The
+  // sibling iterators are directly built around RPO and so can be used with a
+  // stack to produce preorder.
+
+  // The roots, like siblings, are in RPO (so reversed), but we add them in
+  // order here because we'll pop off the stack effectively reversing then.
+  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
+  for (NodeId n : roots()) {
+    node_stack.push_back({n, 0});
+  }
+
+  while (!node_stack.empty()) {
+    NodeId n = NodeId::Invalid;
+    int depth;
+    std::tie(n, depth) = node_stack.pop_back_val();
+
+    if (PrintNode(output, n, depth, /*preorder=*/true)) {
+      // Has children, so we descend. We append the children in order here as
+      // well because they will get reversed when popped off the stack.
+      for (NodeId sibling_n : children(n)) {
+        node_stack.push_back({sibling_n, depth + 1});
+      }
+      continue;
+    }
+
+    int next_depth = node_stack.empty() ? 0 : node_stack.back().second;
+    CARBON_CHECK(next_depth <= depth) << "Cannot have the next depth increase!";
+    for (int close_children_count : llvm::seq(0, depth - next_depth)) {
+      (void)close_children_count;
+      output << "]}";
+    }
+
+    // We always end with a comma and a new line as we'll move to the next
+    // node at whatever the current level ends up being.
+    output << "  ,\n";
+  }
+  output << "  ]\n";
+}
+
+auto TreeAndSubtrees::CollectMemUsage(MemUsage& mem_usage,
+                                      llvm::StringRef label) const -> void {
+  mem_usage.Add(MemUsage::ConcatLabel(label, "subtree_sizes_"), subtree_sizes_);
+}
+
+auto TreeAndSubtrees::SiblingIterator::Print(llvm::raw_ostream& output) const
+    -> void {
+  output << node_;
+}
+
+}  // namespace Carbon::Parse

+ 277 - 0
toolchain/parse/tree_and_subtrees.h

@@ -0,0 +1,277 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_PARSE_TREE_AND_SUBTREES_H_
+#define CARBON_TOOLCHAIN_PARSE_TREE_AND_SUBTREES_H_
+
+#include "llvm/ADT/SmallVector.h"
+#include "toolchain/parse/tree.h"
+
+namespace Carbon::Parse {
+
+// Calculates and stores subtree data for a parse tree. Supports APIs that
+// require subtree knowledge.
+//
+// This requires a complete tree.
+class TreeAndSubtrees {
+ public:
+  class SiblingIterator;
+
+  explicit TreeAndSubtrees(const Lex::TokenizedBuffer& tokens,
+                           const Tree& tree);
+
+  // The following `Extract*` function provide an alternative way of accessing
+  // the nodes of a tree. It is intended to be more convenient and type-safe,
+  // but slower and can't be used on nodes that are marked as having an error.
+  // It is appropriate for uses that are less performance sensitive, like
+  // diagnostics. Example usage:
+  // ```
+  // auto file = tree->ExtractFile();
+  // for (AnyDeclId decl_id : file.decls) {
+  //   // `decl_id` is convertible to a `NodeId`.
+  //   if (std::optional<FunctionDecl> fn_decl =
+  //       tree->ExtractAs<FunctionDecl>(decl_id)) {
+  //     // fn_decl->params is a `TuplePatternId` (which extends `NodeId`)
+  //     // that is guaranteed to reference a `TuplePattern`.
+  //     std::optional<TuplePattern> params = tree->Extract(fn_decl->params);
+  //     // `params` has a value unless there was an error in that node.
+  //   } else if (auto class_def = tree->ExtractAs<ClassDefinition>(decl_id)) {
+  //     // ...
+  //   }
+  // }
+  // ```
+
+  // Extract a `File` object representing the parse tree for the whole file.
+  // #include "toolchain/parse/typed_nodes.h" to get the definition of `File`
+  // and the types representing its children nodes. This is implemented in
+  // extract.cpp.
+  auto ExtractFile() const -> File;
+
+  // Converts this node_id to a typed node of a specified type, if it is a valid
+  // node of that kind.
+  template <typename T>
+  auto ExtractAs(NodeId node_id) const -> std::optional<T>;
+
+  // Converts to a typed node, if it is not an error.
+  template <typename IdT>
+  auto Extract(IdT id) const
+      -> std::optional<typename NodeForId<IdT>::TypedNode>;
+
+  // Verifies that each node in the tree can be successfully extracted.
+  //
+  // This is fairly slow, and is primarily intended to be used as a debugging
+  // aid. This doesn't directly CHECK so that it can be used within a debugger.
+  auto Verify() const -> ErrorOr<Success>;
+
+  // Prints the parse tree in postorder format. See also use PrintPreorder.
+  //
+  // Output represents each node as a YAML record. A node is formatted as:
+  //   ```
+  //   {kind: 'foo', text: '...'}
+  //   ```
+  //
+  // The top level is formatted as an array of these nodes.
+  //   ```
+  //   [
+  //   {kind: 'foo', text: '...'},
+  //   {kind: 'foo', text: '...'},
+  //   ...
+  //   ]
+  //   ```
+  //
+  // Nodes are indented in order to indicate depth. For example, a node with two
+  // children, one of them with an error:
+  //   ```
+  //     {kind: 'bar', text: '...', has_error: yes},
+  //     {kind: 'baz', text: '...'}
+  //   {kind: 'foo', text: '...', subtree_size: 2}
+  //   ```
+  //
+  // This can be parsed as YAML using tools like `python-yq` combined with `jq`
+  // on the command line. The format is also reasonably amenable to other
+  // line-oriented shell tools from `grep` to `awk`.
+  auto Print(llvm::raw_ostream& output) const -> void;
+
+  // Prints the parse tree in preorder. The format is YAML, and similar to
+  // Print. However, nodes are marked as children with postorder (storage)
+  // index. For example, a node with two children, one of them with an error:
+  //   ```
+  //   {node_index: 2, kind: 'foo', text: '...', subtree_size: 2, children: [
+  //     {node_index: 0, kind: 'bar', text: '...', has_error: yes},
+  //     {node_index: 1, kind: 'baz', text: '...'}]}
+  //   ```
+  auto PrintPreorder(llvm::raw_ostream& output) const -> void;
+
+  // Collects memory usage of members.
+  auto CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
+      -> void;
+
+  // Returns an iterable range over the parse tree node and all of its
+  // descendants in depth-first postorder.
+  auto postorder(NodeId n) const
+      -> llvm::iterator_range<Tree::PostorderIterator>;
+
+  // Returns an iterable range over the direct children of a node in the parse
+  // tree. This is a forward range, but is constant time to increment. The order
+  // of children is the same as would be found in a reverse postorder traversal.
+  auto children(NodeId n) const -> llvm::iterator_range<SiblingIterator>;
+
+  // Returns an iterable range over the roots of the parse tree. This is a
+  // forward range, but is constant time to increment. The order of roots is the
+  // same as would be found in a reverse postorder traversal.
+  auto roots() const -> llvm::iterator_range<SiblingIterator>;
+
+  auto tree() const -> const Tree& { return *tree_; }
+
+ private:
+  friend class TypedNodesTestPeer;
+
+  // Extract a node of type `T` from a sibling range. This is expected to
+  // consume the complete sibling range. Malformed tree errors are written
+  // to `*trace`, if `trace != nullptr`. This is implemented in extract.cpp.
+  template <typename T>
+  auto TryExtractNodeFromChildren(
+      NodeId node_id, llvm::iterator_range<SiblingIterator> children,
+      ErrorBuilder* trace) const -> std::optional<T>;
+
+  // Extract a node of type `T` from a sibling range. This is expected to
+  // consume the complete sibling range. Malformed tree errors are fatal.
+  template <typename T>
+  auto ExtractNodeFromChildren(
+      NodeId node_id, llvm::iterator_range<SiblingIterator> children) const
+      -> T;
+
+  // Like ExtractAs(), but malformed tree errors are not fatal. Should only be
+  // used by `Verify()` or by tests.
+  template <typename T>
+  auto VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
+      -> std::optional<T>;
+
+  // Wrapper around `VerifyExtractAs` to dispatch based on a runtime node kind.
+  // Returns true if extraction was successful.
+  auto VerifyExtract(NodeId node_id, NodeKind kind, ErrorBuilder* trace) const
+      -> bool;
+
+  // Prints a single node for Print(). Returns true when preorder and there are
+  // children.
+  auto PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
+                 bool preorder) const -> bool;
+
+  // The associated tokens.
+  const Lex::TokenizedBuffer* tokens_;
+
+  // The associated tree.
+  const Tree* tree_;
+
+  // For each node in the tree, the size of the node's subtree. This is the
+  // number of nodes (and thus tokens) that are covered by the node (and its
+  // descendents) in the parse tree. It's one for nodes with no children.
+  //
+  // During a *reverse* postorder (RPO) traversal of the parse tree, this can
+  // also be thought of as the offset to the next non-descendant node. When the
+  // node is not the first child of its parent (which is the last child visited
+  // in RPO), that is the offset to the next sibling. When the node *is* the
+  // first child of its parent, this will be an offset to the node's parent's
+  // next sibling, or if it the parent is also a first child, the grandparent's
+  // next sibling, and so on.
+  llvm::SmallVector<int32_t> subtree_sizes_;
+};
+
+// A forward iterator across the siblings at a particular level in the parse
+// tree. It produces `Tree::NodeId` objects which are opaque handles and must
+// be used in conjunction with the `Tree` itself.
+//
+// While this is a forward iterator and may not have good locality within the
+// `Tree` data structure, it is still constant time to increment and
+// suitable for algorithms relying on that property.
+//
+// The siblings are discovered through a reverse postorder (RPO) tree traversal
+// (which is made constant time through cached distance information), and so the
+// relative order of siblings matches their RPO order.
+class TreeAndSubtrees::SiblingIterator
+    : public llvm::iterator_facade_base<SiblingIterator,
+                                        std::forward_iterator_tag, NodeId, int,
+                                        const NodeId*, NodeId>,
+      public Printable<SiblingIterator> {
+ public:
+  explicit SiblingIterator() = delete;
+
+  auto operator==(const SiblingIterator& rhs) const -> bool {
+    return node_ == rhs.node_;
+  }
+
+  auto operator*() const -> NodeId { return node_; }
+
+  using iterator_facade_base::operator++;
+  auto operator++() -> SiblingIterator& {
+    node_.index -= std::abs(tree_->subtree_sizes_[node_.index]);
+    return *this;
+  }
+
+  // Prints the underlying node index.
+  auto Print(llvm::raw_ostream& output) const -> void;
+
+ private:
+  friend class TreeAndSubtrees;
+
+  explicit SiblingIterator(const TreeAndSubtrees& tree, NodeId node)
+      : tree_(&tree), node_(node) {}
+
+  const TreeAndSubtrees* tree_;
+  NodeId node_;
+};
+
+template <typename T>
+auto TreeAndSubtrees::ExtractNodeFromChildren(
+    NodeId node_id, llvm::iterator_range<SiblingIterator> children) const -> T {
+  auto result = TryExtractNodeFromChildren<T>(node_id, children, nullptr);
+  if (!result.has_value()) {
+    // On error try again, this time capturing a trace.
+    ErrorBuilder trace;
+    TryExtractNodeFromChildren<T>(node_id, children, &trace);
+    CARBON_FATAL() << "Malformed parse node:\n"
+                   << static_cast<Error>(trace).message();
+  }
+  return *result;
+}
+
+template <typename T>
+auto TreeAndSubtrees::ExtractAs(NodeId node_id) const -> std::optional<T> {
+  static_assert(HasKindMember<T>, "Not a parse node type");
+  if (!tree_->IsValid<T>(node_id)) {
+    return std::nullopt;
+  }
+
+  return ExtractNodeFromChildren<T>(node_id, children(node_id));
+}
+
+template <typename T>
+auto TreeAndSubtrees::VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
+    -> std::optional<T> {
+  static_assert(HasKindMember<T>, "Not a parse node type");
+  if (!tree_->IsValid<T>(node_id)) {
+    if (trace) {
+      *trace << "VerifyExtractAs error: wrong kind "
+             << tree_->node_kind(node_id) << ", expected " << T::Kind << "\n";
+    }
+    return std::nullopt;
+  }
+
+  return TryExtractNodeFromChildren<T>(node_id, children(node_id), trace);
+}
+
+template <typename IdT>
+auto TreeAndSubtrees::Extract(IdT id) const
+    -> std::optional<typename NodeForId<IdT>::TypedNode> {
+  if (!tree_->IsValid(id)) {
+    return std::nullopt;
+  }
+
+  using T = typename NodeForId<IdT>::TypedNode;
+  return ExtractNodeFromChildren<T>(id, children(id));
+}
+
+}  // namespace Carbon::Parse
+
+#endif  // CARBON_TOOLCHAIN_PARSE_TREE_AND_SUBTREES_H_

+ 16 - 8
toolchain/parse/tree_node_diagnostic_converter.h

@@ -5,9 +5,12 @@
 #ifndef CARBON_TOOLCHAIN_PARSE_TREE_NODE_DIAGNOSTIC_CONVERTER_H_
 #define CARBON_TOOLCHAIN_PARSE_TREE_NODE_DIAGNOSTIC_CONVERTER_H_
 
+#include <utility>
+
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 
 namespace Carbon::Parse {
 
@@ -32,11 +35,12 @@ class NodeLoc {
 
 class NodeLocConverter : public DiagnosticConverter<NodeLoc> {
  public:
-  explicit NodeLocConverter(const Lex::TokenizedBuffer* tokens,
-                            llvm::StringRef filename, const Tree* parse_tree)
+  explicit NodeLocConverter(
+      const Lex::TokenizedBuffer* tokens, llvm::StringRef filename,
+      llvm::function_ref<const Parse::TreeAndSubtrees&()> get_tree_and_subtrees)
       : token_converter_(tokens),
         filename_(filename),
-        parse_tree_(parse_tree) {}
+        get_tree_and_subtrees_(get_tree_and_subtrees) {}
 
   // Map the given token into a diagnostic location.
   auto ConvertLoc(NodeLoc node_loc, ContextFnT context_fn) const
@@ -47,17 +51,19 @@ class NodeLocConverter : public DiagnosticConverter<NodeLoc> {
       return {.filename = filename_};
     }
 
+    const auto& tree = get_tree_and_subtrees_();
+
     if (node_loc.token_only()) {
       return token_converter_.ConvertLoc(
-          parse_tree_->node_token(node_loc.node_id()), context_fn);
+          tree.tree().node_token(node_loc.node_id()), context_fn);
     }
 
     // Construct a location that encompasses all tokens that descend from this
     // node (including the root).
-    Lex::TokenIndex start_token = parse_tree_->node_token(node_loc.node_id());
+    Lex::TokenIndex start_token = tree.tree().node_token(node_loc.node_id());
     Lex::TokenIndex end_token = start_token;
-    for (NodeId desc : parse_tree_->postorder(node_loc.node_id())) {
-      Lex::TokenIndex desc_token = parse_tree_->node_token(desc);
+    for (NodeId desc : tree.postorder(node_loc.node_id())) {
+      Lex::TokenIndex desc_token = tree.tree().node_token(desc);
       if (!desc_token.is_valid()) {
         continue;
       }
@@ -89,7 +95,9 @@ class NodeLocConverter : public DiagnosticConverter<NodeLoc> {
  private:
   Lex::TokenDiagnosticConverter token_converter_;
   llvm::StringRef filename_;
-  const Tree* parse_tree_;
+
+  // Returns a lazily constructed TreeAndSubtrees.
+  llvm::function_ref<const Parse::TreeAndSubtrees&()> get_tree_and_subtrees_;
 };
 
 }  // namespace Carbon::Parse

+ 5 - 2
toolchain/parse/tree_test.cpp

@@ -16,6 +16,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/testing/yaml_test_helpers.h"
 
 namespace Carbon::Parse {
@@ -60,7 +61,8 @@ TEST_F(TreeTest, AsAndTryAs) {
   Lex::TokenizedBuffer& tokens = GetTokenizedBuffer("fn F();");
   Tree tree = Parse(tokens, consumer_, /*vlog_stream=*/nullptr);
   ASSERT_FALSE(tree.has_errors());
-  auto it = tree.roots().begin();
+  TreeAndSubtrees tree_and_subtrees(tokens, tree);
+  auto it = tree_and_subtrees.roots().begin();
   // A FileEnd node, so won't match.
   NodeId n = *it;
 
@@ -134,8 +136,9 @@ TEST_F(TreeTest, PrintPreorderAsYAML) {
   Lex::TokenizedBuffer& tokens = GetTokenizedBuffer("fn F();");
   Tree tree = Parse(tokens, consumer_, /*vlog_stream=*/nullptr);
   EXPECT_FALSE(tree.has_errors());
+  TreeAndSubtrees tree_and_subtrees(tokens, tree);
   TestRawOstream print_stream;
-  tree.Print(print_stream, /*preorder=*/true);
+  tree_and_subtrees.PrintPreorder(print_stream);
 
   auto param_list = Yaml::Sequence(ElementsAre(Yaml::Mapping(
       ElementsAre(Pair("node_index", "3"), Pair("kind", "TuplePatternStart"),

+ 12 - 8
toolchain/parse/typed_nodes_test.cpp

@@ -12,6 +12,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 
 namespace Carbon::Parse {
 
@@ -20,7 +21,7 @@ namespace Carbon::Parse {
 class TypedNodesTestPeer {
  public:
   template <typename T>
-  static auto VerifyExtractAs(const Tree* tree, NodeId node_id,
+  static auto VerifyExtractAs(const TreeAndSubtrees* tree, NodeId node_id,
                               ErrorBuilder* trace) -> std::optional<T> {
     return tree->VerifyExtractAs<T>(node_id, trace);
   }
@@ -57,14 +58,16 @@ class TypedNodeTest : public ::testing::Test {
     return token_storage_.front();
   }
 
-  auto GetTree(llvm::StringRef t) -> Tree& {
+  auto GetTree(llvm::StringRef t) -> TreeAndSubtrees& {
     tree_storage_.push_front(Parse(GetTokenizedBuffer(t), consumer_,
                                    /*vlog_stream=*/nullptr));
-    return tree_storage_.front();
+    tree_and_subtrees_storage_.push_front(
+        TreeAndSubtrees(token_storage_.front(), tree_storage_.front()));
+    return tree_and_subtrees_storage_.front();
   }
 
   auto GetTokenizedBufferAndTree(llvm::StringRef t)
-      -> std::pair<Lex::TokenizedBuffer*, Tree*> {
+      -> std::pair<Lex::TokenizedBuffer*, TreeAndSubtrees*> {
     auto* tree = &GetTree(t);
     return {&token_storage_.front(), tree};
   }
@@ -74,6 +77,7 @@ class TypedNodeTest : public ::testing::Test {
   std::forward_list<SourceBuffer> source_storage_;
   std::forward_list<Lex::TokenizedBuffer> token_storage_;
   std::forward_list<Tree> tree_storage_;
+  std::forward_list<TreeAndSubtrees> tree_and_subtrees_storage_;
   DiagnosticConsumer& consumer_ = ConsoleDiagnosticConsumer();
 };
 
@@ -81,15 +85,15 @@ TEST_F(TypedNodeTest, Empty) {
   auto* tree = &GetTree("");
   auto file = tree->ExtractFile();
 
-  EXPECT_TRUE(tree->IsValid(file.start));
+  EXPECT_TRUE(tree->tree().IsValid(file.start));
   EXPECT_TRUE(tree->ExtractAs<FileStart>(file.start).has_value());
   EXPECT_TRUE(tree->Extract(file.start).has_value());
 
-  EXPECT_TRUE(tree->IsValid(file.end));
+  EXPECT_TRUE(tree->tree().IsValid(file.end));
   EXPECT_TRUE(tree->ExtractAs<FileEnd>(file.end).has_value());
   EXPECT_TRUE(tree->Extract(file.end).has_value());
 
-  EXPECT_FALSE(tree->IsValid<FileEnd>(file.start));
+  EXPECT_FALSE(tree->tree().IsValid<FileEnd>(file.start));
   EXPECT_FALSE(tree->ExtractAs<FileEnd>(file.start).has_value());
 }
 
@@ -342,7 +346,7 @@ TEST_F(TypedNodeTest, VerifyInvalid) {
   ASSERT_TRUE(f_intro.has_value());
 
   // Change the kind of the introducer and check we get a good trace log.
-  TypedNodesTestPeer::SetNodeKind(tree, f_sig->introducer,
+  TypedNodesTestPeer::SetNodeKind(&tree_storage_.front(), f_sig->introducer,
                                   NodeKind::ClassIntroducer);
 
   // The introducer should not extract as a FunctionIntroducer any more because