Forráskód Böngészése

Separate subtree size information from parse nodes. (#4174)

Move subtree sizes over to TreeAndSubtrees, using the different
structure to represent the additional parse work that occurs, as well as
making it clear which functions require the extra information. My intent
is to make it hard to use this by accident.

The subtree size is still tracked during Parse::Tree construction. I
think a lot of that can be cleaned up, although we use it during
placeholder assignment so it may take some work. I wanted to see what
people thought about this before taking action on such a change.

I'm using a 1m line source file generated by #4124 for testing. Command
is `time bazel-bin/toolchain/install/prefix_root/bin/carbon compile
--phase=check --dump-mem-usage ~/tmp/data.carbon`

At head, what I'm seeing is:

```
...
parse_tree_.node_impls_:
  used_bytes:      61516116
  reserved_bytes:  61516116
...
Total:
  used_bytes:      447814230
  reserved_bytes:  551663894
...
1.43s user 0.14s system 99% cpu 1.565 total
```

With `Tree::Verify` disabled completely, it looks like:
```
parse_tree_.node_impls_:
  used_bytes:      41010744
  reserved_bytes:  41010744
...
Total:
  used_bytes:      427308858
  reserved_bytes:  531158522
...
1.20s user 0.13s system 99% cpu 1.332 total
```

Re-enabling just the basic verification (what is now `Tree::Verify`),
I'm seeing maybe 0.05s slower, but that's within noise for my system. I
do see variability in my timing results, and overall I think this is a
0.2s +/- 0.1s improvement versus the earlier (always testing `Extract`
code) implementation. That's opt; debug builds will be unaffected,
because the same checking occurs as before.

Note, the subtree size is a third of the node representation, which is
why I'm showing the decrease in memory usage here.
Jon Ross-Perkins 1 éve
szülő
commit
f67791cfee
40 módosított fájl, 767 hozzáadás és 692 törlés
  1. 8 4
      language_server/language_server.cpp
  2. 3 2
      toolchain/check/check.cpp
  3. 3 0
      toolchain/check/check.h
  4. 5 2
      toolchain/check/context.cpp
  5. 11 0
      toolchain/check/context.h
  6. 2 1
      toolchain/check/handle_impl.cpp
  7. 1 0
      toolchain/driver/BUILD
  8. 30 6
      toolchain/driver/driver.cpp
  9. 5 1
      toolchain/parse/BUILD
  10. 15 23
      toolchain/parse/context.cpp
  11. 4 6
      toolchain/parse/context.h
  12. 24 19
      toolchain/parse/extract.cpp
  13. 2 4
      toolchain/parse/handle_array_expr.cpp
  14. 3 5
      toolchain/parse/handle_binding_pattern.cpp
  15. 2 4
      toolchain/parse/handle_brace_expr.cpp
  16. 2 4
      toolchain/parse/handle_call_expr.cpp
  17. 4 4
      toolchain/parse/handle_choice.cpp
  18. 2 4
      toolchain/parse/handle_code_block.cpp
  19. 2 4
      toolchain/parse/handle_decl_definition.cpp
  20. 2 3
      toolchain/parse/handle_decl_name_and_params.cpp
  21. 2 4
      toolchain/parse/handle_decl_scope_loop.cpp
  22. 9 15
      toolchain/parse/handle_expr.cpp
  23. 6 9
      toolchain/parse/handle_function.cpp
  24. 2 4
      toolchain/parse/handle_impl.cpp
  25. 2 2
      toolchain/parse/handle_import_and_package.cpp
  26. 1 1
      toolchain/parse/handle_index_expr.cpp
  27. 1 2
      toolchain/parse/handle_let.cpp
  28. 18 28
      toolchain/parse/handle_match.cpp
  29. 2 4
      toolchain/parse/handle_paren_expr.cpp
  30. 1 1
      toolchain/parse/handle_pattern_list.cpp
  31. 3 4
      toolchain/parse/handle_period.cpp
  32. 5 9
      toolchain/parse/handle_statement.cpp
  33. 2 4
      toolchain/parse/handle_var.cpp
  34. 22 245
      toolchain/parse/tree.cpp
  35. 7 246
      toolchain/parse/tree.h
  36. 244 0
      toolchain/parse/tree_and_subtrees.cpp
  37. 277 0
      toolchain/parse/tree_and_subtrees.h
  38. 16 8
      toolchain/parse/tree_node_diagnostic_converter.h
  39. 5 2
      toolchain/parse/tree_test.cpp
  40. 12 8
      toolchain/parse/typed_nodes_test.cpp

+ 8 - 4
language_server/language_server.cpp

@@ -10,6 +10,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/lex.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/parse.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/source/source_buffer.h"
 #include "toolchain/source/source_buffer.h"
 
 
 namespace Carbon::LS {
 namespace Carbon::LS {
@@ -79,11 +80,12 @@ auto LanguageServer::onReply(llvm::json::Value /*id*/,
 // Returns the text of first child of kind Parse::NodeKind::IdentifierName.
 // Returns the text of first child of kind Parse::NodeKind::IdentifierName.
 static auto GetIdentifierName(const SharedValueStores& value_stores,
 static auto GetIdentifierName(const SharedValueStores& value_stores,
                               const Lex::TokenizedBuffer& tokens,
                               const Lex::TokenizedBuffer& tokens,
-                              const Parse::Tree& p, Parse::NodeId node)
+                              const Parse::TreeAndSubtrees& p,
+                              Parse::NodeId node)
     -> std::optional<llvm::StringRef> {
     -> std::optional<llvm::StringRef> {
   for (auto ch : p.children(node)) {
   for (auto ch : p.children(node)) {
-    if (p.node_kind(ch) == Parse::NodeKind::IdentifierName) {
-      auto token = p.node_token(ch);
+    if (p.tree().node_kind(ch) == Parse::NodeKind::IdentifierName) {
+      auto token = p.tree().node_token(ch);
       if (tokens.GetKind(token) == Lex::TokenKind::Identifier) {
       if (tokens.GetKind(token) == Lex::TokenKind::Identifier) {
         return value_stores.identifiers().Get(tokens.GetIdentifier(token));
         return value_stores.identifiers().Get(tokens.GetIdentifier(token));
       }
       }
@@ -104,6 +106,7 @@ void LanguageServer::OnDocumentSymbol(
   auto buf = SourceBuffer::MakeFromFile(vfs, file, NullDiagnosticConsumer());
   auto buf = SourceBuffer::MakeFromFile(vfs, file, NullDiagnosticConsumer());
   auto lexed = Lex::Lex(value_stores, *buf, NullDiagnosticConsumer());
   auto lexed = Lex::Lex(value_stores, *buf, NullDiagnosticConsumer());
   auto parsed = Parse::Parse(lexed, NullDiagnosticConsumer(), nullptr);
   auto parsed = Parse::Parse(lexed, NullDiagnosticConsumer(), nullptr);
+  Parse::TreeAndSubtrees tree_and_subtrees(lexed, parsed);
   std::vector<clang::clangd::DocumentSymbol> result;
   std::vector<clang::clangd::DocumentSymbol> result;
   for (const auto& node : parsed.postorder()) {
   for (const auto& node : parsed.postorder()) {
     clang::clangd::SymbolKind symbol_kind;
     clang::clangd::SymbolKind symbol_kind;
@@ -126,7 +129,8 @@ void LanguageServer::OnDocumentSymbol(
         continue;
         continue;
     }
     }
 
 
-    if (auto name = GetIdentifierName(value_stores, lexed, parsed, node)) {
+    if (auto name =
+            GetIdentifierName(value_stores, lexed, tree_and_subtrees, node)) {
       auto tok = parsed.node_token(node);
       auto tok = parsed.node_token(node);
       clang::clangd::Position pos{lexed.GetLineNumber(tok) - 1,
       clang::clangd::Position pos{lexed.GetLineNumber(tok) - 1,
                                   lexed.GetColumnNumber(tok) - 1};
                                   lexed.GetColumnNumber(tok) - 1};

+ 3 - 2
toolchain/check/check.cpp

@@ -65,7 +65,7 @@ struct UnitInfo {
       : check_ir_id(check_ir_id),
       : check_ir_id(check_ir_id),
         unit(&unit),
         unit(&unit),
         converter(unit.tokens, unit.tokens->source().filename(),
         converter(unit.tokens, unit.tokens->source().filename(),
-                  unit.parse_tree),
+                  unit.get_parse_tree_and_subtrees),
         err_tracker(*unit.consumer),
         err_tracker(*unit.consumer),
         emitter(converter, err_tracker) {}
         emitter(converter, err_tracker) {}
 
 
@@ -891,7 +891,8 @@ static auto CheckParseTree(
   SemIRDiagnosticConverter converter(node_converters, &sem_ir);
   SemIRDiagnosticConverter converter(node_converters, &sem_ir);
   Context::DiagnosticEmitter emitter(converter, unit_info.err_tracker);
   Context::DiagnosticEmitter emitter(converter, unit_info.err_tracker);
   Context context(*unit_info.unit->tokens, emitter, *unit_info.unit->parse_tree,
   Context context(*unit_info.unit->tokens, emitter, *unit_info.unit->parse_tree,
-                  sem_ir, vlog_stream);
+                  unit_info.unit->get_parse_tree_and_subtrees, sem_ir,
+                  vlog_stream);
   PrettyStackTraceFunction context_dumper(
   PrettyStackTraceFunction context_dumper(
       [&](llvm::raw_ostream& output) { context.PrintForStackDump(output); });
       [&](llvm::raw_ostream& output) { context.PrintForStackDump(output); });
 
 

+ 3 - 0
toolchain/check/check.h

@@ -10,6 +10,7 @@
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/tree.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/sem_ir/file.h"
 #include "toolchain/sem_ir/file.h"
 
 
 namespace Carbon::Check {
 namespace Carbon::Check {
@@ -20,6 +21,8 @@ struct Unit {
   const Lex::TokenizedBuffer* tokens;
   const Lex::TokenizedBuffer* tokens;
   const Parse::Tree* parse_tree;
   const Parse::Tree* parse_tree;
   DiagnosticConsumer* consumer;
   DiagnosticConsumer* consumer;
+  // Returns a lazily constructed TreeAndSubtrees.
+  std::function<const Parse::TreeAndSubtrees&()> get_parse_tree_and_subtrees;
   // The generated IR. Unset on input, set on output.
   // The generated IR. Unset on input, set on output.
   std::optional<SemIR::File>* sem_ir;
   std::optional<SemIR::File>* sem_ir;
 };
 };

+ 5 - 2
toolchain/check/context.cpp

@@ -36,11 +36,14 @@
 namespace Carbon::Check {
 namespace Carbon::Check {
 
 
 Context::Context(const Lex::TokenizedBuffer& tokens, DiagnosticEmitter& emitter,
 Context::Context(const Lex::TokenizedBuffer& tokens, DiagnosticEmitter& emitter,
-                 const Parse::Tree& parse_tree, SemIR::File& sem_ir,
-                 llvm::raw_ostream* vlog_stream)
+                 const Parse::Tree& parse_tree,
+                 llvm::function_ref<const Parse::TreeAndSubtrees&()>
+                     get_parse_tree_and_subtrees,
+                 SemIR::File& sem_ir, llvm::raw_ostream* vlog_stream)
     : tokens_(&tokens),
     : tokens_(&tokens),
       emitter_(&emitter),
       emitter_(&emitter),
       parse_tree_(&parse_tree),
       parse_tree_(&parse_tree),
+      get_parse_tree_and_subtrees_(get_parse_tree_and_subtrees),
       sem_ir_(&sem_ir),
       sem_ir_(&sem_ir),
       vlog_stream_(vlog_stream),
       vlog_stream_(vlog_stream),
       node_stack_(parse_tree, vlog_stream),
       node_stack_(parse_tree, vlog_stream),

+ 11 - 0
toolchain/check/context.h

@@ -19,6 +19,7 @@
 #include "toolchain/check/scope_stack.h"
 #include "toolchain/check/scope_stack.h"
 #include "toolchain/parse/node_ids.h"
 #include "toolchain/parse/node_ids.h"
 #include "toolchain/parse/tree.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/sem_ir/file.h"
 #include "toolchain/sem_ir/file.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/import_ir.h"
 #include "toolchain/sem_ir/import_ir.h"
@@ -53,6 +54,8 @@ class Context {
   // Stores references for work.
   // Stores references for work.
   explicit Context(const Lex::TokenizedBuffer& tokens,
   explicit Context(const Lex::TokenizedBuffer& tokens,
                    DiagnosticEmitter& emitter, const Parse::Tree& parse_tree,
                    DiagnosticEmitter& emitter, const Parse::Tree& parse_tree,
+                   llvm::function_ref<const Parse::TreeAndSubtrees&()>
+                       get_parse_tree_and_subtrees,
                    SemIR::File& sem_ir, llvm::raw_ostream* vlog_stream);
                    SemIR::File& sem_ir, llvm::raw_ostream* vlog_stream);
 
 
   // Marks an implementation TODO. Always returns false.
   // Marks an implementation TODO. Always returns false.
@@ -360,6 +363,10 @@ class Context {
 
 
   auto parse_tree() -> const Parse::Tree& { return *parse_tree_; }
   auto parse_tree() -> const Parse::Tree& { return *parse_tree_; }
 
 
+  auto parse_tree_and_subtrees() -> const Parse::TreeAndSubtrees& {
+    return get_parse_tree_and_subtrees_();
+  }
+
   auto sem_ir() -> SemIR::File& { return *sem_ir_; }
   auto sem_ir() -> SemIR::File& { return *sem_ir_; }
 
 
   auto node_stack() -> NodeStack& { return node_stack_; }
   auto node_stack() -> NodeStack& { return node_stack_; }
@@ -486,6 +493,10 @@ class Context {
   // The file's parse tree.
   // The file's parse tree.
   const Parse::Tree* parse_tree_;
   const Parse::Tree* parse_tree_;
 
 
+  // Returns a lazily constructed TreeAndSubtrees.
+  llvm::function_ref<const Parse::TreeAndSubtrees&()>
+      get_parse_tree_and_subtrees_;
+
   // The SemIR::File being added to.
   // The SemIR::File being added to.
   SemIR::File* sem_ir_;
   SemIR::File* sem_ir_;
 
 

+ 2 - 1
toolchain/check/handle_impl.cpp

@@ -141,7 +141,8 @@ static auto ExtendImpl(Context& context, Parse::NodeId extend_node,
     // The explicit self type is the same as the default self type, so suggest
     // The explicit self type is the same as the default self type, so suggest
     // removing it and recover as if it were not present.
     // removing it and recover as if it were not present.
     if (auto self_as =
     if (auto self_as =
-            context.parse_tree().ExtractAs<Parse::TypeImplAs>(self_type_node)) {
+            context.parse_tree_and_subtrees().ExtractAs<Parse::TypeImplAs>(
+                self_type_node)) {
       CARBON_DIAGNOSTIC(ExtendImplSelfAsDefault, Note,
       CARBON_DIAGNOSTIC(ExtendImplSelfAsDefault, Note,
                         "Remove the explicit `Self` type here.");
                         "Remove the explicit `Self` type here.");
       diag.Note(self_as->type_expr, ExtendImplSelfAsDefault);
       diag.Note(self_as->type_expr, ExtendImplSelfAsDefault);

+ 1 - 0
toolchain/driver/BUILD

@@ -74,6 +74,7 @@ cc_library(
         "//toolchain/lex",
         "//toolchain/lex",
         "//toolchain/lower",
         "//toolchain/lower",
         "//toolchain/parse",
         "//toolchain/parse",
+        "//toolchain/parse:tree",
         "//toolchain/sem_ir:file",
         "//toolchain/sem_ir:file",
         "//toolchain/sem_ir:formatter",
         "//toolchain/sem_ir:formatter",
         "//toolchain/sem_ir:inst_namer",
         "//toolchain/sem_ir:inst_namer",

+ 30 - 6
toolchain/driver/driver.cpp

@@ -27,6 +27,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/lex.h"
 #include "toolchain/lower/lower.h"
 #include "toolchain/lower/lower.h"
 #include "toolchain/parse/parse.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/sem_ir/formatter.h"
 #include "toolchain/sem_ir/formatter.h"
 #include "toolchain/sem_ir/inst_namer.h"
 #include "toolchain/sem_ir/inst_namer.h"
 #include "toolchain/source/source_buffer.h"
 #include "toolchain/source/source_buffer.h"
@@ -599,7 +600,12 @@ class Driver::CompilationUnit {
     });
     });
     if (options_.dump_parse_tree && IncludeInDumps()) {
     if (options_.dump_parse_tree && IncludeInDumps()) {
       consumer_->Flush();
       consumer_->Flush();
-      parse_tree_->Print(driver_->output_stream_, options_.preorder_parse_tree);
+      const auto& tree_and_subtrees = GetParseTreeAndSubtrees();
+      if (options_.preorder_parse_tree) {
+        tree_and_subtrees.PrintPreorder(driver_->output_stream_);
+      } else {
+        tree_and_subtrees.Print(driver_->output_stream_);
+      }
     }
     }
     if (mem_usage_) {
     if (mem_usage_) {
       mem_usage_->Collect("parse_tree_", *parse_tree_);
       mem_usage_->Collect("parse_tree_", *parse_tree_);
@@ -613,11 +619,15 @@ class Driver::CompilationUnit {
   // Returns information needed to check this unit.
   // Returns information needed to check this unit.
   auto GetCheckUnit() -> Check::Unit {
   auto GetCheckUnit() -> Check::Unit {
     CARBON_CHECK(parse_tree_);
     CARBON_CHECK(parse_tree_);
-    return {.value_stores = &value_stores_,
-            .tokens = &*tokens_,
-            .parse_tree = &*parse_tree_,
-            .consumer = consumer_,
-            .sem_ir = &sem_ir_};
+    return {
+        .value_stores = &value_stores_,
+        .tokens = &*tokens_,
+        .parse_tree = &*parse_tree_,
+        .consumer = consumer_,
+        .get_parse_tree_and_subtrees = [&]() -> const Parse::TreeAndSubtrees& {
+          return GetParseTreeAndSubtrees();
+        },
+        .sem_ir = &sem_ir_};
   }
   }
 
 
   // Runs post-check logic. Returns true if checking succeeded for the IR.
   // Runs post-check logic. Returns true if checking succeeded for the IR.
@@ -778,6 +788,19 @@ class Driver::CompilationUnit {
     return true;
     return true;
   }
   }
 
 
+  // The TreeAndSubtrees is mainly used for debugging and diagnostics, and has
+  // significant overhead. Avoid constructing it when unused.
+  auto GetParseTreeAndSubtrees() -> const Parse::TreeAndSubtrees& {
+    if (!parse_tree_and_subtrees_) {
+      parse_tree_and_subtrees_ = Parse::TreeAndSubtrees(*tokens_, *parse_tree_);
+      if (mem_usage_) {
+        mem_usage_->Collect("parse_tree_and_subtrees_",
+                            *parse_tree_and_subtrees_);
+      }
+    }
+    return *parse_tree_and_subtrees_;
+  }
+
   // Wraps a call with log statements to indicate start and end.
   // Wraps a call with log statements to indicate start and end.
   auto LogCall(llvm::StringLiteral label, llvm::function_ref<void()> fn)
   auto LogCall(llvm::StringLiteral label, llvm::function_ref<void()> fn)
       -> void {
       -> void {
@@ -814,6 +837,7 @@ class Driver::CompilationUnit {
   std::optional<SourceBuffer> source_;
   std::optional<SourceBuffer> source_;
   std::optional<Lex::TokenizedBuffer> tokens_;
   std::optional<Lex::TokenizedBuffer> tokens_;
   std::optional<Parse::Tree> parse_tree_;
   std::optional<Parse::Tree> parse_tree_;
+  std::optional<Parse::TreeAndSubtrees> parse_tree_and_subtrees_;
   std::optional<SemIR::File> sem_ir_;
   std::optional<SemIR::File> sem_ir_;
   std::unique_ptr<llvm::LLVMContext> llvm_context_;
   std::unique_ptr<llvm::LLVMContext> llvm_context_;
   std::unique_ptr<llvm::Module> module_;
   std::unique_ptr<llvm::Module> module_;

+ 5 - 1
toolchain/parse/BUILD

@@ -106,8 +106,12 @@ cc_library(
     srcs = [
     srcs = [
         "extract.cpp",
         "extract.cpp",
         "tree.cpp",
         "tree.cpp",
+        "tree_and_subtrees.cpp",
+    ],
+    hdrs = [
+        "tree.h",
+        "tree_and_subtrees.h",
     ],
     ],
-    hdrs = ["tree.h"],
     deps = [
     deps = [
         ":node_kind",
         ":node_kind",
         "//common:check",
         "//common:check",

+ 15 - 23
toolchain/parse/context.cpp

@@ -68,18 +68,15 @@ Context::Context(Tree& tree, Lex::TokenizedBuffer& tokens,
 
 
 auto Context::AddLeafNode(NodeKind kind, Lex::TokenIndex token, bool has_error)
 auto Context::AddLeafNode(NodeKind kind, Lex::TokenIndex token, bool has_error)
     -> void {
     -> void {
-  tree_->node_impls_.push_back(
-      Tree::NodeImpl(kind, has_error, token, /*subtree_size=*/1));
+  tree_->node_impls_.push_back(Tree::NodeImpl(kind, has_error, token));
   if (has_error) {
   if (has_error) {
     tree_->has_errors_ = true;
     tree_->has_errors_ = true;
   }
   }
 }
 }
 
 
-auto Context::AddNode(NodeKind kind, Lex::TokenIndex token, int subtree_start,
-                      bool has_error) -> void {
-  int subtree_size = tree_->size() - subtree_start + 1;
-  tree_->node_impls_.push_back(
-      Tree::NodeImpl(kind, has_error, token, subtree_size));
+auto Context::AddNode(NodeKind kind, Lex::TokenIndex token, bool has_error)
+    -> void {
+  tree_->node_impls_.push_back(Tree::NodeImpl(kind, has_error, token));
   if (has_error) {
   if (has_error) {
     tree_->has_errors_ = true;
     tree_->has_errors_ = true;
   }
   }
@@ -91,7 +88,6 @@ auto Context::ReplacePlaceholderNode(int32_t position, NodeKind kind,
   CARBON_CHECK(position >= 0 && position < tree_->size())
   CARBON_CHECK(position >= 0 && position < tree_->size())
       << "position: " << position << " size: " << tree_->size();
       << "position: " << position << " size: " << tree_->size();
   auto* node_impl = &tree_->node_impls_[position];
   auto* node_impl = &tree_->node_impls_[position];
-  CARBON_CHECK(node_impl->subtree_size == 1);
   CARBON_CHECK(node_impl->kind == NodeKind::Placeholder);
   CARBON_CHECK(node_impl->kind == NodeKind::Placeholder);
   node_impl->kind = kind;
   node_impl->kind = kind;
   node_impl->has_error = has_error;
   node_impl->has_error = has_error;
@@ -123,9 +119,9 @@ auto Context::ConsumeAndAddCloseSymbol(Lex::TokenIndex expected_open,
   Lex::TokenKind open_token_kind = tokens().GetKind(expected_open);
   Lex::TokenKind open_token_kind = tokens().GetKind(expected_open);
 
 
   if (!open_token_kind.is_opening_symbol()) {
   if (!open_token_kind.is_opening_symbol()) {
-    AddNode(close_kind, state.token, state.subtree_start, /*has_error=*/true);
+    AddNode(close_kind, state.token, /*has_error=*/true);
   } else if (auto close_token = ConsumeIf(open_token_kind.closing_symbol())) {
   } else if (auto close_token = ConsumeIf(open_token_kind.closing_symbol())) {
-    AddNode(close_kind, *close_token, state.subtree_start, state.has_error);
+    AddNode(close_kind, *close_token, state.has_error);
   } else {
   } else {
     // TODO: Include the location of the matching opening delimiter in the
     // TODO: Include the location of the matching opening delimiter in the
     // diagnostic.
     // diagnostic.
@@ -135,7 +131,7 @@ auto Context::ConsumeAndAddCloseSymbol(Lex::TokenIndex expected_open,
                    open_token_kind.closing_symbol().fixed_spelling());
                    open_token_kind.closing_symbol().fixed_spelling());
 
 
     SkipTo(tokens().GetMatchedClosingToken(expected_open));
     SkipTo(tokens().GetMatchedClosingToken(expected_open));
-    AddNode(close_kind, Consume(), state.subtree_start, /*has_error=*/true);
+    AddNode(close_kind, Consume(), /*has_error=*/true);
   }
   }
 }
 }
 
 
@@ -415,7 +411,7 @@ auto Context::AddNodeExpectingDeclSemi(StateStackEntry state,
   }
   }
 
 
   if (auto semi = ConsumeIf(Lex::TokenKind::Semi)) {
   if (auto semi = ConsumeIf(Lex::TokenKind::Semi)) {
-    AddNode(node_kind, *semi, state.subtree_start, /*has_error=*/false);
+    AddNode(node_kind, *semi, /*has_error=*/false);
   } else {
   } else {
     if (is_def_allowed) {
     if (is_def_allowed) {
       DiagnoseExpectedDeclSemiOrDefinition(decl_kind);
       DiagnoseExpectedDeclSemiOrDefinition(decl_kind);
@@ -433,8 +429,7 @@ auto Context::RecoverFromDeclError(StateStackEntry state, NodeKind node_kind,
   if (skip_past_likely_end) {
   if (skip_past_likely_end) {
     token = SkipPastLikelyEnd(token);
     token = SkipPastLikelyEnd(token);
   }
   }
-  AddNode(node_kind, token, state.subtree_start,
-          /*has_error=*/true);
+  AddNode(node_kind, token, /*has_error=*/true);
 }
 }
 
 
 auto Context::ParseLibraryName(bool accept_default)
 auto Context::ParseLibraryName(bool accept_default)
@@ -464,13 +459,11 @@ auto Context::ParseLibraryName(bool accept_default)
 auto Context::ParseLibrarySpecifier(bool accept_default)
 auto Context::ParseLibrarySpecifier(bool accept_default)
     -> std::optional<StringLiteralValueId> {
     -> std::optional<StringLiteralValueId> {
   auto library_token = ConsumeChecked(Lex::TokenKind::Library);
   auto library_token = ConsumeChecked(Lex::TokenKind::Library);
-  auto library_subtree_start = tree().size();
   auto library_id = ParseLibraryName(accept_default);
   auto library_id = ParseLibraryName(accept_default);
   if (!library_id) {
   if (!library_id) {
     AddLeafNode(NodeKind::LibraryName, *position_, /*has_error=*/true);
     AddLeafNode(NodeKind::LibraryName, *position_, /*has_error=*/true);
   }
   }
-  AddNode(NodeKind::LibrarySpecifier, library_token, library_subtree_start,
-          /*has_error=*/false);
+  AddNode(NodeKind::LibrarySpecifier, library_token, /*has_error=*/false);
   return library_id;
   return library_id;
 }
 }
 
 
@@ -503,8 +496,7 @@ static auto ParsingInDeferredDefinitionScope(Context& context) -> bool {
          state == State::DeclDefinitionFinishAsNamedConstraint;
          state == State::DeclDefinitionFinishAsNamedConstraint;
 }
 }
 
 
-auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token,
-                                         int subtree_start, bool has_error)
+auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token, bool has_error)
     -> void {
     -> void {
   if (ParsingInDeferredDefinitionScope(*this)) {
   if (ParsingInDeferredDefinitionScope(*this)) {
     deferred_definition_stack_.push_back(tree_->deferred_definitions_.Add(
     deferred_definition_stack_.push_back(tree_->deferred_definitions_.Add(
@@ -512,11 +504,11 @@ auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token,
              FunctionDefinitionStartId(NodeId(tree_->node_impls_.size()))}));
              FunctionDefinitionStartId(NodeId(tree_->node_impls_.size()))}));
   }
   }
 
 
-  AddNode(NodeKind::FunctionDefinitionStart, token, subtree_start, has_error);
+  AddNode(NodeKind::FunctionDefinitionStart, token, has_error);
 }
 }
 
 
-auto Context::AddFunctionDefinition(Lex::TokenIndex token, int subtree_start,
-                                    bool has_error) -> void {
+auto Context::AddFunctionDefinition(Lex::TokenIndex token, bool has_error)
+    -> void {
   if (ParsingInDeferredDefinitionScope(*this)) {
   if (ParsingInDeferredDefinitionScope(*this)) {
     auto definition_index = deferred_definition_stack_.pop_back_val();
     auto definition_index = deferred_definition_stack_.pop_back_val();
     auto& definition = tree_->deferred_definitions_.Get(definition_index);
     auto& definition = tree_->deferred_definitions_.Get(definition_index);
@@ -526,7 +518,7 @@ auto Context::AddFunctionDefinition(Lex::TokenIndex token, int subtree_start,
         DeferredDefinitionIndex(tree_->deferred_definitions().size());
         DeferredDefinitionIndex(tree_->deferred_definitions().size());
   }
   }
 
 
-  AddNode(NodeKind::FunctionDefinition, token, subtree_start, has_error);
+  AddNode(NodeKind::FunctionDefinition, token, has_error);
 }
 }
 
 
 auto Context::PrintForStackDump(llvm::raw_ostream& output) const -> void {
 auto Context::PrintForStackDump(llvm::raw_ostream& output) const -> void {

+ 4 - 6
toolchain/parse/context.h

@@ -100,8 +100,7 @@ class Context {
       -> void;
       -> void;
 
 
   // Adds a node to the parse tree that has children.
   // Adds a node to the parse tree that has children.
-  auto AddNode(NodeKind kind, Lex::TokenIndex token, int subtree_start,
-               bool has_error) -> void;
+  auto AddNode(NodeKind kind, Lex::TokenIndex token, bool has_error) -> void;
 
 
   // Replaces the placeholder node at the indicated position with a leaf node.
   // Replaces the placeholder node at the indicated position with a leaf node.
   //
   //
@@ -328,12 +327,11 @@ class Context {
 
 
   // Adds a function definition start node, and begins tracking a deferred
   // Adds a function definition start node, and begins tracking a deferred
   // definition if necessary.
   // definition if necessary.
-  auto AddFunctionDefinitionStart(Lex::TokenIndex token, int subtree_start,
-                                  bool has_error) -> void;
+  auto AddFunctionDefinitionStart(Lex::TokenIndex token, bool has_error)
+      -> void;
   // Adds a function definition node, and ends tracking a deferred definition if
   // Adds a function definition node, and ends tracking a deferred definition if
   // necessary.
   // necessary.
-  auto AddFunctionDefinition(Lex::TokenIndex token, int subtree_start,
-                             bool has_error) -> void;
+  auto AddFunctionDefinition(Lex::TokenIndex token, bool has_error) -> void;
 
 
   // Prints information for a stack dump.
   // Prints information for a stack dump.
   auto PrintForStackDump(llvm::raw_ostream& output) const -> void;
   auto PrintForStackDump(llvm::raw_ostream& output) const -> void;

+ 24 - 19
toolchain/parse/extract.cpp

@@ -9,6 +9,7 @@
 #include "common/error.h"
 #include "common/error.h"
 #include "common/struct_reflection.h"
 #include "common/struct_reflection.h"
 #include "toolchain/parse/tree.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/parse/typed_nodes.h"
 #include "toolchain/parse/typed_nodes.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
@@ -20,12 +21,12 @@ namespace {
 class NodeExtractor {
 class NodeExtractor {
  public:
  public:
   struct CheckpointState {
   struct CheckpointState {
-    Tree::SiblingIterator it;
+    TreeAndSubtrees::SiblingIterator it;
   };
   };
 
 
-  NodeExtractor(const Tree* tree, Lex::TokenizedBuffer* tokens,
+  NodeExtractor(const TreeAndSubtrees* tree, const Lex::TokenizedBuffer* tokens,
                 ErrorBuilder* trace, NodeId node_id,
                 ErrorBuilder* trace, NodeId node_id,
-                llvm::iterator_range<Tree::SiblingIterator> children)
+                llvm::iterator_range<TreeAndSubtrees::SiblingIterator> children)
       : tree_(tree),
       : tree_(tree),
         tokens_(tokens),
         tokens_(tokens),
         trace_(trace),
         trace_(trace),
@@ -34,9 +35,11 @@ class NodeExtractor {
         end_(children.end()) {}
         end_(children.end()) {}
 
 
   auto at_end() const -> bool { return it_ == end_; }
   auto at_end() const -> bool { return it_ == end_; }
-  auto kind() const -> NodeKind { return tree_->node_kind(*it_); }
+  auto kind() const -> NodeKind { return tree_->tree().node_kind(*it_); }
   auto has_token() const -> bool { return node_id_.is_valid(); }
   auto has_token() const -> bool { return node_id_.is_valid(); }
-  auto token() const -> Lex::TokenIndex { return tree_->node_token(node_id_); }
+  auto token() const -> Lex::TokenIndex {
+    return tree_->tree().node_token(node_id_);
+  }
   auto token_kind() const -> Lex::TokenKind {
   auto token_kind() const -> Lex::TokenKind {
     return tokens_->GetKind(token());
     return tokens_->GetKind(token());
   }
   }
@@ -73,12 +76,12 @@ class NodeExtractor {
                             std::tuple<U...>* /*type*/) -> std::optional<T>;
                             std::tuple<U...>* /*type*/) -> std::optional<T>;
 
 
  private:
  private:
-  const Tree* tree_;
-  Lex::TokenizedBuffer* tokens_;
+  const TreeAndSubtrees* tree_;
+  const Lex::TokenizedBuffer* tokens_;
   ErrorBuilder* trace_;
   ErrorBuilder* trace_;
   NodeId node_id_;
   NodeId node_id_;
-  Tree::SiblingIterator it_;
-  Tree::SiblingIterator end_;
+  TreeAndSubtrees::SiblingIterator it_;
+  TreeAndSubtrees::SiblingIterator end_;
 };
 };
 }  // namespace
 }  // namespace
 
 
@@ -97,8 +100,8 @@ namespace {
 // };
 // };
 // ```
 // ```
 //
 //
-// Note that `Tree::SiblingIterator`s iterate in reverse order through the
-// children of a node.
+// Note that `TreeAndSubtrees::SiblingIterator`s iterate in reverse order
+// through the children of a node.
 //
 //
 // This class is only in this file.
 // This class is only in this file.
 template <typename T>
 template <typename T>
@@ -320,7 +323,7 @@ auto NodeExtractor::MatchesTokenKind(Lex::TokenKind expected_kind) const
   if (token_kind() != expected_kind) {
   if (token_kind() != expected_kind) {
     if (trace_) {
     if (trace_) {
       *trace_ << "Token " << expected_kind << " expected for "
       *trace_ << "Token " << expected_kind << " expected for "
-              << tree_->node_kind(node_id_) << ", found " << token_kind()
+              << tree_->tree().node_kind(node_id_) << ", found " << token_kind()
               << "\n";
               << "\n";
     }
     }
     return false;
     return false;
@@ -405,14 +408,15 @@ struct Extractable {
 }  // namespace
 }  // namespace
 
 
 template <typename T>
 template <typename T>
-auto Tree::TryExtractNodeFromChildren(
-    NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children,
+auto TreeAndSubtrees::TryExtractNodeFromChildren(
+    NodeId node_id,
+    llvm::iterator_range<TreeAndSubtrees::SiblingIterator> children,
     ErrorBuilder* trace) const -> std::optional<T> {
     ErrorBuilder* trace) const -> std::optional<T> {
   NodeExtractor extractor(this, tokens_, trace, node_id, children);
   NodeExtractor extractor(this, tokens_, trace, node_id, children);
   auto result = Extractable<T>::ExtractImpl(extractor);
   auto result = Extractable<T>::ExtractImpl(extractor);
   if (!extractor.at_end()) {
   if (!extractor.at_end()) {
     if (trace) {
     if (trace) {
-      *trace << "Error: " << node_kind(extractor.ExtractNode())
+      *trace << "Error: " << tree_->node_kind(extractor.ExtractNode())
              << " node left unconsumed.";
              << " node left unconsumed.";
     }
     }
     return std::nullopt;
     return std::nullopt;
@@ -421,16 +425,17 @@ auto Tree::TryExtractNodeFromChildren(
 }
 }
 
 
 // Manually instantiate Tree::TryExtractNodeFromChildren
 // Manually instantiate Tree::TryExtractNodeFromChildren
-#define CARBON_PARSE_NODE_KIND(KindName)                                    \
-  template auto Tree::TryExtractNodeFromChildren<KindName>(                 \
-      NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children, \
+#define CARBON_PARSE_NODE_KIND(KindName)                               \
+  template auto TreeAndSubtrees::TryExtractNodeFromChildren<KindName>( \
+      NodeId node_id,                                                  \
+      llvm::iterator_range<TreeAndSubtrees::SiblingIterator> children, \
       ErrorBuilder * trace) const -> std::optional<KindName>;
       ErrorBuilder * trace) const -> std::optional<KindName>;
 
 
 // Also instantiate for `File`, even though it isn't a parse node.
 // Also instantiate for `File`, even though it isn't a parse node.
 CARBON_PARSE_NODE_KIND(File)
 CARBON_PARSE_NODE_KIND(File)
 #include "toolchain/parse/node_kind.def"
 #include "toolchain/parse/node_kind.def"
 
 
-auto Tree::ExtractFile() const -> File {
+auto TreeAndSubtrees::ExtractFile() const -> File {
   return ExtractNodeFromChildren<File>(NodeId::Invalid, roots());
   return ExtractNodeFromChildren<File>(NodeId::Invalid, roots());
 }
 }
 
 

+ 2 - 4
toolchain/parse/handle_array_expr.cpp

@@ -24,14 +24,12 @@ auto HandleArrayExprSemi(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   auto semi = context.ConsumeIf(Lex::TokenKind::Semi);
   auto semi = context.ConsumeIf(Lex::TokenKind::Semi);
   if (!semi) {
   if (!semi) {
-    context.AddNode(NodeKind::ArrayExprSemi, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::ArrayExprSemi, *context.position(), true);
     CARBON_DIAGNOSTIC(ExpectedArraySemi, Error, "Expected `;` in array type.");
     CARBON_DIAGNOSTIC(ExpectedArraySemi, Error, "Expected `;` in array type.");
     context.emitter().Emit(*context.position(), ExpectedArraySemi);
     context.emitter().Emit(*context.position(), ExpectedArraySemi);
     state.has_error = true;
     state.has_error = true;
   } else {
   } else {
-    context.AddNode(NodeKind::ArrayExprSemi, *semi, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::ArrayExprSemi, *semi, state.has_error);
   }
   }
   context.PushState(state, State::ArrayExprFinish);
   context.PushState(state, State::ArrayExprFinish);
   if (!context.PositionIs(Lex::TokenKind::CloseSquareBracket)) {
   if (!context.PositionIs(Lex::TokenKind::CloseSquareBracket)) {

+ 3 - 5
toolchain/parse/handle_binding_pattern.cpp

@@ -76,7 +76,7 @@ static auto HandleBindingPatternFinish(Context& context, NodeKind node_kind)
     -> void {
     -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(node_kind, state.token, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, state.token, state.has_error);
 
 
   // Propagate errors to the parent state so that they can take different
   // Propagate errors to the parent state so that they can take different
   // actions on invalid patterns.
   // actions on invalid patterns.
@@ -96,8 +96,7 @@ auto HandleBindingPatternFinishAsRegular(Context& context) -> void {
 auto HandleBindingPatternAddr(Context& context) -> void {
 auto HandleBindingPatternAddr(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::Addr, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::Addr, state.token, state.has_error);
 
 
   // If an error was encountered, propagate it while adding a node.
   // If an error was encountered, propagate it while adding a node.
   if (state.has_error) {
   if (state.has_error) {
@@ -108,8 +107,7 @@ auto HandleBindingPatternAddr(Context& context) -> void {
 auto HandleBindingPatternTemplate(Context& context) -> void {
 auto HandleBindingPatternTemplate(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::Template, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::Template, state.token, state.has_error);
 
 
   // If an error was encountered, propagate it while adding a node.
   // If an error was encountered, propagate it while adding a node.
   if (state.has_error) {
   if (state.has_error) {

+ 2 - 4
toolchain/parse/handle_brace_expr.cpp

@@ -151,8 +151,7 @@ static auto HandleBraceExprParamFinish(Context& context, NodeKind node_kind,
                         /*has_error=*/true);
                         /*has_error=*/true);
     context.ReturnErrorOnState();
     context.ReturnErrorOnState();
   } else {
   } else {
-    context.AddNode(node_kind, state.token, state.subtree_start,
-                    /*has_error=*/false);
+    context.AddNode(node_kind, state.token, /*has_error=*/false);
   }
   }
 
 
   if (context.ConsumeListToken(
   if (context.ConsumeListToken(
@@ -183,8 +182,7 @@ static auto HandleBraceExprFinish(Context& context, NodeKind start_kind,
   auto state = context.PopState();
   auto state = context.PopState();
 
 
   context.ReplacePlaceholderNode(state.subtree_start, start_kind, state.token);
   context.ReplacePlaceholderNode(state.subtree_start, start_kind, state.token);
-  context.AddNode(end_kind, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(end_kind, context.Consume(), state.has_error);
 }
 }
 
 
 auto HandleBraceExprFinishAsType(Context& context) -> void {
 auto HandleBraceExprFinishAsType(Context& context) -> void {

+ 2 - 4
toolchain/parse/handle_call_expr.cpp

@@ -11,8 +11,7 @@ auto HandleCallExpr(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   context.PushState(state, State::CallExprFinish);
   context.PushState(state, State::CallExprFinish);
 
 
-  context.AddNode(NodeKind::CallExprStart, context.Consume(),
-                  state.subtree_start, state.has_error);
+  context.AddNode(NodeKind::CallExprStart, context.Consume(), state.has_error);
   if (!context.PositionIs(Lex::TokenKind::CloseParen)) {
   if (!context.PositionIs(Lex::TokenKind::CloseParen)) {
     context.PushState(State::CallExprParamFinish);
     context.PushState(State::CallExprParamFinish);
     context.PushState(State::Expr);
     context.PushState(State::Expr);
@@ -37,8 +36,7 @@ auto HandleCallExprParamFinish(Context& context) -> void {
 auto HandleCallExprFinish(Context& context) -> void {
 auto HandleCallExprFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::CallExpr, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::CallExpr, context.Consume(), state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 4 - 4
toolchain/parse/handle_choice.cpp

@@ -24,17 +24,17 @@ auto HandleChoiceDefinitionStart(Context& context) -> void {
     }
     }
 
 
     context.AddNode(NodeKind::ChoiceDefinitionStart, *context.position(),
     context.AddNode(NodeKind::ChoiceDefinitionStart, *context.position(),
-                    state.subtree_start, /*has_error=*/true);
+                    /*has_error=*/true);
 
 
     context.AddNode(NodeKind::ChoiceDefinition, *context.position(),
     context.AddNode(NodeKind::ChoiceDefinition, *context.position(),
-                    state.subtree_start, /*has_error=*/true);
+                    /*has_error=*/true);
 
 
     context.SkipPastLikelyEnd(*context.position());
     context.SkipPastLikelyEnd(*context.position());
     return;
     return;
   }
   }
 
 
   context.AddNode(NodeKind::ChoiceDefinitionStart, context.Consume(),
   context.AddNode(NodeKind::ChoiceDefinitionStart, context.Consume(),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 
 
   state.has_error = false;
   state.has_error = false;
   state.state = State::ChoiceDefinitionFinish;
   state.state = State::ChoiceDefinitionFinish;
@@ -94,6 +94,6 @@ auto HandleChoiceDefinitionFinish(Context& context) -> void {
 
 
   context.AddNode(NodeKind::ChoiceDefinition,
   context.AddNode(NodeKind::ChoiceDefinition,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 }
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_code_block.cpp

@@ -31,11 +31,9 @@ auto HandleCodeBlockFinish(Context& context) -> void {
 
 
   // If the block started with an open curly, this is a close curly.
   // If the block started with an open curly, this is a close curly.
   if (context.tokens().GetKind(state.token) == Lex::TokenKind::OpenCurlyBrace) {
   if (context.tokens().GetKind(state.token) == Lex::TokenKind::OpenCurlyBrace) {
-    context.AddNode(NodeKind::CodeBlock, context.Consume(), state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::CodeBlock, context.Consume(), state.has_error);
   } else {
   } else {
-    context.AddNode(NodeKind::CodeBlock, state.token, state.subtree_start,
-                    /*has_error=*/true);
+    context.AddNode(NodeKind::CodeBlock, state.token, /*has_error=*/true);
   }
   }
 }
 }
 
 

+ 2 - 4
toolchain/parse/handle_decl_definition.cpp

@@ -23,8 +23,7 @@ static auto HandleDeclOrDefinition(Context& context, NodeKind decl_kind,
 
 
   context.PushState(state, definition_finish_state);
   context.PushState(state, definition_finish_state);
   context.PushState(State::DeclScopeLoop);
   context.PushState(State::DeclScopeLoop);
-  context.AddNode(definition_start_kind, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(definition_start_kind, context.Consume(), state.has_error);
 }
 }
 
 
 auto HandleDeclOrDefinitionAsClass(Context& context) -> void {
 auto HandleDeclOrDefinitionAsClass(Context& context) -> void {
@@ -56,8 +55,7 @@ static auto HandleDeclDefinitionFinish(Context& context,
                                        NodeKind definition_kind) -> void {
                                        NodeKind definition_kind) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(definition_kind, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(definition_kind, context.Consume(), state.has_error);
 }
 }
 
 
 auto HandleDeclDefinitionFinishAsClass(Context& context) -> void {
 auto HandleDeclDefinitionFinishAsClass(Context& context) -> void {

+ 2 - 3
toolchain/parse/handle_decl_name_and_params.cpp

@@ -41,7 +41,7 @@ auto HandleDeclNameAndParams(Context& context) -> void {
     case Lex::TokenKind::Period:
     case Lex::TokenKind::Period:
       context.AddNode(NodeKind::NameQualifier,
       context.AddNode(NodeKind::NameQualifier,
                       context.ConsumeChecked(Lex::TokenKind::Period),
                       context.ConsumeChecked(Lex::TokenKind::Period),
-                      state.subtree_start, state.has_error);
+                      state.has_error);
       context.PushState(State::DeclNameAndParams);
       context.PushState(State::DeclNameAndParams);
       break;
       break;
 
 
@@ -83,8 +83,7 @@ auto HandleDeclNameAndParamsAfterParams(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
   if (auto period = context.ConsumeIf(Lex::TokenKind::Period)) {
   if (auto period = context.ConsumeIf(Lex::TokenKind::Period)) {
-    context.AddNode(NodeKind::NameQualifier, *period, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::NameQualifier, *period, state.has_error);
     context.PushState(State::DeclNameAndParams);
     context.PushState(State::DeclNameAndParams);
   }
   }
 }
 }

+ 2 - 4
toolchain/parse/handle_decl_scope_loop.cpp

@@ -18,8 +18,7 @@ static auto FinishAndSkipInvalidDecl(Context& context, int32_t subtree_start)
   context.ReplacePlaceholderNode(subtree_start, NodeKind::InvalidParseStart,
   context.ReplacePlaceholderNode(subtree_start, NodeKind::InvalidParseStart,
                                  cursor, /*has_error=*/true);
                                  cursor, /*has_error=*/true);
   context.AddNode(NodeKind::InvalidParseSubtree,
   context.AddNode(NodeKind::InvalidParseSubtree,
-                  context.SkipPastLikelyEnd(cursor), subtree_start,
-                  /*has_error=*/true);
+                  context.SkipPastLikelyEnd(cursor), /*has_error=*/true);
 }
 }
 
 
 // Prints a diagnostic and calls FinishAndSkipInvalidDecl.
 // Prints a diagnostic and calls FinishAndSkipInvalidDecl.
@@ -226,12 +225,11 @@ static auto TryHandleAsModifier(Context& context) -> bool {
       auto extern_token = context.Consume();
       auto extern_token = context.Consume();
       if (context.PositionIs(Lex::TokenKind::Library)) {
       if (context.PositionIs(Lex::TokenKind::Library)) {
         // `extern library <owning_library>` syntax.
         // `extern library <owning_library>` syntax.
-        auto subtree_start = context.tree().size();
         context.ParseLibrarySpecifier(/*accept_default=*/true);
         context.ParseLibrarySpecifier(/*accept_default=*/true);
         // TODO: Consider error recovery when a non-declaration token is next,
         // TODO: Consider error recovery when a non-declaration token is next,
         // like a typo of the library name.
         // like a typo of the library name.
         context.AddNode(NodeKind::ExternModifierWithLibrary, extern_token,
         context.AddNode(NodeKind::ExternModifierWithLibrary, extern_token,
-                        subtree_start, /*has_error=*/false);
+                        /*has_error=*/false);
       } else {
       } else {
         // `extern` syntax without a library.
         // `extern` syntax without a library.
         context.AddLeafNode(NodeKind::ExternModifier, extern_token);
         context.AddLeafNode(NodeKind::ExternModifier, extern_token);

+ 9 - 15
toolchain/parse/handle_expr.cpp

@@ -276,12 +276,12 @@ auto HandleExprLoop(Context& context) -> void {
       // node so that checking can insert control flow here.
       // node so that checking can insert control flow here.
       case Lex::TokenKind::And:
       case Lex::TokenKind::And:
         context.AddNode(NodeKind::ShortCircuitOperandAnd, state.token,
         context.AddNode(NodeKind::ShortCircuitOperandAnd, state.token,
-                        state.subtree_start, state.has_error);
+                        state.has_error);
         state.state = State::ExprLoopForShortCircuitOperatorAsAnd;
         state.state = State::ExprLoopForShortCircuitOperatorAsAnd;
         break;
         break;
       case Lex::TokenKind::Or:
       case Lex::TokenKind::Or:
         context.AddNode(NodeKind::ShortCircuitOperandOr, state.token,
         context.AddNode(NodeKind::ShortCircuitOperandOr, state.token,
-                        state.subtree_start, state.has_error);
+                        state.has_error);
         state.state = State::ExprLoopForShortCircuitOperatorAsOr;
         state.state = State::ExprLoopForShortCircuitOperatorAsOr;
         break;
         break;
 
 
@@ -307,8 +307,7 @@ auto HandleExprLoop(Context& context) -> void {
                        << operator_kind;
                        << operator_kind;
     }
     }
 
 
-    context.AddNode(node_kind, state.token, state.subtree_start,
-                    state.has_error);
+    context.AddNode(node_kind, state.token, state.has_error);
     state.has_error = false;
     state.has_error = false;
     context.PushState(state);
     context.PushState(state);
   }
   }
@@ -318,7 +317,7 @@ auto HandleExprLoop(Context& context) -> void {
 static auto HandleExprLoopForOperator(Context& context,
 static auto HandleExprLoopForOperator(Context& context,
                                       Context::StateStackEntry state,
                                       Context::StateStackEntry state,
                                       NodeKind node_kind) -> void {
                                       NodeKind node_kind) -> void {
-  context.AddNode(node_kind, state.token, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, state.token, state.has_error);
   state.has_error = false;
   state.has_error = false;
   context.PushState(state, State::ExprLoop);
   context.PushState(state, State::ExprLoop);
 }
 }
@@ -371,8 +370,7 @@ auto HandleExprLoopForShortCircuitOperatorAsOr(Context& context) -> void {
 auto HandleIfExprFinishCondition(Context& context) -> void {
 auto HandleIfExprFinishCondition(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::IfExprIf, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfExprIf, state.token, state.has_error);
 
 
   if (context.PositionIs(Lex::TokenKind::Then)) {
   if (context.PositionIs(Lex::TokenKind::Then)) {
     context.PushState(State::IfExprFinishThen);
     context.PushState(State::IfExprFinishThen);
@@ -397,8 +395,7 @@ auto HandleIfExprFinishCondition(Context& context) -> void {
 auto HandleIfExprFinishThen(Context& context) -> void {
 auto HandleIfExprFinishThen(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::IfExprThen, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfExprThen, state.token, state.has_error);
 
 
   if (context.PositionIs(Lex::TokenKind::Else)) {
   if (context.PositionIs(Lex::TokenKind::Else)) {
     context.PushState(State::IfExprFinishElse);
     context.PushState(State::IfExprFinishElse);
@@ -431,16 +428,14 @@ auto HandleIfExprFinishElse(Context& context) -> void {
 auto HandleIfExprFinish(Context& context) -> void {
 auto HandleIfExprFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::IfExprElse, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfExprElse, state.token, state.has_error);
 }
 }
 
 
 auto HandleExprStatementFinish(Context& context) -> void {
 auto HandleExprStatementFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
   if (auto semi = context.ConsumeIf(Lex::TokenKind::Semi)) {
   if (auto semi = context.ConsumeIf(Lex::TokenKind::Semi)) {
-    context.AddNode(NodeKind::ExprStatement, *semi, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::ExprStatement, *semi, state.has_error);
     return;
     return;
   }
   }
 
 
@@ -451,8 +446,7 @@ auto HandleExprStatementFinish(Context& context) -> void {
   }
   }
 
 
   context.AddNode(NodeKind::ExprStatement,
   context.AddNode(NodeKind::ExprStatement,
-                  context.SkipPastLikelyEnd(state.token), state.subtree_start,
-                  /*has_error=*/true);
+                  context.SkipPastLikelyEnd(state.token), /*has_error=*/true);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 6 - 9
toolchain/parse/handle_function.cpp

@@ -31,8 +31,7 @@ auto HandleFunctionAfterParams(Context& context) -> void {
 auto HandleFunctionReturnTypeFinish(Context& context) -> void {
 auto HandleFunctionReturnTypeFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::ReturnType, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ReturnType, state.token, state.has_error);
 }
 }
 
 
 auto HandleFunctionSignatureFinish(Context& context) -> void {
 auto HandleFunctionSignatureFinish(Context& context) -> void {
@@ -41,12 +40,11 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
   switch (context.PositionKind()) {
   switch (context.PositionKind()) {
     case Lex::TokenKind::Semi: {
     case Lex::TokenKind::Semi: {
       context.AddNode(NodeKind::FunctionDecl, context.Consume(),
       context.AddNode(NodeKind::FunctionDecl, context.Consume(),
-                      state.subtree_start, state.has_error);
+                      state.has_error);
       break;
       break;
     }
     }
     case Lex::TokenKind::OpenCurlyBrace: {
     case Lex::TokenKind::OpenCurlyBrace: {
-      context.AddFunctionDefinitionStart(context.Consume(), state.subtree_start,
-                                         state.has_error);
+      context.AddFunctionDefinitionStart(context.Consume(), state.has_error);
       // Any error is recorded on the FunctionDefinitionStart.
       // Any error is recorded on the FunctionDefinitionStart.
       state.has_error = false;
       state.has_error = false;
       context.PushState(state, State::FunctionDefinitionFinish);
       context.PushState(state, State::FunctionDefinitionFinish);
@@ -55,7 +53,7 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
     }
     }
     case Lex::TokenKind::Equal: {
     case Lex::TokenKind::Equal: {
       context.AddNode(NodeKind::BuiltinFunctionDefinitionStart,
       context.AddNode(NodeKind::BuiltinFunctionDefinitionStart,
-                      context.Consume(), state.subtree_start, state.has_error);
+                      context.Consume(), state.has_error);
       if (!context.ConsumeAndAddLeafNodeIf(Lex::TokenKind::StringLiteral,
       if (!context.ConsumeAndAddLeafNodeIf(Lex::TokenKind::StringLiteral,
                                            NodeKind::BuiltinName)) {
                                            NodeKind::BuiltinName)) {
         CARBON_DIAGNOSTIC(ExpectedBuiltinName, Error,
         CARBON_DIAGNOSTIC(ExpectedBuiltinName, Error,
@@ -73,7 +71,7 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
                                      /*skip_past_likely_end=*/true);
                                      /*skip_past_likely_end=*/true);
       } else {
       } else {
         context.AddNode(NodeKind::BuiltinFunctionDefinition, *semi,
         context.AddNode(NodeKind::BuiltinFunctionDefinition, *semi,
-                        state.subtree_start, state.has_error);
+                        state.has_error);
       }
       }
       break;
       break;
     }
     }
@@ -94,8 +92,7 @@ auto HandleFunctionSignatureFinish(Context& context) -> void {
 
 
 auto HandleFunctionDefinitionFinish(Context& context) -> void {
 auto HandleFunctionDefinitionFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
-  context.AddFunctionDefinition(context.Consume(), state.subtree_start,
-                                state.has_error);
+  context.AddFunctionDefinition(context.Consume(), state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_impl.cpp

@@ -54,8 +54,7 @@ auto HandleImplAfterForall(Context& context) -> void {
   if (state.has_error) {
   if (state.has_error) {
     context.ReturnErrorOnState();
     context.ReturnErrorOnState();
   }
   }
-  context.AddNode(NodeKind::ImplForall, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ImplForall, state.token, state.has_error);
   // One of:
   // One of:
   //   as <expression> ...
   //   as <expression> ...
   //   <expression> as <expression>...
   //   <expression> as <expression>...
@@ -65,8 +64,7 @@ auto HandleImplAfterForall(Context& context) -> void {
 auto HandleImplBeforeAs(Context& context) -> void {
 auto HandleImplBeforeAs(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   if (auto as = context.ConsumeIf(Lex::TokenKind::As)) {
   if (auto as = context.ConsumeIf(Lex::TokenKind::As)) {
-    context.AddNode(NodeKind::TypeImplAs, *as, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::TypeImplAs, *as, state.has_error);
     context.PushState(State::Expr);
     context.PushState(State::Expr);
   } else {
   } else {
     if (!state.has_error) {
     if (!state.has_error) {

+ 2 - 2
toolchain/parse/handle_import_and_package.cpp

@@ -16,7 +16,7 @@ namespace Carbon::Parse {
 static auto OnParseError(Context& context, Context::StateStackEntry state,
 static auto OnParseError(Context& context, Context::StateStackEntry state,
                          NodeKind declaration) -> void {
                          NodeKind declaration) -> void {
   return context.AddNode(declaration, context.SkipPastLikelyEnd(state.token),
   return context.AddNode(declaration, context.SkipPastLikelyEnd(state.token),
-                         state.subtree_start, /*has_error=*/true);
+                         /*has_error=*/true);
 }
 }
 
 
 // Determines whether the specified modifier appears within the introducer of
 // Determines whether the specified modifier appears within the introducer of
@@ -109,7 +109,7 @@ static auto HandleDeclContent(Context& context, Context::StateStackEntry state,
       context.set_packaging_decl(names, is_impl);
       context.set_packaging_decl(names, is_impl);
     }
     }
 
 
-    context.AddNode(declaration, *semi, state.subtree_start, state.has_error);
+    context.AddNode(declaration, *semi, state.has_error);
   } else {
   } else {
     context.DiagnoseExpectedDeclSemi(context.tokens().GetKind(state.token));
     context.DiagnoseExpectedDeclSemi(context.tokens().GetKind(state.token));
     on_parse_error();
     on_parse_error();

+ 1 - 1
toolchain/parse/handle_index_expr.cpp

@@ -13,7 +13,7 @@ auto HandleIndexExpr(Context& context) -> void {
   context.PushState(state, State::IndexExprFinish);
   context.PushState(state, State::IndexExprFinish);
   context.AddNode(NodeKind::IndexExprStart,
   context.AddNode(NodeKind::IndexExprStart,
                   context.ConsumeChecked(Lex::TokenKind::OpenSquareBracket),
                   context.ConsumeChecked(Lex::TokenKind::OpenSquareBracket),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
   context.PushState(State::Expr);
   context.PushState(State::Expr);
 }
 }
 
 

+ 1 - 2
toolchain/parse/handle_let.cpp

@@ -45,8 +45,7 @@ auto HandleLetFinish(Context& context) -> void {
     state.has_error = true;
     state.has_error = true;
     end_token = context.SkipPastLikelyEnd(state.token);
     end_token = context.SkipPastLikelyEnd(state.token);
   }
   }
-  context.AddNode(NodeKind::LetDecl, end_token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::LetDecl, end_token, state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 18 - 28
toolchain/parse/handle_match.cpp

@@ -20,8 +20,8 @@ static auto HandleStatementsBlockStart(Context& context, State finish,
     }
     }
 
 
     context.AddLeafNode(equal_greater, *context.position(), true);
     context.AddLeafNode(equal_greater, *context.position(), true);
-    context.AddNode(starter, *context.position(), state.subtree_start, true);
-    context.AddNode(complete, *context.position(), state.subtree_start, true);
+    context.AddNode(starter, *context.position(), true);
+    context.AddNode(complete, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     context.SkipPastLikelyEnd(*context.position());
     return;
     return;
   }
   }
@@ -35,14 +35,13 @@ static auto HandleStatementsBlockStart(Context& context, State finish,
       context.emitter().Emit(*context.position(), ExpectedMatchCaseBlock);
       context.emitter().Emit(*context.position(), ExpectedMatchCaseBlock);
     }
     }
 
 
-    context.AddNode(starter, *context.position(), state.subtree_start, true);
-    context.AddNode(complete, *context.position(), state.subtree_start, true);
+    context.AddNode(starter, *context.position(), true);
+    context.AddNode(complete, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     context.SkipPastLikelyEnd(*context.position());
     return;
     return;
   }
   }
 
 
-  context.AddNode(starter, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(starter, context.Consume(), state.has_error);
   context.PushState(state, finish);
   context.PushState(state, finish);
   context.PushState(State::StatementScopeLoop);
   context.PushState(State::StatementScopeLoop);
 }
 }
@@ -77,16 +76,14 @@ auto HandleMatchConditionFinish(Context& context) -> void {
       context.emitter().Emit(*context.position(), ExpectedMatchCasesBlock);
       context.emitter().Emit(*context.position(), ExpectedMatchCasesBlock);
     }
     }
 
 
-    context.AddNode(NodeKind::MatchStatementStart, *context.position(),
-                    state.subtree_start, true);
-    context.AddNode(NodeKind::MatchStatement, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::MatchStatementStart, *context.position(), true);
+    context.AddNode(NodeKind::MatchStatement, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     context.SkipPastLikelyEnd(*context.position());
     return;
     return;
   }
   }
 
 
   context.AddNode(NodeKind::MatchStatementStart, context.Consume(),
   context.AddNode(NodeKind::MatchStatementStart, context.Consume(),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 
 
   state.has_error = false;
   state.has_error = false;
   if (context.PositionIs(Lex::TokenKind::CloseCurlyBrace)) {
   if (context.PositionIs(Lex::TokenKind::CloseCurlyBrace)) {
@@ -145,10 +142,8 @@ auto HandleMatchCaseIntroducer(Context& context) -> void {
 auto HandleMatchCaseAfterPattern(Context& context) -> void {
 auto HandleMatchCaseAfterPattern(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   if (state.has_error) {
   if (state.has_error) {
-    context.AddNode(NodeKind::MatchCaseStart, *context.position(),
-                    state.subtree_start, true);
-    context.AddNode(NodeKind::MatchCase, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::MatchCaseStart, *context.position(), true);
+    context.AddNode(NodeKind::MatchCase, *context.position(), true);
     context.SkipPastLikelyEnd(*context.position());
     context.SkipPastLikelyEnd(*context.position());
     return;
     return;
   }
   }
@@ -166,13 +161,10 @@ auto HandleMatchCaseAfterPattern(Context& context) -> void {
                           true);
                           true);
       context.AddLeafNode(NodeKind::InvalidParse, *context.position(), true);
       context.AddLeafNode(NodeKind::InvalidParse, *context.position(), true);
       state = context.PopState();
       state = context.PopState();
-      context.AddNode(NodeKind::MatchCaseGuard, *context.position(),
-                      state.subtree_start, true);
+      context.AddNode(NodeKind::MatchCaseGuard, *context.position(), true);
       state = context.PopState();
       state = context.PopState();
-      context.AddNode(NodeKind::MatchCaseStart, *context.position(),
-                      state.subtree_start, true);
-      context.AddNode(NodeKind::MatchCase, *context.position(),
-                      state.subtree_start, true);
+      context.AddNode(NodeKind::MatchCaseStart, *context.position(), true);
+      context.AddNode(NodeKind::MatchCase, *context.position(), true);
       context.SkipPastLikelyEnd(*context.position());
       context.SkipPastLikelyEnd(*context.position());
       return;
       return;
     }
     }
@@ -184,11 +176,9 @@ auto HandleMatchCaseGuardFinish(Context& context) -> void {
 
 
   auto close_paren = context.ConsumeIf(Lex::TokenKind::CloseParen);
   auto close_paren = context.ConsumeIf(Lex::TokenKind::CloseParen);
   if (close_paren) {
   if (close_paren) {
-    context.AddNode(NodeKind::MatchCaseGuard, *close_paren, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::MatchCaseGuard, *close_paren, state.has_error);
   } else {
   } else {
-    context.AddNode(NodeKind::MatchCaseGuard, *context.position(),
-                    state.subtree_start, true);
+    context.AddNode(NodeKind::MatchCaseGuard, *context.position(), true);
     context.ReturnErrorOnState();
     context.ReturnErrorOnState();
     context.SkipPastLikelyEnd(*context.position());
     context.SkipPastLikelyEnd(*context.position());
     return;
     return;
@@ -205,7 +195,7 @@ auto HandleMatchCaseFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   context.AddNode(NodeKind::MatchCase,
   context.AddNode(NodeKind::MatchCase,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 }
 
 
 auto HandleMatchDefaultIntroducer(Context& context) -> void {
 auto HandleMatchDefaultIntroducer(Context& context) -> void {
@@ -220,14 +210,14 @@ auto HandleMatchDefaultFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   context.AddNode(NodeKind::MatchDefault,
   context.AddNode(NodeKind::MatchDefault,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 }
 
 
 auto HandleMatchStatementFinish(Context& context) -> void {
 auto HandleMatchStatementFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   context.AddNode(NodeKind::MatchStatement,
   context.AddNode(NodeKind::MatchStatement,
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
                   context.ConsumeChecked(Lex::TokenKind::CloseCurlyBrace),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_paren_expr.cpp

@@ -21,8 +21,7 @@ auto HandleOnlyParenExpr(Context& context) -> void {
 
 
 static auto FinishParenExpr(Context& context,
 static auto FinishParenExpr(Context& context,
                             const Context::StateStackEntry& state) -> void {
                             const Context::StateStackEntry& state) -> void {
-  context.AddNode(NodeKind::ParenExpr, context.Consume(), state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ParenExpr, context.Consume(), state.has_error);
 }
 }
 
 
 auto HandleOnlyParenExprFinish(Context& context) -> void {
 auto HandleOnlyParenExprFinish(Context& context) -> void {
@@ -108,8 +107,7 @@ auto HandleTupleLiteralFinish(Context& context) -> void {
 
 
   context.ReplacePlaceholderNode(state.subtree_start,
   context.ReplacePlaceholderNode(state.subtree_start,
                                  NodeKind::TupleLiteralStart, state.token);
                                  NodeKind::TupleLiteralStart, state.token);
-  context.AddNode(NodeKind::TupleLiteral, context.Consume(),
-                  state.subtree_start, state.has_error);
+  context.AddNode(NodeKind::TupleLiteral, context.Consume(), state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 1 - 1
toolchain/parse/handle_pattern_list.cpp

@@ -88,7 +88,7 @@ static auto HandlePatternListFinish(Context& context, NodeKind node_kind,
   auto state = context.PopState();
   auto state = context.PopState();
 
 
   context.AddNode(node_kind, context.ConsumeChecked(token_kind),
   context.AddNode(node_kind, context.ConsumeChecked(token_kind),
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 }
 
 
 auto HandlePatternListFinishAsImplicit(Context& context) -> void {
 auto HandlePatternListFinishAsImplicit(Context& context) -> void {

+ 3 - 4
toolchain/parse/handle_period.cpp

@@ -50,7 +50,7 @@ static auto HandlePeriodOrArrow(Context& context, NodeKind node_kind,
     }
     }
   }
   }
 
 
-  context.AddNode(node_kind, dot, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, dot, state.has_error);
 }
 }
 
 
 auto HandlePeriodAsExpr(Context& context) -> void {
 auto HandlePeriodAsExpr(Context& context) -> void {
@@ -72,14 +72,13 @@ auto HandleArrowExpr(Context& context) -> void {
 
 
 auto HandleCompoundMemberAccess(Context& context) -> void {
 auto HandleCompoundMemberAccess(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
-  context.AddNode(NodeKind::MemberAccessExpr, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::MemberAccessExpr, state.token, state.has_error);
 }
 }
 
 
 auto HandleCompoundPointerMemberAccess(Context& context) -> void {
 auto HandleCompoundPointerMemberAccess(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
   context.AddNode(NodeKind::PointerMemberAccessExpr, state.token,
   context.AddNode(NodeKind::PointerMemberAccessExpr, state.token,
-                  state.subtree_start, state.has_error);
+                  state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 5 - 9
toolchain/parse/handle_statement.cpp

@@ -85,7 +85,7 @@ static auto HandleStatementKeywordFinish(Context& context, NodeKind node_kind)
     // Recover to the next semicolon if possible.
     // Recover to the next semicolon if possible.
     semi = context.SkipPastLikelyEnd(state.token);
     semi = context.SkipPastLikelyEnd(state.token);
   }
   }
-  context.AddNode(node_kind, *semi, state.subtree_start, state.has_error);
+  context.AddNode(node_kind, *semi, state.has_error);
 }
 }
 
 
 auto HandleStatementBreakFinish(Context& context) -> void {
 auto HandleStatementBreakFinish(Context& context) -> void {
@@ -141,8 +141,7 @@ auto HandleStatementForHeaderFinish(Context& context) -> void {
 auto HandleStatementForFinish(Context& context) -> void {
 auto HandleStatementForFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::ForStatement, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ForStatement, state.token, state.has_error);
 }
 }
 
 
 auto HandleStatementIf(Context& context) -> void {
 auto HandleStatementIf(Context& context) -> void {
@@ -170,15 +169,13 @@ auto HandleStatementIfThenBlockFinish(Context& context) -> void {
                           ? State::StatementIf
                           ? State::StatementIf
                           : State::CodeBlock);
                           : State::CodeBlock);
   } else {
   } else {
-    context.AddNode(NodeKind::IfStatement, state.token, state.subtree_start,
-                    state.has_error);
+    context.AddNode(NodeKind::IfStatement, state.token, state.has_error);
   }
   }
 }
 }
 
 
 auto HandleStatementIfElseBlockFinish(Context& context) -> void {
 auto HandleStatementIfElseBlockFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
-  context.AddNode(NodeKind::IfStatement, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::IfStatement, state.token, state.has_error);
 }
 }
 
 
 auto HandleStatementReturn(Context& context) -> void {
 auto HandleStatementReturn(Context& context) -> void {
@@ -234,8 +231,7 @@ auto HandleStatementWhileConditionFinish(Context& context) -> void {
 auto HandleStatementWhileBlockFinish(Context& context) -> void {
 auto HandleStatementWhileBlockFinish(Context& context) -> void {
   auto state = context.PopState();
   auto state = context.PopState();
 
 
-  context.AddNode(NodeKind::WhileStatement, state.token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::WhileStatement, state.token, state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 2 - 4
toolchain/parse/handle_var.cpp

@@ -85,8 +85,7 @@ auto HandleVarFinishAsDecl(Context& context) -> void {
     state.has_error = true;
     state.has_error = true;
     end_token = context.SkipPastLikelyEnd(state.token);
     end_token = context.SkipPastLikelyEnd(state.token);
   }
   }
-  context.AddNode(NodeKind::VariableDecl, end_token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::VariableDecl, end_token, state.has_error);
 }
 }
 
 
 auto HandleVarFinishAsFor(Context& context) -> void {
 auto HandleVarFinishAsFor(Context& context) -> void {
@@ -108,8 +107,7 @@ auto HandleVarFinishAsFor(Context& context) -> void {
     state.has_error = true;
     state.has_error = true;
   }
   }
 
 
-  context.AddNode(NodeKind::ForIn, end_token, state.subtree_start,
-                  state.has_error);
+  context.AddNode(NodeKind::ForIn, end_token, state.has_error);
 }
 }
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 22 - 245
toolchain/parse/tree.cpp

@@ -10,6 +10,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/node_kind.h"
 #include "toolchain/parse/node_kind.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/parse/typed_nodes.h"
 #include "toolchain/parse/typed_nodes.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
@@ -20,28 +21,6 @@ auto Tree::postorder() const -> llvm::iterator_range<PostorderIterator> {
       PostorderIterator(NodeId(node_impls_.size())));
       PostorderIterator(NodeId(node_impls_.size())));
 }
 }
 
 
-auto Tree::postorder(NodeId n) const
-    -> llvm::iterator_range<PostorderIterator> {
-  // The postorder ends after this node, the root, and begins at the start of
-  // its subtree.
-  int start_index = n.index - node_impls_[n.index].subtree_size + 1;
-  return PostorderIterator::MakeRange(NodeId(start_index), n);
-}
-
-auto Tree::children(NodeId n) const -> llvm::iterator_range<SiblingIterator> {
-  CARBON_CHECK(n.is_valid());
-  int end_index = n.index - node_impls_[n.index].subtree_size;
-  return llvm::iterator_range<SiblingIterator>(
-      SiblingIterator(*this, NodeId(n.index - 1)),
-      SiblingIterator(*this, NodeId(end_index)));
-}
-
-auto Tree::roots() const -> llvm::iterator_range<SiblingIterator> {
-  return llvm::iterator_range<SiblingIterator>(
-      SiblingIterator(*this, NodeId(static_cast<int>(node_impls_.size()) - 1)),
-      SiblingIterator(*this, NodeId(-1)));
-}
-
 auto Tree::node_has_error(NodeId n) const -> bool {
 auto Tree::node_has_error(NodeId n) const -> bool {
   CARBON_CHECK(n.is_valid());
   CARBON_CHECK(n.is_valid());
   return node_impls_[n.index].has_error;
   return node_impls_[n.index].has_error;
@@ -57,245 +36,47 @@ auto Tree::node_token(NodeId n) const -> Lex::TokenIndex {
   return node_impls_[n.index].token;
   return node_impls_[n.index].token;
 }
 }
 
 
-auto Tree::node_subtree_size(NodeId n) const -> int32_t {
-  CARBON_CHECK(n.is_valid());
-  return node_impls_[n.index].subtree_size;
-}
-
-auto Tree::PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
-                     bool preorder) const -> bool {
-  const auto& n_impl = node_impls_[n.index];
-  output.indent(2 * (depth + 2));
-  output << "{";
-  // If children are being added, include node_index in order to disambiguate
-  // nodes.
-  if (preorder) {
-    output << "node_index: " << n << ", ";
-  }
-  output << "kind: '" << n_impl.kind << "', text: '"
-         << tokens_->GetTokenText(n_impl.token) << "'";
-
-  if (n_impl.has_error) {
-    output << ", has_error: yes";
-  }
-
-  if (n_impl.subtree_size > 1) {
-    output << ", subtree_size: " << n_impl.subtree_size;
-    if (preorder) {
-      output << ", children: [\n";
-      return true;
-    }
-  }
-  output << "}";
-  return false;
-}
-
 auto Tree::Print(llvm::raw_ostream& output) const -> void {
 auto Tree::Print(llvm::raw_ostream& output) const -> void {
-  output << "- filename: " << tokens_->source().filename() << "\n"
-         << "  parse_tree: [\n";
-
-  // Walk the tree just to calculate depths for each node.
-  llvm::SmallVector<int> indents;
-  indents.append(size(), 0);
-
-  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
-  for (NodeId n : roots()) {
-    node_stack.push_back({n, 0});
-  }
-
-  while (!node_stack.empty()) {
-    NodeId n = NodeId::Invalid;
-    int depth;
-    std::tie(n, depth) = node_stack.pop_back_val();
-    for (NodeId sibling_n : children(n)) {
-      indents[sibling_n.index] = depth + 1;
-      node_stack.push_back({sibling_n, depth + 1});
-    }
-  }
-
-  for (NodeId n : postorder()) {
-    PrintNode(output, n, indents[n.index], /*preorder=*/false);
-    output << ",\n";
-  }
-  output << "  ]\n";
-}
-
-auto Tree::Print(llvm::raw_ostream& output, bool preorder) const -> void {
-  if (!preorder) {
-    Print(output);
-    return;
-  }
-
-  output << "- filename: " << tokens_->source().filename() << "\n"
-         << "  parse_tree: [\n";
-
-  // The parse tree is stored in postorder. The preorder can be constructed
-  // by reversing the order of each level of siblings within an RPO. The
-  // sibling iterators are directly built around RPO and so can be used with a
-  // stack to produce preorder.
-
-  // The roots, like siblings, are in RPO (so reversed), but we add them in
-  // order here because we'll pop off the stack effectively reversing then.
-  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
-  for (NodeId n : roots()) {
-    node_stack.push_back({n, 0});
-  }
-
-  while (!node_stack.empty()) {
-    NodeId n = NodeId::Invalid;
-    int depth;
-    std::tie(n, depth) = node_stack.pop_back_val();
-
-    if (PrintNode(output, n, depth, /*preorder=*/true)) {
-      // Has children, so we descend. We append the children in order here as
-      // well because they will get reversed when popped off the stack.
-      for (NodeId sibling_n : children(n)) {
-        node_stack.push_back({sibling_n, depth + 1});
-      }
-      continue;
-    }
-
-    int next_depth = node_stack.empty() ? 0 : node_stack.back().second;
-    CARBON_CHECK(next_depth <= depth) << "Cannot have the next depth increase!";
-    for (int close_children_count : llvm::seq(0, depth - next_depth)) {
-      (void)close_children_count;
-      output << "]}";
-    }
-
-    // We always end with a comma and a new line as we'll move to the next
-    // node at whatever the current level ends up being.
-    output << "  ,\n";
-  }
-  output << "  ]\n";
-}
-
-auto Tree::CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
-    -> void {
-  mem_usage.Add(MemUsage::ConcatLabel(label, "node_impls_"), node_impls_);
-  mem_usage.Add(MemUsage::ConcatLabel(label, "imports_"), imports_);
-}
-
-auto Tree::VerifyExtract(NodeId node_id, NodeKind kind,
-                         ErrorBuilder* trace) const -> bool {
-  switch (kind) {
-#define CARBON_PARSE_NODE_KIND(Name) \
-  case NodeKind::Name:               \
-    return VerifyExtractAs<Name>(node_id, trace).has_value();
-#include "toolchain/parse/node_kind.def"
-  }
+  TreeAndSubtrees(*tokens_, *this).Print(output);
 }
 }
 
 
 auto Tree::Verify() const -> ErrorOr<Success> {
 auto Tree::Verify() const -> ErrorOr<Success> {
   llvm::SmallVector<NodeId> nodes;
   llvm::SmallVector<NodeId> nodes;
   // Traverse the tree in postorder.
   // Traverse the tree in postorder.
   for (NodeId n : postorder()) {
   for (NodeId n : postorder()) {
-    const auto& n_impl = node_impls_[n.index];
-
-    if (n_impl.has_error && !has_errors_) {
+    if (node_has_error(n) && !has_errors()) {
       return Error(llvm::formatv(
       return Error(llvm::formatv(
-          "NodeId #{0} has errors, but the tree is not marked as having any.",
-          n.index));
+          "Node {0} has errors, but the tree is not marked as having any.", n));
     }
     }
 
 
-    if (n_impl.kind == NodeKind::Placeholder) {
+    if (node_kind(n) == NodeKind::Placeholder) {
       return Error(llvm::formatv(
       return Error(llvm::formatv(
-          "Node #{0} is a placeholder node that wasn't replaced.", n.index));
-    }
-    // Should extract successfully if node not marked as having an error.
-    // Without this code, a 10 mloc test case of lex & parse takes
-    // 4.129 s ± 0.041 s. With this additional verification, it takes
-    // 5.768 s ± 0.036 s.
-    if (!n_impl.has_error && !VerifyExtract(n, n_impl.kind, nullptr)) {
-      ErrorBuilder trace;
-      trace << llvm::formatv(
-          "NodeId #{0} couldn't be extracted as a {1}. Trace:\n", n,
-          n_impl.kind);
-      VerifyExtract(n, n_impl.kind, &trace);
-      return trace;
+          "Node {0} is a placeholder node that wasn't replaced.", n));
     }
     }
-
-    int subtree_size = 1;
-    if (n_impl.kind.has_bracket()) {
-      int child_count = 0;
-      while (true) {
-        if (nodes.empty()) {
-          return Error(
-              llvm::formatv("NodeId #{0} is a {1} with bracket {2}, but didn't "
-                            "find the bracket.",
-                            n, n_impl.kind, n_impl.kind.bracket()));
-        }
-        auto child_impl = node_impls_[nodes.pop_back_val().index];
-        subtree_size += child_impl.subtree_size;
-        ++child_count;
-        if (n_impl.kind.bracket() == child_impl.kind) {
-          // If there's a bracketing node and a child count, verify the child
-          // count too.
-          if (n_impl.kind.has_child_count() &&
-              child_count != n_impl.kind.child_count()) {
-            return Error(llvm::formatv(
-                "NodeId #{0} is a {1} with child_count {2}, but encountered "
-                "{3} nodes before we reached the bracketing node.",
-                n, n_impl.kind, n_impl.kind.child_count(), child_count));
-          }
-          break;
-        }
-      }
-    } else {
-      for (int i : llvm::seq(n_impl.kind.child_count())) {
-        if (nodes.empty()) {
-          return Error(llvm::formatv(
-              "NodeId #{0} is a {1} with child_count {2}, but only had {3} "
-              "nodes to consume.",
-              n, n_impl.kind, n_impl.kind.child_count(), i));
-        }
-        auto child_impl = node_impls_[nodes.pop_back_val().index];
-        subtree_size += child_impl.subtree_size;
-      }
-    }
-    if (n_impl.subtree_size != subtree_size) {
-      return Error(llvm::formatv(
-          "NodeId #{0} is a {1} with subtree_size of {2}, but calculated {3}.",
-          n, n_impl.kind, n_impl.subtree_size, subtree_size));
-    }
-    nodes.push_back(n);
   }
   }
 
 
-  // Remaining nodes should all be roots in the tree; make sure they line up.
-  CARBON_CHECK(nodes.back().index ==
-               static_cast<int32_t>(node_impls_.size()) - 1)
-      << nodes.back() << " " << node_impls_.size() - 1;
-  int prev_index = -1;
-  for (const auto& n : nodes) {
-    const auto& n_impl = node_impls_[n.index];
-
-    if (n.index - n_impl.subtree_size != prev_index) {
-      return Error(
-          llvm::formatv("NodeId #{0} is a root {1} with subtree_size {2}, but "
-                        "previous root was at #{3}.",
-                        n, n_impl.kind, n_impl.subtree_size, prev_index));
-    }
-    prev_index = n.index;
+  if (!has_errors() &&
+      static_cast<int32_t>(size()) != tokens_->expected_parse_tree_size()) {
+    return Error(llvm::formatv(
+        "Tree has {0} nodes and no errors, but "
+        "Lex::TokenizedBuffer expected {1} nodes for {2} tokens.",
+        size(), tokens_->expected_parse_tree_size(), tokens_->size()));
   }
   }
 
 
-  // Validate the roots, ensures Tree::ExtractFile() doesn't CHECK-fail.
-  if (!TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), nullptr)) {
-    ErrorBuilder trace;
-    trace << "Roots of tree couldn't be extracted as a `File`. Trace:\n";
-    TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), &trace);
-    return trace;
-  }
+#ifndef NDEBUG
+  TreeAndSubtrees subtrees(*tokens_, *this);
+  CARBON_RETURN_IF_ERROR(subtrees.Verify());
+#endif  // NDEBUG
 
 
-  if (!has_errors_ && static_cast<int32_t>(node_impls_.size()) !=
-                          tokens_->expected_parse_tree_size()) {
-    return Error(
-        llvm::formatv("Tree has {0} nodes and no errors, but "
-                      "Lex::TokenizedBuffer expected {1} nodes for {2} tokens.",
-                      node_impls_.size(), tokens_->expected_parse_tree_size(),
-                      tokens_->size()));
-  }
   return Success();
   return Success();
 }
 }
 
 
+auto Tree::CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
+    -> void {
+  mem_usage.Add(MemUsage::ConcatLabel(label, "node_impls_"), node_impls_);
+  mem_usage.Add(MemUsage::ConcatLabel(label, "imports_"), imports_);
+}
+
 auto Tree::PostorderIterator::MakeRange(NodeId begin, NodeId end)
 auto Tree::PostorderIterator::MakeRange(NodeId begin, NodeId end)
     -> llvm::iterator_range<PostorderIterator> {
     -> llvm::iterator_range<PostorderIterator> {
   CARBON_CHECK(begin.is_valid() && end.is_valid());
   CARBON_CHECK(begin.is_valid() && end.is_valid());
@@ -307,8 +88,4 @@ auto Tree::PostorderIterator::Print(llvm::raw_ostream& output) const -> void {
   output << node_;
   output << node_;
 }
 }
 
 
-auto Tree::SiblingIterator::Print(llvm::raw_ostream& output) const -> void {
-  output << node_;
-}
-
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 7 - 246
toolchain/parse/tree.h

@@ -78,7 +78,6 @@ struct File;
 class Tree : public Printable<Tree> {
 class Tree : public Printable<Tree> {
  public:
  public:
   class PostorderIterator;
   class PostorderIterator;
-  class SiblingIterator;
 
 
   // Names in packaging, whether the file's packaging or an import. Links back
   // Names in packaging, whether the file's packaging or an import. Links back
   // to the node for diagnostics.
   // to the node for diagnostics.
@@ -114,20 +113,6 @@ class Tree : public Printable<Tree> {
   // postorder.
   // postorder.
   auto postorder() const -> llvm::iterator_range<PostorderIterator>;
   auto postorder() const -> llvm::iterator_range<PostorderIterator>;
 
 
-  // Returns an iterable range over the parse tree node and all of its
-  // descendants in depth-first postorder.
-  auto postorder(NodeId n) const -> llvm::iterator_range<PostorderIterator>;
-
-  // Returns an iterable range over the direct children of a node in the parse
-  // tree. This is a forward range, but is constant time to increment. The order
-  // of children is the same as would be found in a reverse postorder traversal.
-  auto children(NodeId n) const -> llvm::iterator_range<SiblingIterator>;
-
-  // Returns an iterable range over the roots of the parse tree. This is a
-  // forward range, but is constant time to increment. The order of roots is the
-  // same as would be found in a reverse postorder traversal.
-  auto roots() const -> llvm::iterator_range<SiblingIterator>;
-
   // Tests whether a particular node contains an error and may not match the
   // Tests whether a particular node contains an error and may not match the
   // full expected structure of the grammar.
   // full expected structure of the grammar.
   auto node_has_error(NodeId n) const -> bool;
   auto node_has_error(NodeId n) const -> bool;
@@ -183,96 +168,19 @@ class Tree : public Printable<Tree> {
     return deferred_definitions_;
     return deferred_definitions_;
   }
   }
 
 
-  // See the other Print comments.
+  // Builds TreeAndSubtrees to print the tree.
   auto Print(llvm::raw_ostream& output) const -> void;
   auto Print(llvm::raw_ostream& output) const -> void;
 
 
-  // Prints a description of the parse tree to the provided `raw_ostream`.
-  //
-  // The tree may be printed in either preorder or postorder. Output represents
-  // each node as a YAML record; in preorder, children are nested.
-  //
-  // In both, a node is formatted as:
-  //   ```
-  //   {kind: 'foo', text: '...'}
-  //   ```
-  //
-  // The top level is formatted as an array of these nodes.
-  //   ```
-  //   [
-  //   {kind: 'foo', text: '...'},
-  //   {kind: 'foo', text: '...'},
-  //   ...
-  //   ]
-  //   ```
-  //
-  // In postorder, nodes are indented in order to indicate depth. For example, a
-  // node with two children, one of them with an error:
-  //   ```
-  //     {kind: 'bar', text: '...', has_error: yes},
-  //     {kind: 'baz', text: '...'}
-  //   {kind: 'foo', text: '...', subtree_size: 2}
-  //   ```
-  //
-  // In preorder, nodes are marked as children with postorder (storage) index.
-  // For example, a node with two children, one of them with an error:
-  //   ```
-  //   {node_index: 2, kind: 'foo', text: '...', subtree_size: 2, children: [
-  //     {node_index: 0, kind: 'bar', text: '...', has_error: yes},
-  //     {node_index: 1, kind: 'baz', text: '...'}]}
-  //   ```
-  //
-  // This can be parsed as YAML using tools like `python-yq` combined with `jq`
-  // on the command line. The format is also reasonably amenable to other
-  // line-oriented shell tools from `grep` to `awk`.
-  auto Print(llvm::raw_ostream& output, bool preorder) const -> void;
-
   // Collects memory usage of members.
   // Collects memory usage of members.
   auto CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
   auto CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
       -> void;
       -> void;
 
 
-  // The following `Extract*` function provide an alternative way of accessing
-  // the nodes of a tree. It is intended to be more convenient and type-safe,
-  // but slower and can't be used on nodes that are marked as having an error.
-  // It is appropriate for uses that are less performance sensitive, like
-  // diagnostics. Example usage:
-  // ```
-  // auto file = tree->ExtractFile();
-  // for (AnyDeclId decl_id : file.decls) {
-  //   // `decl_id` is convertible to a `NodeId`.
-  //   if (std::optional<FunctionDecl> fn_decl =
-  //       tree->ExtractAs<FunctionDecl>(decl_id)) {
-  //     // fn_decl->params is a `TuplePatternId` (which extends `NodeId`)
-  //     // that is guaranteed to reference a `TuplePattern`.
-  //     std::optional<TuplePattern> params = tree->Extract(fn_decl->params);
-  //     // `params` has a value unless there was an error in that node.
-  //   } else if (auto class_def = tree->ExtractAs<ClassDefinition>(decl_id)) {
-  //     // ...
-  //   }
-  // }
-  // ```
-
-  // Extract a `File` object representing the parse tree for the whole file.
-  // #include "toolchain/parse/typed_nodes.h" to get the definition of `File`
-  // and the types representing its children nodes. This is implemented in
-  // extract.cpp.
-  auto ExtractFile() const -> File;
-
-  // Converts this node_id to a typed node of a specified type, if it is a valid
-  // node of that kind.
-  template <typename T>
-  auto ExtractAs(NodeId node_id) const -> std::optional<T>;
-
-  // Converts to a typed node, if it is not an error.
-  template <typename IdT>
-  auto Extract(IdT id) const
-      -> std::optional<typename NodeForId<IdT>::TypedNode>;
-
   // Verifies the parse tree structure. Checks invariants of the parse tree
   // Verifies the parse tree structure. Checks invariants of the parse tree
   // structure and returns verification errors.
   // structure and returns verification errors.
   //
   //
-  // This is fairly slow, and is primarily intended to be used as a debugging
-  // aid. This routine doesn't directly CHECK so that it can be used within a
-  // debugger.
+  // In opt builds, this does some minimal checking. In debug builds, it'll
+  // build a TreeAndSubtrees and run further verification. This doesn't directly
+  // CHECK so that it can be used within a debugger.
   auto Verify() const -> ErrorOr<Success>;
   auto Verify() const -> ErrorOr<Success>;
 
 
  private:
  private:
@@ -285,12 +193,8 @@ class Tree : public Printable<Tree> {
   // The in-memory representation of data used for a particular node in the
   // The in-memory representation of data used for a particular node in the
   // tree.
   // tree.
   struct NodeImpl {
   struct NodeImpl {
-    explicit NodeImpl(NodeKind kind, bool has_error, Lex::TokenIndex token,
-                      int subtree_size)
-        : kind(kind),
-          has_error(has_error),
-          token(token),
-          subtree_size(subtree_size) {}
+    explicit NodeImpl(NodeKind kind, bool has_error, Lex::TokenIndex token)
+        : kind(kind), has_error(has_error), token(token) {}
 
 
     // The kind of this node. Note that this is only a single byte.
     // The kind of this node. Note that this is only a single byte.
     NodeKind kind;
     NodeKind kind;
@@ -315,38 +219,11 @@ class Tree : public Printable<Tree> {
 
 
     // The token root of this node.
     // The token root of this node.
     Lex::TokenIndex token;
     Lex::TokenIndex token;
-
-    // The size of this node's subtree of the parse tree. This is the number of
-    // nodes (and thus tokens) that are covered by this node (and its
-    // descendents) in the parse tree.
-    //
-    // During a *reverse* postorder (RPO) traversal of the parse tree, this can
-    // also be thought of as the offset to the next non-descendant node. When
-    // this node is not the first child of its parent (which is the last child
-    // visited in RPO), that is the offset to the next sibling. When this node
-    // *is* the first child of its parent, this will be an offset to the node's
-    // parent's next sibling, or if it the parent is also a first child, the
-    // grandparent's next sibling, and so on.
-    //
-    // This field should always be a positive integer as at least this node is
-    // part of its subtree.
-    int32_t subtree_size;
   };
   };
 
 
-  static_assert(sizeof(NodeImpl) == 12,
+  static_assert(sizeof(NodeImpl) == 8,
                 "Unexpected size of node implementation!");
                 "Unexpected size of node implementation!");
 
 
-  // Like ExtractAs(), but malformed tree errors are not fatal. Should only be
-  // used by `Verify()` or by tests.
-  template <typename T>
-  auto VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
-      -> std::optional<T>;
-
-  // Wrapper around `VerifyExtractAs` to dispatch based on a runtime node kind.
-  // Returns true if extraction was successful.
-  auto VerifyExtract(NodeId node_id, NodeKind kind, ErrorBuilder* trace) const
-      -> bool;
-
   // Sets the kind of a node. This is intended to allow putting the tree into a
   // Sets the kind of a node. This is intended to allow putting the tree into a
   // state where verification can fail, in order to make the failure path of
   // state where verification can fail, in order to make the failure path of
   // `Verify` testable.
   // `Verify` testable.
@@ -354,26 +231,6 @@ class Tree : public Printable<Tree> {
     node_impls_[node_id.index].kind = kind;
     node_impls_[node_id.index].kind = kind;
   }
   }
 
 
-  // Prints a single node for Print(). Returns true when preorder and there are
-  // children.
-  auto PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
-                 bool preorder) const -> bool;
-
-  // Extract a node of type `T` from a sibling range. This is expected to
-  // consume the complete sibling range. Malformed tree errors are written
-  // to `*trace`, if `trace != nullptr`. This is implemented in extract.cpp.
-  template <typename T>
-  auto TryExtractNodeFromChildren(
-      NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children,
-      ErrorBuilder* trace) const -> std::optional<T>;
-
-  // Extract a node of type `T` from a sibling range. This is expected to
-  // consume the complete sibling range. Malformed tree errors are fatal.
-  template <typename T>
-  auto ExtractNodeFromChildren(
-      NodeId node_id,
-      llvm::iterator_range<Tree::SiblingIterator> children) const -> T;
-
   // Depth-first postorder sequence of node implementation data.
   // Depth-first postorder sequence of node implementation data.
   llvm::SmallVector<NodeImpl> node_impls_;
   llvm::SmallVector<NodeImpl> node_impls_;
 
 
@@ -449,102 +306,6 @@ class Tree::PostorderIterator
   NodeId node_;
   NodeId node_;
 };
 };
 
 
-// A forward iterator across the siblings at a particular level in the parse
-// tree. It produces `Tree::NodeId` objects which are opaque handles and must
-// be used in conjunction with the `Tree` itself.
-//
-// While this is a forward iterator and may not have good locality within the
-// `Tree` data structure, it is still constant time to increment and
-// suitable for algorithms relying on that property.
-//
-// The siblings are discovered through a reverse postorder (RPO) tree traversal
-// (which is made constant time through cached distance information), and so the
-// relative order of siblings matches their RPO order.
-class Tree::SiblingIterator
-    : public llvm::iterator_facade_base<SiblingIterator,
-                                        std::forward_iterator_tag, NodeId, int,
-                                        const NodeId*, NodeId>,
-      public Printable<Tree::SiblingIterator> {
- public:
-  explicit SiblingIterator() = delete;
-
-  auto operator==(const SiblingIterator& rhs) const -> bool {
-    return node_ == rhs.node_;
-  }
-
-  auto operator*() const -> NodeId { return node_; }
-
-  using iterator_facade_base::operator++;
-  auto operator++() -> SiblingIterator& {
-    node_.index -= std::abs(tree_->node_impls_[node_.index].subtree_size);
-    return *this;
-  }
-
-  // Prints the underlying node index.
-  auto Print(llvm::raw_ostream& output) const -> void;
-
- private:
-  friend class Tree;
-
-  explicit SiblingIterator(const Tree& tree_arg, NodeId n)
-      : tree_(&tree_arg), node_(n) {}
-
-  const Tree* tree_;
-
-  NodeId node_;
-};
-
-template <typename T>
-auto Tree::ExtractNodeFromChildren(
-    NodeId node_id, llvm::iterator_range<Tree::SiblingIterator> children) const
-    -> T {
-  auto result = TryExtractNodeFromChildren<T>(node_id, children, nullptr);
-  if (!result.has_value()) {
-    // On error try again, this time capturing a trace.
-    ErrorBuilder trace;
-    TryExtractNodeFromChildren<T>(node_id, children, &trace);
-    CARBON_FATAL() << "Malformed parse node:\n"
-                   << static_cast<Error>(trace).message();
-  }
-  return *result;
-}
-
-template <typename T>
-auto Tree::ExtractAs(NodeId node_id) const -> std::optional<T> {
-  static_assert(HasKindMember<T>, "Not a parse node type");
-  if (!IsValid<T>(node_id)) {
-    return std::nullopt;
-  }
-
-  return ExtractNodeFromChildren<T>(node_id, children(node_id));
-}
-
-template <typename T>
-auto Tree::VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
-    -> std::optional<T> {
-  static_assert(HasKindMember<T>, "Not a parse node type");
-  if (!IsValid<T>(node_id)) {
-    if (trace) {
-      *trace << "VerifyExtractAs error: wrong kind " << node_kind(node_id)
-             << ", expected " << T::Kind << "\n";
-    }
-    return std::nullopt;
-  }
-
-  return TryExtractNodeFromChildren<T>(node_id, children(node_id), trace);
-}
-
-template <typename IdT>
-auto Tree::Extract(IdT id) const
-    -> std::optional<typename NodeForId<IdT>::TypedNode> {
-  if (!IsValid(id)) {
-    return std::nullopt;
-  }
-
-  using T = typename NodeForId<IdT>::TypedNode;
-  return ExtractNodeFromChildren<T>(id, children(id));
-}
-
 template <const NodeKind& K>
 template <const NodeKind& K>
 struct Tree::ConvertTo<NodeIdForKind<K>> {
 struct Tree::ConvertTo<NodeIdForKind<K>> {
   static auto AllowedFor(NodeKind kind) -> bool { return kind == K; }
   static auto AllowedFor(NodeKind kind) -> bool { return kind == K; }

+ 244 - 0
toolchain/parse/tree_and_subtrees.cpp

@@ -0,0 +1,244 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "toolchain/parse/tree_and_subtrees.h"
+
+namespace Carbon::Parse {
+
+TreeAndSubtrees::TreeAndSubtrees(const Lex::TokenizedBuffer& tokens,
+                                 const Tree& tree)
+    : tokens_(&tokens), tree_(&tree) {
+  subtree_sizes_.reserve(tree_->size());
+
+  // A stack of nodes which haven't yet been used as children.
+  llvm::SmallVector<NodeId> size_stack;
+  for (auto n : tree.postorder()) {
+    // Nodes always include themselves.
+    int32_t size = 1;
+    auto kind = tree.node_kind(n);
+    if (kind.has_child_count()) {
+      // When the child count is set, remove the specific number from the stack.
+      CARBON_CHECK(static_cast<int32_t>(size_stack.size()) >=
+                   kind.child_count())
+          << "Need " << kind.child_count() << " children for " << kind
+          << ", have " << size_stack.size() << " available";
+      for (auto i : llvm::seq(kind.child_count())) {
+        auto child = size_stack.pop_back_val();
+        CARBON_CHECK((size_t)child.index < subtree_sizes_.size());
+        size += subtree_sizes_[child.index];
+        if (kind.has_bracket() && i == kind.child_count() - 1) {
+          CARBON_CHECK(kind.bracket() == tree.node_kind(child))
+              << "Node " << kind << " needs bracket " << kind.bracket()
+              << ", found wrong bracket " << tree.node_kind(child);
+        }
+      }
+    } else {
+      while (true) {
+        CARBON_CHECK(!size_stack.empty())
+            << "Node " << kind << " is missing bracket " << kind.bracket();
+        auto child = size_stack.pop_back_val();
+        size += subtree_sizes_[child.index];
+        if (kind.bracket() == tree.node_kind(child)) {
+          break;
+        }
+      }
+    }
+    size_stack.push_back(n);
+    subtree_sizes_.push_back(size);
+  }
+
+  CARBON_CHECK(static_cast<int>(subtree_sizes_.size()) == tree_->size());
+
+  // Remaining nodes should all be roots in the tree; make sure they line up.
+  CARBON_CHECK(size_stack.back().index ==
+               static_cast<int32_t>(tree_->size()) - 1)
+      << size_stack.back() << " " << tree_->size() - 1;
+  int prev_index = -1;
+  for (const auto& n : size_stack) {
+    CARBON_CHECK(n.index - subtree_sizes_[n.index] == prev_index)
+        << "NodeId " << n << " is a root " << tree_->node_kind(n)
+        << " with subtree_size " << subtree_sizes_[n.index]
+        << ", but previous root was at " << prev_index << ".";
+    prev_index = n.index;
+  }
+}
+
+auto TreeAndSubtrees::VerifyExtract(NodeId node_id, NodeKind kind,
+                                    ErrorBuilder* trace) const -> bool {
+  switch (kind) {
+#define CARBON_PARSE_NODE_KIND(Name) \
+  case NodeKind::Name:               \
+    return VerifyExtractAs<Name>(node_id, trace).has_value();
+#include "toolchain/parse/node_kind.def"
+  }
+}
+
+auto TreeAndSubtrees::Verify() const -> ErrorOr<Success> {
+  // Validate that each node extracts successfully when not marked as having an
+  // error.
+  //
+  // Without this code, a 10 mloc test case of lex & parse takes 4.129 s ± 0.041
+  // s. With this additional verification, it takes 5.768 s ± 0.036 s.
+  for (NodeId n : tree_->postorder()) {
+    if (tree_->node_has_error(n)) {
+      continue;
+    }
+
+    auto node_kind = tree_->node_kind(n);
+    if (!VerifyExtract(n, node_kind, nullptr)) {
+      ErrorBuilder trace;
+      trace << llvm::formatv(
+          "NodeId #{0} couldn't be extracted as a {1}. Trace:\n", n, node_kind);
+      VerifyExtract(n, node_kind, &trace);
+      return trace;
+    }
+  }
+
+  // Validate the roots. Also ensures Tree::ExtractFile() doesn't error.
+  if (!TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), nullptr)) {
+    ErrorBuilder trace;
+    trace << "Roots of tree couldn't be extracted as a `File`. Trace:\n";
+    TryExtractNodeFromChildren<File>(NodeId::Invalid, roots(), &trace);
+    return trace;
+  }
+
+  return Success();
+}
+
+auto TreeAndSubtrees::postorder(NodeId n) const
+    -> llvm::iterator_range<Tree::PostorderIterator> {
+  // The postorder ends after this node, the root, and begins at the start of
+  // its subtree.
+  int start_index = n.index - subtree_sizes_[n.index] + 1;
+  return Tree::PostorderIterator::MakeRange(NodeId(start_index), n);
+}
+
+auto TreeAndSubtrees::children(NodeId n) const
+    -> llvm::iterator_range<SiblingIterator> {
+  CARBON_CHECK(n.is_valid());
+  int end_index = n.index - subtree_sizes_[n.index];
+  return llvm::iterator_range<SiblingIterator>(
+      SiblingIterator(*this, NodeId(n.index - 1)),
+      SiblingIterator(*this, NodeId(end_index)));
+}
+
+auto TreeAndSubtrees::roots() const -> llvm::iterator_range<SiblingIterator> {
+  return llvm::iterator_range<SiblingIterator>(
+      SiblingIterator(*this,
+                      NodeId(static_cast<int>(subtree_sizes_.size()) - 1)),
+      SiblingIterator(*this, NodeId(-1)));
+}
+
+auto TreeAndSubtrees::PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
+                                bool preorder) const -> bool {
+  output.indent(2 * (depth + 2));
+  output << "{";
+  // If children are being added, include node_index in order to disambiguate
+  // nodes.
+  if (preorder) {
+    output << "node_index: " << n << ", ";
+  }
+  output << "kind: '" << tree_->node_kind(n) << "', text: '"
+         << tokens_->GetTokenText(tree_->node_token(n)) << "'";
+
+  if (tree_->node_has_error(n)) {
+    output << ", has_error: yes";
+  }
+
+  if (subtree_sizes_[n.index] > 1) {
+    output << ", subtree_size: " << subtree_sizes_[n.index];
+    if (preorder) {
+      output << ", children: [\n";
+      return true;
+    }
+  }
+  output << "}";
+  return false;
+}
+
+auto TreeAndSubtrees::Print(llvm::raw_ostream& output) const -> void {
+  output << "- filename: " << tokens_->source().filename() << "\n"
+         << "  parse_tree: [\n";
+
+  // Walk the tree just to calculate depths for each node.
+  llvm::SmallVector<int> indents;
+  indents.resize(subtree_sizes_.size(), 0);
+
+  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
+  for (NodeId n : roots()) {
+    node_stack.push_back({n, 0});
+  }
+
+  while (!node_stack.empty()) {
+    NodeId n = NodeId::Invalid;
+    int depth;
+    std::tie(n, depth) = node_stack.pop_back_val();
+    for (NodeId sibling_n : children(n)) {
+      indents[sibling_n.index] = depth + 1;
+      node_stack.push_back({sibling_n, depth + 1});
+    }
+  }
+
+  for (NodeId n : tree_->postorder()) {
+    PrintNode(output, n, indents[n.index], /*preorder=*/false);
+    output << ",\n";
+  }
+  output << "  ]\n";
+}
+
+auto TreeAndSubtrees::PrintPreorder(llvm::raw_ostream& output) const -> void {
+  output << "- filename: " << tokens_->source().filename() << "\n"
+         << "  parse_tree: [\n";
+
+  // The parse tree is stored in postorder. The preorder can be constructed
+  // by reversing the order of each level of siblings within an RPO. The
+  // sibling iterators are directly built around RPO and so can be used with a
+  // stack to produce preorder.
+
+  // The roots, like siblings, are in RPO (so reversed), but we add them in
+  // order here because we'll pop off the stack effectively reversing then.
+  llvm::SmallVector<std::pair<NodeId, int>, 16> node_stack;
+  for (NodeId n : roots()) {
+    node_stack.push_back({n, 0});
+  }
+
+  while (!node_stack.empty()) {
+    NodeId n = NodeId::Invalid;
+    int depth;
+    std::tie(n, depth) = node_stack.pop_back_val();
+
+    if (PrintNode(output, n, depth, /*preorder=*/true)) {
+      // Has children, so we descend. We append the children in order here as
+      // well because they will get reversed when popped off the stack.
+      for (NodeId sibling_n : children(n)) {
+        node_stack.push_back({sibling_n, depth + 1});
+      }
+      continue;
+    }
+
+    int next_depth = node_stack.empty() ? 0 : node_stack.back().second;
+    CARBON_CHECK(next_depth <= depth) << "Cannot have the next depth increase!";
+    for (int close_children_count : llvm::seq(0, depth - next_depth)) {
+      (void)close_children_count;
+      output << "]}";
+    }
+
+    // We always end with a comma and a new line as we'll move to the next
+    // node at whatever the current level ends up being.
+    output << "  ,\n";
+  }
+  output << "  ]\n";
+}
+
+auto TreeAndSubtrees::CollectMemUsage(MemUsage& mem_usage,
+                                      llvm::StringRef label) const -> void {
+  mem_usage.Add(MemUsage::ConcatLabel(label, "subtree_sizes_"), subtree_sizes_);
+}
+
+auto TreeAndSubtrees::SiblingIterator::Print(llvm::raw_ostream& output) const
+    -> void {
+  output << node_;
+}
+
+}  // namespace Carbon::Parse

+ 277 - 0
toolchain/parse/tree_and_subtrees.h

@@ -0,0 +1,277 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_PARSE_TREE_AND_SUBTREES_H_
+#define CARBON_TOOLCHAIN_PARSE_TREE_AND_SUBTREES_H_
+
+#include "llvm/ADT/SmallVector.h"
+#include "toolchain/parse/tree.h"
+
+namespace Carbon::Parse {
+
+// Calculates and stores subtree data for a parse tree. Supports APIs that
+// require subtree knowledge.
+//
+// This requires a complete tree.
+class TreeAndSubtrees {
+ public:
+  class SiblingIterator;
+
+  explicit TreeAndSubtrees(const Lex::TokenizedBuffer& tokens,
+                           const Tree& tree);
+
+  // The following `Extract*` function provide an alternative way of accessing
+  // the nodes of a tree. It is intended to be more convenient and type-safe,
+  // but slower and can't be used on nodes that are marked as having an error.
+  // It is appropriate for uses that are less performance sensitive, like
+  // diagnostics. Example usage:
+  // ```
+  // auto file = tree->ExtractFile();
+  // for (AnyDeclId decl_id : file.decls) {
+  //   // `decl_id` is convertible to a `NodeId`.
+  //   if (std::optional<FunctionDecl> fn_decl =
+  //       tree->ExtractAs<FunctionDecl>(decl_id)) {
+  //     // fn_decl->params is a `TuplePatternId` (which extends `NodeId`)
+  //     // that is guaranteed to reference a `TuplePattern`.
+  //     std::optional<TuplePattern> params = tree->Extract(fn_decl->params);
+  //     // `params` has a value unless there was an error in that node.
+  //   } else if (auto class_def = tree->ExtractAs<ClassDefinition>(decl_id)) {
+  //     // ...
+  //   }
+  // }
+  // ```
+
+  // Extract a `File` object representing the parse tree for the whole file.
+  // #include "toolchain/parse/typed_nodes.h" to get the definition of `File`
+  // and the types representing its children nodes. This is implemented in
+  // extract.cpp.
+  auto ExtractFile() const -> File;
+
+  // Converts this node_id to a typed node of a specified type, if it is a valid
+  // node of that kind.
+  template <typename T>
+  auto ExtractAs(NodeId node_id) const -> std::optional<T>;
+
+  // Converts to a typed node, if it is not an error.
+  template <typename IdT>
+  auto Extract(IdT id) const
+      -> std::optional<typename NodeForId<IdT>::TypedNode>;
+
+  // Verifies that each node in the tree can be successfully extracted.
+  //
+  // This is fairly slow, and is primarily intended to be used as a debugging
+  // aid. This doesn't directly CHECK so that it can be used within a debugger.
+  auto Verify() const -> ErrorOr<Success>;
+
+  // Prints the parse tree in postorder format. See also use PrintPreorder.
+  //
+  // Output represents each node as a YAML record. A node is formatted as:
+  //   ```
+  //   {kind: 'foo', text: '...'}
+  //   ```
+  //
+  // The top level is formatted as an array of these nodes.
+  //   ```
+  //   [
+  //   {kind: 'foo', text: '...'},
+  //   {kind: 'foo', text: '...'},
+  //   ...
+  //   ]
+  //   ```
+  //
+  // Nodes are indented in order to indicate depth. For example, a node with two
+  // children, one of them with an error:
+  //   ```
+  //     {kind: 'bar', text: '...', has_error: yes},
+  //     {kind: 'baz', text: '...'}
+  //   {kind: 'foo', text: '...', subtree_size: 2}
+  //   ```
+  //
+  // This can be parsed as YAML using tools like `python-yq` combined with `jq`
+  // on the command line. The format is also reasonably amenable to other
+  // line-oriented shell tools from `grep` to `awk`.
+  auto Print(llvm::raw_ostream& output) const -> void;
+
+  // Prints the parse tree in preorder. The format is YAML, and similar to
+  // Print. However, nodes are marked as children with postorder (storage)
+  // index. For example, a node with two children, one of them with an error:
+  //   ```
+  //   {node_index: 2, kind: 'foo', text: '...', subtree_size: 2, children: [
+  //     {node_index: 0, kind: 'bar', text: '...', has_error: yes},
+  //     {node_index: 1, kind: 'baz', text: '...'}]}
+  //   ```
+  auto PrintPreorder(llvm::raw_ostream& output) const -> void;
+
+  // Collects memory usage of members.
+  auto CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
+      -> void;
+
+  // Returns an iterable range over the parse tree node and all of its
+  // descendants in depth-first postorder.
+  auto postorder(NodeId n) const
+      -> llvm::iterator_range<Tree::PostorderIterator>;
+
+  // Returns an iterable range over the direct children of a node in the parse
+  // tree. This is a forward range, but is constant time to increment. The order
+  // of children is the same as would be found in a reverse postorder traversal.
+  auto children(NodeId n) const -> llvm::iterator_range<SiblingIterator>;
+
+  // Returns an iterable range over the roots of the parse tree. This is a
+  // forward range, but is constant time to increment. The order of roots is the
+  // same as would be found in a reverse postorder traversal.
+  auto roots() const -> llvm::iterator_range<SiblingIterator>;
+
+  auto tree() const -> const Tree& { return *tree_; }
+
+ private:
+  friend class TypedNodesTestPeer;
+
+  // Extract a node of type `T` from a sibling range. This is expected to
+  // consume the complete sibling range. Malformed tree errors are written
+  // to `*trace`, if `trace != nullptr`. This is implemented in extract.cpp.
+  template <typename T>
+  auto TryExtractNodeFromChildren(
+      NodeId node_id, llvm::iterator_range<SiblingIterator> children,
+      ErrorBuilder* trace) const -> std::optional<T>;
+
+  // Extract a node of type `T` from a sibling range. This is expected to
+  // consume the complete sibling range. Malformed tree errors are fatal.
+  template <typename T>
+  auto ExtractNodeFromChildren(
+      NodeId node_id, llvm::iterator_range<SiblingIterator> children) const
+      -> T;
+
+  // Like ExtractAs(), but malformed tree errors are not fatal. Should only be
+  // used by `Verify()` or by tests.
+  template <typename T>
+  auto VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
+      -> std::optional<T>;
+
+  // Wrapper around `VerifyExtractAs` to dispatch based on a runtime node kind.
+  // Returns true if extraction was successful.
+  auto VerifyExtract(NodeId node_id, NodeKind kind, ErrorBuilder* trace) const
+      -> bool;
+
+  // Prints a single node for Print(). Returns true when preorder and there are
+  // children.
+  auto PrintNode(llvm::raw_ostream& output, NodeId n, int depth,
+                 bool preorder) const -> bool;
+
+  // The associated tokens.
+  const Lex::TokenizedBuffer* tokens_;
+
+  // The associated tree.
+  const Tree* tree_;
+
+  // For each node in the tree, the size of the node's subtree. This is the
+  // number of nodes (and thus tokens) that are covered by the node (and its
+  // descendents) in the parse tree. It's one for nodes with no children.
+  //
+  // During a *reverse* postorder (RPO) traversal of the parse tree, this can
+  // also be thought of as the offset to the next non-descendant node. When the
+  // node is not the first child of its parent (which is the last child visited
+  // in RPO), that is the offset to the next sibling. When the node *is* the
+  // first child of its parent, this will be an offset to the node's parent's
+  // next sibling, or if it the parent is also a first child, the grandparent's
+  // next sibling, and so on.
+  llvm::SmallVector<int32_t> subtree_sizes_;
+};
+
+// A forward iterator across the siblings at a particular level in the parse
+// tree. It produces `Tree::NodeId` objects which are opaque handles and must
+// be used in conjunction with the `Tree` itself.
+//
+// While this is a forward iterator and may not have good locality within the
+// `Tree` data structure, it is still constant time to increment and
+// suitable for algorithms relying on that property.
+//
+// The siblings are discovered through a reverse postorder (RPO) tree traversal
+// (which is made constant time through cached distance information), and so the
+// relative order of siblings matches their RPO order.
+class TreeAndSubtrees::SiblingIterator
+    : public llvm::iterator_facade_base<SiblingIterator,
+                                        std::forward_iterator_tag, NodeId, int,
+                                        const NodeId*, NodeId>,
+      public Printable<SiblingIterator> {
+ public:
+  explicit SiblingIterator() = delete;
+
+  auto operator==(const SiblingIterator& rhs) const -> bool {
+    return node_ == rhs.node_;
+  }
+
+  auto operator*() const -> NodeId { return node_; }
+
+  using iterator_facade_base::operator++;
+  auto operator++() -> SiblingIterator& {
+    node_.index -= std::abs(tree_->subtree_sizes_[node_.index]);
+    return *this;
+  }
+
+  // Prints the underlying node index.
+  auto Print(llvm::raw_ostream& output) const -> void;
+
+ private:
+  friend class TreeAndSubtrees;
+
+  explicit SiblingIterator(const TreeAndSubtrees& tree, NodeId node)
+      : tree_(&tree), node_(node) {}
+
+  const TreeAndSubtrees* tree_;
+  NodeId node_;
+};
+
+template <typename T>
+auto TreeAndSubtrees::ExtractNodeFromChildren(
+    NodeId node_id, llvm::iterator_range<SiblingIterator> children) const -> T {
+  auto result = TryExtractNodeFromChildren<T>(node_id, children, nullptr);
+  if (!result.has_value()) {
+    // On error try again, this time capturing a trace.
+    ErrorBuilder trace;
+    TryExtractNodeFromChildren<T>(node_id, children, &trace);
+    CARBON_FATAL() << "Malformed parse node:\n"
+                   << static_cast<Error>(trace).message();
+  }
+  return *result;
+}
+
+template <typename T>
+auto TreeAndSubtrees::ExtractAs(NodeId node_id) const -> std::optional<T> {
+  static_assert(HasKindMember<T>, "Not a parse node type");
+  if (!tree_->IsValid<T>(node_id)) {
+    return std::nullopt;
+  }
+
+  return ExtractNodeFromChildren<T>(node_id, children(node_id));
+}
+
+template <typename T>
+auto TreeAndSubtrees::VerifyExtractAs(NodeId node_id, ErrorBuilder* trace) const
+    -> std::optional<T> {
+  static_assert(HasKindMember<T>, "Not a parse node type");
+  if (!tree_->IsValid<T>(node_id)) {
+    if (trace) {
+      *trace << "VerifyExtractAs error: wrong kind "
+             << tree_->node_kind(node_id) << ", expected " << T::Kind << "\n";
+    }
+    return std::nullopt;
+  }
+
+  return TryExtractNodeFromChildren<T>(node_id, children(node_id), trace);
+}
+
+template <typename IdT>
+auto TreeAndSubtrees::Extract(IdT id) const
+    -> std::optional<typename NodeForId<IdT>::TypedNode> {
+  if (!tree_->IsValid(id)) {
+    return std::nullopt;
+  }
+
+  using T = typename NodeForId<IdT>::TypedNode;
+  return ExtractNodeFromChildren<T>(id, children(id));
+}
+
+}  // namespace Carbon::Parse
+
+#endif  // CARBON_TOOLCHAIN_PARSE_TREE_AND_SUBTREES_H_

+ 16 - 8
toolchain/parse/tree_node_diagnostic_converter.h

@@ -5,9 +5,12 @@
 #ifndef CARBON_TOOLCHAIN_PARSE_TREE_NODE_DIAGNOSTIC_CONVERTER_H_
 #ifndef CARBON_TOOLCHAIN_PARSE_TREE_NODE_DIAGNOSTIC_CONVERTER_H_
 #define CARBON_TOOLCHAIN_PARSE_TREE_NODE_DIAGNOSTIC_CONVERTER_H_
 #define CARBON_TOOLCHAIN_PARSE_TREE_NODE_DIAGNOSTIC_CONVERTER_H_
 
 
+#include <utility>
+
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/tree.h"
 #include "toolchain/parse/tree.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
 
 
@@ -32,11 +35,12 @@ class NodeLoc {
 
 
 class NodeLocConverter : public DiagnosticConverter<NodeLoc> {
 class NodeLocConverter : public DiagnosticConverter<NodeLoc> {
  public:
  public:
-  explicit NodeLocConverter(const Lex::TokenizedBuffer* tokens,
-                            llvm::StringRef filename, const Tree* parse_tree)
+  explicit NodeLocConverter(
+      const Lex::TokenizedBuffer* tokens, llvm::StringRef filename,
+      llvm::function_ref<const Parse::TreeAndSubtrees&()> get_tree_and_subtrees)
       : token_converter_(tokens),
       : token_converter_(tokens),
         filename_(filename),
         filename_(filename),
-        parse_tree_(parse_tree) {}
+        get_tree_and_subtrees_(get_tree_and_subtrees) {}
 
 
   // Map the given token into a diagnostic location.
   // Map the given token into a diagnostic location.
   auto ConvertLoc(NodeLoc node_loc, ContextFnT context_fn) const
   auto ConvertLoc(NodeLoc node_loc, ContextFnT context_fn) const
@@ -47,17 +51,19 @@ class NodeLocConverter : public DiagnosticConverter<NodeLoc> {
       return {.filename = filename_};
       return {.filename = filename_};
     }
     }
 
 
+    const auto& tree = get_tree_and_subtrees_();
+
     if (node_loc.token_only()) {
     if (node_loc.token_only()) {
       return token_converter_.ConvertLoc(
       return token_converter_.ConvertLoc(
-          parse_tree_->node_token(node_loc.node_id()), context_fn);
+          tree.tree().node_token(node_loc.node_id()), context_fn);
     }
     }
 
 
     // Construct a location that encompasses all tokens that descend from this
     // Construct a location that encompasses all tokens that descend from this
     // node (including the root).
     // node (including the root).
-    Lex::TokenIndex start_token = parse_tree_->node_token(node_loc.node_id());
+    Lex::TokenIndex start_token = tree.tree().node_token(node_loc.node_id());
     Lex::TokenIndex end_token = start_token;
     Lex::TokenIndex end_token = start_token;
-    for (NodeId desc : parse_tree_->postorder(node_loc.node_id())) {
-      Lex::TokenIndex desc_token = parse_tree_->node_token(desc);
+    for (NodeId desc : tree.postorder(node_loc.node_id())) {
+      Lex::TokenIndex desc_token = tree.tree().node_token(desc);
       if (!desc_token.is_valid()) {
       if (!desc_token.is_valid()) {
         continue;
         continue;
       }
       }
@@ -89,7 +95,9 @@ class NodeLocConverter : public DiagnosticConverter<NodeLoc> {
  private:
  private:
   Lex::TokenDiagnosticConverter token_converter_;
   Lex::TokenDiagnosticConverter token_converter_;
   llvm::StringRef filename_;
   llvm::StringRef filename_;
-  const Tree* parse_tree_;
+
+  // Returns a lazily constructed TreeAndSubtrees.
+  llvm::function_ref<const Parse::TreeAndSubtrees&()> get_tree_and_subtrees_;
 };
 };
 
 
 }  // namespace Carbon::Parse
 }  // namespace Carbon::Parse

+ 5 - 2
toolchain/parse/tree_test.cpp

@@ -16,6 +16,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/parse.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 #include "toolchain/testing/yaml_test_helpers.h"
 #include "toolchain/testing/yaml_test_helpers.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
@@ -60,7 +61,8 @@ TEST_F(TreeTest, AsAndTryAs) {
   Lex::TokenizedBuffer& tokens = GetTokenizedBuffer("fn F();");
   Lex::TokenizedBuffer& tokens = GetTokenizedBuffer("fn F();");
   Tree tree = Parse(tokens, consumer_, /*vlog_stream=*/nullptr);
   Tree tree = Parse(tokens, consumer_, /*vlog_stream=*/nullptr);
   ASSERT_FALSE(tree.has_errors());
   ASSERT_FALSE(tree.has_errors());
-  auto it = tree.roots().begin();
+  TreeAndSubtrees tree_and_subtrees(tokens, tree);
+  auto it = tree_and_subtrees.roots().begin();
   // A FileEnd node, so won't match.
   // A FileEnd node, so won't match.
   NodeId n = *it;
   NodeId n = *it;
 
 
@@ -134,8 +136,9 @@ TEST_F(TreeTest, PrintPreorderAsYAML) {
   Lex::TokenizedBuffer& tokens = GetTokenizedBuffer("fn F();");
   Lex::TokenizedBuffer& tokens = GetTokenizedBuffer("fn F();");
   Tree tree = Parse(tokens, consumer_, /*vlog_stream=*/nullptr);
   Tree tree = Parse(tokens, consumer_, /*vlog_stream=*/nullptr);
   EXPECT_FALSE(tree.has_errors());
   EXPECT_FALSE(tree.has_errors());
+  TreeAndSubtrees tree_and_subtrees(tokens, tree);
   TestRawOstream print_stream;
   TestRawOstream print_stream;
-  tree.Print(print_stream, /*preorder=*/true);
+  tree_and_subtrees.PrintPreorder(print_stream);
 
 
   auto param_list = Yaml::Sequence(ElementsAre(Yaml::Mapping(
   auto param_list = Yaml::Sequence(ElementsAre(Yaml::Mapping(
       ElementsAre(Pair("node_index", "3"), Pair("kind", "TuplePatternStart"),
       ElementsAre(Pair("node_index", "3"), Pair("kind", "TuplePatternStart"),

+ 12 - 8
toolchain/parse/typed_nodes_test.cpp

@@ -12,6 +12,7 @@
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/lex.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/lex/tokenized_buffer.h"
 #include "toolchain/parse/parse.h"
 #include "toolchain/parse/parse.h"
+#include "toolchain/parse/tree_and_subtrees.h"
 
 
 namespace Carbon::Parse {
 namespace Carbon::Parse {
 
 
@@ -20,7 +21,7 @@ namespace Carbon::Parse {
 class TypedNodesTestPeer {
 class TypedNodesTestPeer {
  public:
  public:
   template <typename T>
   template <typename T>
-  static auto VerifyExtractAs(const Tree* tree, NodeId node_id,
+  static auto VerifyExtractAs(const TreeAndSubtrees* tree, NodeId node_id,
                               ErrorBuilder* trace) -> std::optional<T> {
                               ErrorBuilder* trace) -> std::optional<T> {
     return tree->VerifyExtractAs<T>(node_id, trace);
     return tree->VerifyExtractAs<T>(node_id, trace);
   }
   }
@@ -57,14 +58,16 @@ class TypedNodeTest : public ::testing::Test {
     return token_storage_.front();
     return token_storage_.front();
   }
   }
 
 
-  auto GetTree(llvm::StringRef t) -> Tree& {
+  auto GetTree(llvm::StringRef t) -> TreeAndSubtrees& {
     tree_storage_.push_front(Parse(GetTokenizedBuffer(t), consumer_,
     tree_storage_.push_front(Parse(GetTokenizedBuffer(t), consumer_,
                                    /*vlog_stream=*/nullptr));
                                    /*vlog_stream=*/nullptr));
-    return tree_storage_.front();
+    tree_and_subtrees_storage_.push_front(
+        TreeAndSubtrees(token_storage_.front(), tree_storage_.front()));
+    return tree_and_subtrees_storage_.front();
   }
   }
 
 
   auto GetTokenizedBufferAndTree(llvm::StringRef t)
   auto GetTokenizedBufferAndTree(llvm::StringRef t)
-      -> std::pair<Lex::TokenizedBuffer*, Tree*> {
+      -> std::pair<Lex::TokenizedBuffer*, TreeAndSubtrees*> {
     auto* tree = &GetTree(t);
     auto* tree = &GetTree(t);
     return {&token_storage_.front(), tree};
     return {&token_storage_.front(), tree};
   }
   }
@@ -74,6 +77,7 @@ class TypedNodeTest : public ::testing::Test {
   std::forward_list<SourceBuffer> source_storage_;
   std::forward_list<SourceBuffer> source_storage_;
   std::forward_list<Lex::TokenizedBuffer> token_storage_;
   std::forward_list<Lex::TokenizedBuffer> token_storage_;
   std::forward_list<Tree> tree_storage_;
   std::forward_list<Tree> tree_storage_;
+  std::forward_list<TreeAndSubtrees> tree_and_subtrees_storage_;
   DiagnosticConsumer& consumer_ = ConsoleDiagnosticConsumer();
   DiagnosticConsumer& consumer_ = ConsoleDiagnosticConsumer();
 };
 };
 
 
@@ -81,15 +85,15 @@ TEST_F(TypedNodeTest, Empty) {
   auto* tree = &GetTree("");
   auto* tree = &GetTree("");
   auto file = tree->ExtractFile();
   auto file = tree->ExtractFile();
 
 
-  EXPECT_TRUE(tree->IsValid(file.start));
+  EXPECT_TRUE(tree->tree().IsValid(file.start));
   EXPECT_TRUE(tree->ExtractAs<FileStart>(file.start).has_value());
   EXPECT_TRUE(tree->ExtractAs<FileStart>(file.start).has_value());
   EXPECT_TRUE(tree->Extract(file.start).has_value());
   EXPECT_TRUE(tree->Extract(file.start).has_value());
 
 
-  EXPECT_TRUE(tree->IsValid(file.end));
+  EXPECT_TRUE(tree->tree().IsValid(file.end));
   EXPECT_TRUE(tree->ExtractAs<FileEnd>(file.end).has_value());
   EXPECT_TRUE(tree->ExtractAs<FileEnd>(file.end).has_value());
   EXPECT_TRUE(tree->Extract(file.end).has_value());
   EXPECT_TRUE(tree->Extract(file.end).has_value());
 
 
-  EXPECT_FALSE(tree->IsValid<FileEnd>(file.start));
+  EXPECT_FALSE(tree->tree().IsValid<FileEnd>(file.start));
   EXPECT_FALSE(tree->ExtractAs<FileEnd>(file.start).has_value());
   EXPECT_FALSE(tree->ExtractAs<FileEnd>(file.start).has_value());
 }
 }
 
 
@@ -342,7 +346,7 @@ TEST_F(TypedNodeTest, VerifyInvalid) {
   ASSERT_TRUE(f_intro.has_value());
   ASSERT_TRUE(f_intro.has_value());
 
 
   // Change the kind of the introducer and check we get a good trace log.
   // Change the kind of the introducer and check we get a good trace log.
-  TypedNodesTestPeer::SetNodeKind(tree, f_sig->introducer,
+  TypedNodesTestPeer::SetNodeKind(&tree_storage_.front(), f_sig->introducer,
                                   NodeKind::ClassIntroducer);
                                   NodeKind::ClassIntroducer);
 
 
   // The introducer should not extract as a FunctionIntroducer any more because
   // The introducer should not extract as a FunctionIntroducer any more because