Sfoglia il codice sorgente

Add a typed node return to AddNode (#5123)

Building on #5120, make a variant of `AddNode` that returns typed nodes,
and replace `UnsafeMake` uses with it. This switches to templating in
import parsing so that we can get type validation.

With this change, `UnsafeMake` ends up used in three places: `Tree::As`,
`Tree::TryAs`, and `Context::AddNode`. That should mean that all typed
nodes are verified.
Jon Ross-Perkins 1 anno fa
parent
commit
b555392cee

+ 5 - 9
toolchain/parse/context.cpp

@@ -447,27 +447,23 @@ static auto ParsingInDeferredDefinitionScope(Context& context) -> bool {
 
 auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token, bool has_error)
     -> void {
+  auto start_id = AddNode<NodeKind::FunctionDefinitionStart>(token, has_error);
   if (ParsingInDeferredDefinitionScope(*this)) {
-    deferred_definition_stack_.push_back(tree_->deferred_definitions_.Add(
-        {.start_id = FunctionDefinitionStartId::UnsafeMake(
-             NodeId(tree_->node_impls_.size()))}));
+    deferred_definition_stack_.push_back(
+        tree_->deferred_definitions_.Add({.start_id = start_id}));
   }
-
-  AddNode(NodeKind::FunctionDefinitionStart, token, has_error);
 }
 
 auto Context::AddFunctionDefinition(Lex::TokenIndex token, bool has_error)
     -> void {
+  auto definition_id = AddNode<NodeKind::FunctionDefinition>(token, has_error);
   if (ParsingInDeferredDefinitionScope(*this)) {
     auto definition_index = deferred_definition_stack_.pop_back_val();
     auto& definition = tree_->deferred_definitions_.Get(definition_index);
-    definition.definition_id =
-        FunctionDefinitionId::UnsafeMake(NodeId(tree_->node_impls_.size()));
+    definition.definition_id = definition_id;
     definition.next_definition_index =
         DeferredDefinitionIndex(tree_->deferred_definitions().size());
   }
-
-  AddNode(NodeKind::FunctionDefinition, token, has_error);
 }
 
 auto Context::PrintForStackDump(llvm::raw_ostream& output) const -> void {

+ 9 - 3
toolchain/parse/context.h

@@ -132,14 +132,20 @@ class Context {
   }
 
   // Adds a node to the parse tree that has children.
-  // TODO: Look into switching to a typed node return.
-  auto AddNode(NodeKind kind, Lex::TokenIndex token, bool has_error) -> NodeId {
+  auto AddNode(NodeKind kind, Lex::TokenIndex token, bool has_error) -> void {
     CARBON_CHECK(has_error || (kind != NodeKind::InvalidParse &&
                                kind != NodeKind::InvalidParseStart &&
                                kind != NodeKind::InvalidParseSubtree),
                  "{0} nodes must always have an error", kind);
     tree_->node_impls_.push_back(Tree::NodeImpl(kind, has_error, token));
-    return NodeId(tree_->node_impls_.size() - 1);
+  }
+
+  // Adds a node and returns its typed NodeId.
+  template <const Parse::NodeKind& Kind>
+  auto AddNode(Lex::TokenIndex token, bool has_error) -> NodeIdForKind<Kind> {
+    AddNode(Kind, token, has_error);
+    return NodeIdForKind<Kind>::UnsafeMake(
+        NodeId(tree_->node_impls_.size() - 1));
   }
 
   // Adds an invalid parse node.

+ 6 - 3
toolchain/parse/extract.cpp

@@ -88,6 +88,8 @@ class NodeExtractor {
     }
   }
 
+  auto tree() -> const Tree& { return tree_->tree(); }
+
  private:
   const TreeAndSubtrees* tree_;
   const Lex::TokenizedBuffer* tokens_;
@@ -155,7 +157,7 @@ struct Extractable<NodeIdForKind<Kind>> {
   static auto Extract(NodeExtractor& extractor)
       -> std::optional<NodeIdForKind<Kind>> {
     if (extractor.MatchesNodeIdForKind(Kind)) {
-      return NodeIdForKind<Kind>::UnsafeMake(extractor.ExtractNode());
+      return extractor.tree().As<NodeIdForKind<Kind>>(extractor.ExtractNode());
     } else {
       return std::nullopt;
     }
@@ -182,7 +184,8 @@ struct Extractable<NodeIdInCategory<Category>> {
   static auto Extract(NodeExtractor& extractor)
       -> std::optional<NodeIdInCategory<Category>> {
     if (extractor.MatchesNodeIdInCategory(Category)) {
-      return NodeIdInCategory<Category>::UnsafeMake(extractor.ExtractNode());
+      return extractor.tree().As<NodeIdInCategory<Category>>(
+          extractor.ExtractNode());
     } else {
       return std::nullopt;
     }
@@ -227,7 +230,7 @@ struct Extractable<NodeIdOneOf<T...>> {
   static auto Extract(NodeExtractor& extractor)
       -> std::optional<NodeIdOneOf<T...>> {
     if (extractor.MatchesNodeIdOneOf({T::Kind...})) {
-      return NodeIdOneOf<T...>::UnsafeMake(extractor.ExtractNode());
+      return extractor.tree().As<NodeIdOneOf<T...>>(extractor.ExtractNode());
     } else {
       return std::nullopt;
     }

+ 23 - 22
toolchain/parse/handle_import_and_package.cpp

@@ -34,16 +34,16 @@ static auto HasModifier(Context& context, Context::StateStackEntry state,
 }
 
 // Handles everything after the declaration's introducer.
+template <const Parse::NodeKind& DeclKind>
 static auto HandleDeclContent(Context& context, Context::StateStackEntry state,
-                              NodeKind declaration, bool is_export,
-                              bool is_impl,
+                              bool is_export, bool is_impl,
                               llvm::function_ref<auto()->void> on_parse_error)
     -> void {
   Tree::PackagingNames names = {.is_export = is_export};
 
   // Parse the package name.
-  if (declaration == NodeKind::LibraryDecl ||
-      (declaration == NodeKind::ImportDecl &&
+  if (DeclKind == NodeKind::LibraryDecl ||
+      (DeclKind == NodeKind::ImportDecl &&
        context.PositionIs(Lex::TokenKind::Library))) {
     // This is either `library ...` or `import library ...`, so no package name
     // is expected.
@@ -64,7 +64,7 @@ static auto HandleDeclContent(Context& context, Context::StateStackEntry state,
       CARBON_DIAGNOSTIC(ExpectedIdentifierAfterImport, Error,
                         "expected identifier or `library` after `import`");
       context.emitter().Emit(package_name_position,
-                             declaration == NodeKind::PackageDecl
+                             DeclKind == NodeKind::PackageDecl
                                  ? ExpectedIdentifierAfterPackage
                                  : ExpectedIdentifierAfterImport);
       on_parse_error();
@@ -83,7 +83,7 @@ static auto HandleDeclContent(Context& context, Context::StateStackEntry state,
 
   // Parse the optional library keyword.
   bool accept_default = !names.package_id.has_value();
-  if (declaration == NodeKind::LibraryDecl) {
+  if constexpr (DeclKind == NodeKind::LibraryDecl) {
     auto library_id = context.ParseLibraryName(accept_default);
     if (!library_id) {
       on_parse_error();
@@ -113,10 +113,9 @@ static auto HandleDeclContent(Context& context, Context::StateStackEntry state,
   }
 
   if (auto semi = context.ConsumeIf(Lex::TokenKind::Semi)) {
-    auto node_id = context.AddNode(declaration, *semi, state.has_error);
-    names.node_id = context.tree().As<AnyPackagingDeclId>(node_id);
+    names.node_id = context.AddNode<DeclKind>(*semi, state.has_error);
 
-    if (declaration == NodeKind::ImportDecl) {
+    if constexpr (DeclKind == NodeKind::ImportDecl) {
       context.AddImport(names);
     } else {
       context.set_packaging_decl(names, is_impl);
@@ -179,8 +178,9 @@ static auto RestrictExportToApi(Context& context,
 auto HandleImport(Context& context) -> void {
   auto state = context.PopState();
 
-  auto declaration = NodeKind::ImportDecl;
-  auto on_parse_error = [&] { OnParseError(context, state, declaration); };
+  auto on_parse_error = [&] {
+    OnParseError(context, state, NodeKind::ImportDecl);
+  };
 
   if (VerifyInImports(context, state.token)) {
     // Scan the modifiers to see if this import declaration is exported.
@@ -189,8 +189,8 @@ auto HandleImport(Context& context) -> void {
       RestrictExportToApi(context, state);
     }
 
-    HandleDeclContent(context, state, declaration, is_export,
-                      /*is_impl=*/false, on_parse_error);
+    HandleDeclContent<NodeKind::ImportDecl>(context, state, is_export,
+                                            /*is_impl=*/false, on_parse_error);
   } else {
     on_parse_error();
   }
@@ -214,14 +214,15 @@ auto HandleExportNameFinish(Context& context) -> void {
 }
 
 // Handles common logic for `package` and `library`.
+template <const Parse::NodeKind& DeclKind>
 static auto HandlePackageAndLibraryDecls(Context& context,
-                                         Lex::TokenKind intro_token_kind,
-                                         NodeKind declaration) -> void {
+                                         Lex::TokenKind intro_token_kind)
+    -> void {
   auto state = context.PopState();
 
   bool is_impl = HasModifier(context, state, Lex::TokenKind::Impl);
 
-  auto on_parse_error = [&] { OnParseError(context, state, declaration); };
+  auto on_parse_error = [&] { OnParseError(context, state, DeclKind); };
 
   if (state.token != Lex::TokenIndex::FirstNonCommentToken) {
     CARBON_DIAGNOSTIC(
@@ -241,18 +242,18 @@ static auto HandlePackageAndLibraryDecls(Context& context,
   // `package`/`library` is no longer allowed, but `import` may repeat.
   context.set_packaging_state(Context::PackagingState::InImports);
 
-  HandleDeclContent(context, state, declaration, /*is_export=*/false, is_impl,
-                    on_parse_error);
+  HandleDeclContent<DeclKind>(context, state, /*is_export=*/false, is_impl,
+                              on_parse_error);
 }
 
 auto HandlePackage(Context& context) -> void {
-  HandlePackageAndLibraryDecls(context, Lex::TokenKind::Package,
-                               NodeKind::PackageDecl);
+  HandlePackageAndLibraryDecls<NodeKind::PackageDecl>(context,
+                                                      Lex::TokenKind::Package);
 }
 
 auto HandleLibrary(Context& context) -> void {
-  HandlePackageAndLibraryDecls(context, Lex::TokenKind::Library,
-                               NodeKind::LibraryDecl);
+  HandlePackageAndLibraryDecls<NodeKind::LibraryDecl>(context,
+                                                      Lex::TokenKind::Library);
 }
 
 }  // namespace Carbon::Parse