Przeglądaj źródła

Change NodeIdOneOf and similar to use "requires" and explicit UnsafeMake (#5084)

This doesn't change functionality, but I was seeing better diagnostics
in VS Code.

This also changes the NodeId constructors for related types (also
NodeCategory and NodeIdForKind) to use UnsafeMake for construction. That
originated from avoiding ambiguity coming from `requires`, but the
constructor mode is also one we should typically avoid (e.g., preferring
`Parse::Tree::As`).
Jon Ross-Perkins 1 rok temu
rodzic
commit
e6872f9499

+ 5 - 4
toolchain/check/check_unit.cpp

@@ -378,10 +378,11 @@ auto CheckUnit::ProcessNodeIds() -> bool {
     bool result;
     auto parse_kind = context_.parse_tree().node_kind(node_id);
     switch (parse_kind) {
-#define CARBON_PARSE_NODE_KIND(Name)                              \
-  case Parse::NodeKind::Name: {                                   \
-    result = HandleParseNode(context_, Parse::Name##Id(node_id)); \
-    break;                                                        \
+#define CARBON_PARSE_NODE_KIND(Name)                                   \
+  case Parse::NodeKind::Name: {                                        \
+    result = HandleParseNode(                                          \
+        context_, context_.parse_tree().As<Parse::Name##Id>(node_id)); \
+    break;                                                             \
   }
 #include "toolchain/parse/node_kind.def"
     }

+ 3 - 2
toolchain/check/import.cpp

@@ -133,8 +133,9 @@ auto AddImportNamespace(Context& context, SemIR::TypeId namespace_type_id,
           ? MakeImportedLocIdAndInst(context, import_loc_id.import_ir_inst_id(),
                                      namespace_inst)
           // TODO: Check that this actually is an `AnyNamespaceId`.
-          : SemIR::LocIdAndInst(Parse::AnyNamespaceId(import_loc_id.node_id()),
-                                namespace_inst);
+          : SemIR::LocIdAndInst(
+                Parse::AnyNamespaceId::UnsafeMake(import_loc_id.node_id()),
+                namespace_inst);
   auto namespace_id =
       AddPlaceholderInstInNoBlock(context, namespace_inst_and_loc);
   context.import_ref_ids().push_back(namespace_id);

+ 9 - 14
toolchain/check/node_stack.h

@@ -140,8 +140,8 @@ class NodeStack {
   auto PopForSoloNodeId() -> Parse::NodeIdForKind<RequiredParseKind> {
     Entry back = PopEntry<SemIR::InstId>();
     RequireIdKind(RequiredParseKind, Id::Kind::None);
-    RequireParseKind<RequiredParseKind>(back.node_id);
-    return Parse::NodeIdForKind<RequiredParseKind>(back.node_id);
+    return parse_tree_->As<Parse::NodeIdForKind<RequiredParseKind>>(
+        back.node_id);
   }
 
   // Pops the top of the stack if it is the given kind, and returns the
@@ -192,7 +192,7 @@ class NodeStack {
   template <const Parse::NodeKind& RequiredParseKind>
   auto PopWithNodeId() -> auto {
     auto id = Peek<RequiredParseKind>();
-    Parse::NodeIdForKind<RequiredParseKind> node_id(
+    auto node_id = parse_tree_->As<Parse::NodeIdForKind<RequiredParseKind>>(
         stack_.pop_back_val().node_id);
     return std::make_pair(node_id, id);
   }
@@ -201,8 +201,9 @@ class NodeStack {
   template <Parse::NodeCategory::RawEnumType RequiredParseCategory>
   auto PopWithNodeId() -> auto {
     auto id = Peek<RequiredParseCategory>();
-    Parse::NodeIdInCategory<RequiredParseCategory> node_id(
-        stack_.pop_back_val().node_id);
+    auto node_id =
+        parse_tree_->As<Parse::NodeIdInCategory<RequiredParseCategory>>(
+            stack_.pop_back_val().node_id);
     return std::make_pair(node_id, id);
   }
 
@@ -302,7 +303,9 @@ class NodeStack {
   template <const Parse::NodeKind& RequiredParseKind>
   auto Peek() const -> auto {
     Entry back = stack_.back();
-    RequireParseKind<RequiredParseKind>(back.node_id);
+    CARBON_CHECK(RequiredParseKind == parse_tree_->node_kind(back.node_id),
+                 "Expected {0}, found {1}", RequiredParseKind,
+                 parse_tree_->node_kind(back.node_id));
     constexpr Id::Kind RequiredIdKind = NodeKindToIdKind(RequiredParseKind);
     return Peek<RequiredIdKind>();
   }
@@ -589,14 +592,6 @@ class NodeStack {
                  SemIR::IdKind(NodeKindToIdKind(parse_kind)));
   }
 
-  // Require an entry to have the given Parse::NodeKind.
-  template <const Parse::NodeKind& RequiredParseKind>
-  auto RequireParseKind(Parse::NodeId node_id) const -> void {
-    auto actual_kind = parse_tree_->node_kind(node_id);
-    CARBON_CHECK(RequiredParseKind == actual_kind, "Expected {0}, found {1}",
-                 RequiredParseKind, actual_kind);
-  }
-
   // Require an entry to have the given Parse::NodeCategory.
   template <Parse::NodeCategory::RawEnumType RequiredParseCategory>
   auto RequireParseCategory(Parse::NodeId node_id) const -> void {

+ 3 - 3
toolchain/parse/context.cpp

@@ -449,8 +449,8 @@ auto Context::AddFunctionDefinitionStart(Lex::TokenIndex token, bool has_error)
     -> void {
   if (ParsingInDeferredDefinitionScope(*this)) {
     deferred_definition_stack_.push_back(tree_->deferred_definitions_.Add(
-        {.start_id =
-             FunctionDefinitionStartId(NodeId(tree_->node_impls_.size()))}));
+        {.start_id = FunctionDefinitionStartId::UnsafeMake(
+             NodeId(tree_->node_impls_.size()))}));
   }
 
   AddNode(NodeKind::FunctionDefinitionStart, token, has_error);
@@ -462,7 +462,7 @@ auto Context::AddFunctionDefinition(Lex::TokenIndex token, bool has_error)
     auto definition_index = deferred_definition_stack_.pop_back_val();
     auto& definition = tree_->deferred_definitions_.Get(definition_index);
     definition.definition_id =
-        FunctionDefinitionId(NodeId(tree_->node_impls_.size()));
+        FunctionDefinitionId::UnsafeMake(NodeId(tree_->node_impls_.size()));
     definition.next_definition_index =
         DeferredDefinitionIndex(tree_->deferred_definitions().size());
   }

+ 3 - 3
toolchain/parse/extract.cpp

@@ -155,7 +155,7 @@ struct Extractable<NodeIdForKind<Kind>> {
   static auto Extract(NodeExtractor& extractor)
       -> std::optional<NodeIdForKind<Kind>> {
     if (extractor.MatchesNodeIdForKind(Kind)) {
-      return NodeIdForKind<Kind>(extractor.ExtractNode());
+      return NodeIdForKind<Kind>::UnsafeMake(extractor.ExtractNode());
     } else {
       return std::nullopt;
     }
@@ -182,7 +182,7 @@ struct Extractable<NodeIdInCategory<Category>> {
   static auto Extract(NodeExtractor& extractor)
       -> std::optional<NodeIdInCategory<Category>> {
     if (extractor.MatchesNodeIdInCategory(Category)) {
-      return NodeIdInCategory<Category>(extractor.ExtractNode());
+      return NodeIdInCategory<Category>::UnsafeMake(extractor.ExtractNode());
     } else {
       return std::nullopt;
     }
@@ -227,7 +227,7 @@ struct Extractable<NodeIdOneOf<T...>> {
   static auto Extract(NodeExtractor& extractor)
       -> std::optional<NodeIdOneOf<T...>> {
     if (extractor.MatchesNodeIdOneOf({T::Kind...})) {
-      return NodeIdOneOf<T...>(extractor.ExtractNode());
+      return NodeIdOneOf<T...>::UnsafeMake(extractor.ExtractNode());
     } else {
       return std::nullopt;
     }

+ 1 - 1
toolchain/parse/handle_import_and_package.cpp

@@ -40,7 +40,7 @@ static auto HandleDeclContent(Context& context, Context::StateStackEntry state,
                               llvm::function_ref<auto()->void> on_parse_error)
     -> void {
   Tree::PackagingNames names{
-      .node_id = ImportDeclId(NodeId(state.subtree_start)),
+      .node_id = ImportDeclId::UnsafeMake(NodeId(state.subtree_start)),
       .is_export = is_export};
 
   // Parse the package name.

+ 38 - 7
toolchain/parse/node_ids.h

@@ -40,9 +40,20 @@ template <const NodeKind& K>
 struct NodeIdForKind : public NodeId {
   // NOLINTNEXTLINE(readability-identifier-naming)
   static const NodeKind& Kind;
-  constexpr explicit NodeIdForKind(NodeId node_id) : NodeId(node_id) {}
+
+  // Provide a factory function for construction from `NodeId`. This doesn't
+  // validate the type, so it's unsafe.
+  static constexpr auto UnsafeMake(NodeId node_id) -> NodeIdForKind {
+    return NodeIdForKind(node_id);
+  }
+
   // NOLINTNEXTLINE(google-explicit-constructor)
   constexpr NodeIdForKind(NoneNodeId /*none*/) : NodeId(NoneIndex) {}
+
+ private:
+  // Private to prevent accidental explicit construction from an untyped
+  // NodeId.
+  explicit constexpr NodeIdForKind(NodeId node_id) : NodeId(node_id) {}
 };
 template <const NodeKind& K>
 const NodeKind& NodeIdForKind<K>::Kind = K;
@@ -54,6 +65,12 @@ const NodeKind& NodeIdForKind<K>::Kind = K;
 // NodeId that matches any NodeKind whose `category()` overlaps with `Category`.
 template <NodeCategory::RawEnumType Category>
 struct NodeIdInCategory : public NodeId {
+  // Provide a factory function for construction from `NodeId`. This doesn't
+  // validate the type, so it's unsafe.
+  static constexpr auto UnsafeMake(NodeId node_id) -> NodeIdInCategory {
+    return NodeIdInCategory(node_id);
+  }
+
   // Support conversion from `NodeIdForKind<Kind>` if Kind's category
   // overlaps with `Category`.
   template <const NodeKind& Kind>
@@ -62,9 +79,13 @@ struct NodeIdInCategory : public NodeId {
     CARBON_CHECK(Kind.category().HasAnyOf(Category));
   }
 
-  constexpr explicit NodeIdInCategory(NodeId node_id) : NodeId(node_id) {}
   // NOLINTNEXTLINE(google-explicit-constructor)
   constexpr NodeIdInCategory(NoneNodeId /*none*/) : NodeId(NoneIndex) {}
+
+ private:
+  // Private to prevent accidental explicit construction from an untyped
+  // NodeId.
+  explicit constexpr NodeIdInCategory(NodeId node_id) : NodeId(node_id) {}
 };
 
 // Aliases for `NodeIdInCategory` to describe particular categories of nodes.
@@ -84,16 +105,26 @@ using AnyPackageNameId = NodeIdInCategory<NodeCategory::PackageName>;
 
 // NodeId with kind that matches one of the `T::Kind`s.
 template <typename... T>
+  requires(sizeof...(T) >= 2)
 struct NodeIdOneOf : public NodeId {
-  static_assert(sizeof...(T) >= 2, "Expected at least two types.");
-  constexpr explicit NodeIdOneOf(NodeId node_id) : NodeId(node_id) {}
+  // Provide a factory function for construction from `NodeId`. This doesn't
+  // validate the type, so it's unsafe.
+  static constexpr auto UnsafeMake(NodeId node_id) -> NodeIdOneOf {
+    return NodeIdOneOf(node_id);
+  }
+
   template <const NodeKind& Kind>
+    requires((T::Kind == Kind) || ...)
   // NOLINTNEXTLINE(google-explicit-constructor)
-  NodeIdOneOf(NodeIdForKind<Kind> node_id) : NodeId(node_id) {
-    static_assert(((T::Kind == Kind) || ...));
-  }
+  NodeIdOneOf(NodeIdForKind<Kind> node_id) : NodeId(node_id) {}
+
   // NOLINTNEXTLINE(google-explicit-constructor)
   constexpr NodeIdOneOf(NoneNodeId /*none*/) : NodeId(NoneIndex) {}
+
+ private:
+  // Private to prevent accidental explicit construction from an untyped
+  // NodeId.
+  explicit constexpr NodeIdOneOf(NodeId node_id) : NodeId(node_id) {}
 };
 
 using AnyClassDeclId =

+ 3 - 3
toolchain/parse/tree.h

@@ -146,7 +146,7 @@ class Tree : public Printable<Tree> {
   auto TryAs(NodeId n) const -> std::optional<T> {
     CARBON_DCHECK(n.has_value());
     if (ConvertTo<T>::AllowedFor(node_kind(n))) {
-      return T(n);
+      return T::UnsafeMake(n);
     } else {
       return std::nullopt;
     }
@@ -157,8 +157,8 @@ class Tree : public Printable<Tree> {
   template <typename T>
   auto As(NodeId n) const -> T {
     CARBON_DCHECK(n.has_value());
-    CARBON_CHECK(ConvertTo<T>::AllowedFor(node_kind(n)));
-    return T(n);
+    CARBON_DCHECK(ConvertTo<T>::AllowedFor(node_kind(n)));
+    return T::UnsafeMake(n);
   }
 
   auto packaging_decl() const -> const std::optional<PackagingDecl>& {