Просмотр исходного кода

Use typed parse node ids in SemIR instruction types (#3560)

This involves a number of supporting changes:
* The `parse_node;` member of instruction types may now have any type
derived from `Parse::NodeId` and is no longer required to have that
exact type.
* `Parse::Node::Invalid` is now a singleton object of a separate type
that is convertible to `Parse::NodeId` and its descendants. This
replaces the `Invalid` member of its descendants, and avoids having to
write long `NodeIdOneOf<...>` types when initializing variables to
invalid.
* `IndexBase` now allows `==` and `!=` comparisons between its derived
classes and types that are convertible to those types.
* A number of functions in the check stage have been changed to preserve
more type information instead of using `Parse::NodeId`.
* `NodeIdForKind<K>` (also known as `KId`) now has a `Kind` member so it
may be used to declare `NodeIdOneOf<T, U>` types without #including
`parse/typed_nodes.h`.
* `NodeIdForKind<K>` (also known as `KId`) may be implicitly converted
to `NodeIdOneOf<T, U>` if `T::Kind == K` or `U::Kind == K` (executing a
TODO).

Many of the `parse_node` members were not converted since they would
have required more extensive changes. They have been marked with "TODO"
comments.

---------

Co-authored-by: Richard Smith <richard@metafoo.co.uk>
josh11b 2 лет назад
Родитель
Сommit
b0da52a3d7

+ 34 - 0
toolchain/base/index_base.h

@@ -66,6 +66,40 @@ auto operator!=(IndexType lhs, IndexType rhs) -> bool {
   return lhs.index != rhs.index;
 }
 
+template <
+    typename IndexType, typename RHSType,
+    typename std::enable_if_t<std::is_base_of_v<IdBase, IndexType>>* = nullptr,
+    typename std::enable_if_t<std::is_convertible_v<RHSType, IndexType>>* =
+        nullptr>
+auto operator==(IndexType lhs, RHSType rhs) -> bool {
+  return lhs.index == IndexType(rhs).index;
+}
+template <
+    typename IndexType, typename RHSType,
+    typename std::enable_if_t<std::is_base_of_v<IdBase, IndexType>>* = nullptr,
+    typename std::enable_if_t<std::is_convertible_v<RHSType, IndexType>>* =
+        nullptr>
+auto operator!=(IndexType lhs, RHSType rhs) -> bool {
+  return lhs.index != IndexType(rhs).index;
+}
+
+template <
+    typename LHSType, typename IndexType,
+    typename std::enable_if_t<std::is_base_of_v<IdBase, IndexType>>* = nullptr,
+    typename std::enable_if_t<std::is_convertible_v<LHSType, IndexType>>* =
+        nullptr>
+auto operator==(LHSType lhs, IndexType rhs) -> bool {
+  return IndexType(lhs).index == rhs.index;
+}
+template <
+    typename LHSType, typename IndexType,
+    typename std::enable_if_t<std::is_base_of_v<IdBase, IndexType>>* = nullptr,
+    typename std::enable_if_t<std::is_convertible_v<LHSType, IndexType>>* =
+        nullptr>
+auto operator!=(LHSType lhs, IndexType rhs) -> bool {
+  return IndexType(lhs).index != rhs.index;
+}
+
 // The < and > comparisons for only IndexBase.
 template <typename IndexType,
           typename std::enable_if_t<std::is_base_of_v<IndexBase, IndexType>>* =

+ 4 - 2
toolchain/check/handle_class.cpp

@@ -21,8 +21,10 @@ auto HandleClassIntroducer(Context& context,
   return true;
 }
 
-static auto BuildClassDecl(Context& context, Parse::NodeId parse_node)
-    -> std::tuple<SemIR::ClassId, SemIR::InstId> {
+static auto BuildClassDecl(
+    Context& context,
+    Parse::NodeIdOneOf<Parse::ClassDeclId, Parse::ClassDefinitionStartId>
+        parse_node) -> std::tuple<SemIR::ClassId, SemIR::InstId> {
   auto name_context = context.decl_name_stack().FinishName();
   context.node_stack()
       .PopAndDiscardSoloParseNode<Parse::NodeKind::ClassIntroducer>();

+ 5 - 3
toolchain/check/handle_function.cpp

@@ -43,9 +43,11 @@ static auto DiagnoseModifiers(Context& context) -> KeywordModifierSet {
 // Build a FunctionDecl describing the signature of a function. This
 // handles the common logic shared by function declaration syntax and function
 // definition syntax.
-static auto BuildFunctionDecl(Context& context, Parse::NodeId parse_node,
-                              bool is_definition)
-    -> std::pair<SemIR::FunctionId, SemIR::InstId> {
+static auto BuildFunctionDecl(
+    Context& context,
+    Parse::NodeIdOneOf<Parse::FunctionDeclId, Parse::FunctionDefinitionStartId>
+        parse_node,
+    bool is_definition) -> std::pair<SemIR::FunctionId, SemIR::InstId> {
   // TODO: This contains the IR block for the parameters and return type. At
   // present, it's just loose, but it's not strictly required for parameter
   // refs; we should either stop constructing it completely or, if it turns out

+ 4 - 1
toolchain/check/handle_interface.cpp

@@ -21,7 +21,10 @@ auto HandleInterfaceIntroducer(Context& context,
   return true;
 }
 
-static auto BuildInterfaceDecl(Context& context, Parse::NodeId parse_node)
+static auto BuildInterfaceDecl(
+    Context& context, Parse::NodeIdOneOf<Parse::InterfaceDeclId,
+                                         Parse::InterfaceDefinitionStartId>
+                          parse_node)
     -> std::tuple<SemIR::InterfaceId, SemIR::InstId> {
   auto name_context = context.decl_name_stack().FinishName();
   context.node_stack()

+ 12 - 7
toolchain/check/node_stack.h

@@ -144,40 +144,45 @@ class NodeStack {
   template <const Parse::NodeKind& RequiredParseKind>
   auto PopWithParseNode() -> auto {
     constexpr IdKind RequiredIdKind = ParseNodeKindToIdKind(RequiredParseKind);
+    auto NodeIdCast = [&](auto back) {
+      using NodeIdT = Parse::NodeIdForKind<RequiredParseKind>;
+      return std::pair<NodeIdT, decltype(back.second)>(back);
+    };
+
     if constexpr (RequiredIdKind == IdKind::InstId) {
       auto back = PopWithParseNode<SemIR::InstId>();
       RequireParseKind<RequiredParseKind>(back.first);
-      return back;
+      return NodeIdCast(back);
     }
     if constexpr (RequiredIdKind == IdKind::InstBlockId) {
       auto back = PopWithParseNode<SemIR::InstBlockId>();
       RequireParseKind<RequiredParseKind>(back.first);
-      return back;
+      return NodeIdCast(back);
     }
     if constexpr (RequiredIdKind == IdKind::FunctionId) {
       auto back = PopWithParseNode<SemIR::FunctionId>();
       RequireParseKind<RequiredParseKind>(back.first);
-      return back;
+      return NodeIdCast(back);
     }
     if constexpr (RequiredIdKind == IdKind::ClassId) {
       auto back = PopWithParseNode<SemIR::ClassId>();
       RequireParseKind<RequiredParseKind>(back.first);
-      return back;
+      return NodeIdCast(back);
     }
     if constexpr (RequiredIdKind == IdKind::InterfaceId) {
       auto back = PopWithParseNode<SemIR::InterfaceId>();
       RequireParseKind<RequiredParseKind>(back.first);
-      return back;
+      return NodeIdCast(back);
     }
     if constexpr (RequiredIdKind == IdKind::NameId) {
       auto back = PopWithParseNode<SemIR::NameId>();
       RequireParseKind<RequiredParseKind>(back.first);
-      return back;
+      return NodeIdCast(back);
     }
     if constexpr (RequiredIdKind == IdKind::TypeId) {
       auto back = PopWithParseNode<SemIR::TypeId>();
       RequireParseKind<RequiredParseKind>(back.first);
-      return back;
+      return NodeIdCast(back);
     }
     CARBON_FATAL() << "Unpoppable IdKind for parse kind: " << RequiredParseKind
                    << "; see value in ParseNodeKindToIdKind";

+ 5 - 3
toolchain/check/return.cpp

@@ -108,7 +108,8 @@ auto RegisterReturnedVar(Context& context, SemIR::InstId bind_id) -> void {
   }
 }
 
-auto BuildReturnWithNoExpr(Context& context, Parse::NodeId parse_node) -> void {
+auto BuildReturnWithNoExpr(Context& context,
+                           Parse::ReturnStatementId parse_node) -> void {
   const auto& function = GetCurrentFunction(context);
 
   if (function.return_type_id.is_valid()) {
@@ -122,7 +123,7 @@ auto BuildReturnWithNoExpr(Context& context, Parse::NodeId parse_node) -> void {
   context.AddInst(SemIR::Return{parse_node});
 }
 
-auto BuildReturnWithExpr(Context& context, Parse::NodeId parse_node,
+auto BuildReturnWithExpr(Context& context, Parse::ReturnStatementId parse_node,
                          SemIR::InstId expr_id) -> void {
   const auto& function = GetCurrentFunction(context);
   auto returned_var_id = GetCurrentReturnedVar(context);
@@ -154,7 +155,8 @@ auto BuildReturnWithExpr(Context& context, Parse::NodeId parse_node,
   context.AddInst(SemIR::ReturnExpr{parse_node, expr_id});
 }
 
-auto BuildReturnVar(Context& context, Parse::NodeId parse_node) -> void {
+auto BuildReturnVar(Context& context, Parse::ReturnStatementId parse_node)
+    -> void {
   const auto& function = GetCurrentFunction(context);
   auto returned_var_id = GetCurrentReturnedVar(context);
 

+ 5 - 3
toolchain/check/return.h

@@ -22,14 +22,16 @@ auto CheckReturnedVar(Context& context, Parse::NodeId returned_node,
 auto RegisterReturnedVar(Context& context, SemIR::InstId bind_id) -> void;
 
 // Checks and builds SemIR for a `return;` statement.
-auto BuildReturnWithNoExpr(Context& context, Parse::NodeId parse_node) -> void;
+auto BuildReturnWithNoExpr(Context& context,
+                           Parse::ReturnStatementId parse_node) -> void;
 
 // Checks and builds SemIR for a `return <expression>;` statement.
-auto BuildReturnWithExpr(Context& context, Parse::NodeId parse_node,
+auto BuildReturnWithExpr(Context& context, Parse::ReturnStatementId parse_node,
                          SemIR::InstId expr_id) -> void;
 
 // Checks and builds SemIR for a `return var;` statement.
-auto BuildReturnVar(Context& context, Parse::NodeId parse_node) -> void;
+auto BuildReturnVar(Context& context, Parse::ReturnStatementId parse_node)
+    -> void;
 
 }  // namespace Carbon::Check
 

+ 21 - 37
toolchain/parse/node_ids.h

@@ -10,6 +10,9 @@
 
 namespace Carbon::Parse {
 
+// Represents an invalid node id of any type
+struct InvalidNodeId {};
+
 // A lightweight handle representing a node in the tree.
 //
 // Objects of this type are small and cheap to copy and store. They don't
@@ -17,28 +20,26 @@ namespace Carbon::Parse {
 // can be used with the underlying tree to query for detailed information.
 struct NodeId : public IdBase {
   // An explicitly invalid instance.
-  static const NodeId Invalid;
+  static constexpr InvalidNodeId Invalid;
 
   using IdBase::IdBase;
+  constexpr NodeId(InvalidNodeId) : IdBase(NodeId::InvalidIndex) {}
 };
 
-constexpr NodeId NodeId::Invalid = NodeId(NodeId::InvalidIndex);
-
 // For looking up the type associated with a given id type.
 template <typename T>
 struct NodeForId;
 
 // `<KindName>Id` is a typed version of `NodeId` that references a node of kind
 // `<KindName>`:
-template <const NodeKind&>
+template <const NodeKind& K>
 struct NodeIdForKind : public NodeId {
-  static const NodeIdForKind Invalid;
-
-  explicit NodeIdForKind(NodeId node_id) : NodeId(node_id) {}
+  static const NodeKind& Kind;
+  constexpr explicit NodeIdForKind(NodeId node_id) : NodeId(node_id) {}
+  constexpr NodeIdForKind(InvalidNodeId) : NodeId(NodeId::InvalidIndex) {}
 };
-template <const NodeKind& Kind>
-constexpr NodeIdForKind<Kind> NodeIdForKind<Kind>::Invalid =
-    NodeIdForKind(NodeId::Invalid.index);
+template <const NodeKind& K>
+const NodeKind& NodeIdForKind<K>::Kind = K;
 
 #define CARBON_PARSE_NODE_KIND(KindName) \
   using KindName##Id = NodeIdForKind<NodeKind::KindName>;
@@ -47,19 +48,13 @@ constexpr NodeIdForKind<Kind> NodeIdForKind<Kind>::Invalid =
 // NodeId that matches any NodeKind whose `category()` overlaps with `Category`.
 template <NodeCategory Category>
 struct NodeIdInCategory : public NodeId {
-  // An explicitly invalid instance.
-  static const NodeIdInCategory<Category> Invalid;
-
   // TODO: Support conversion from `NodeIdForKind<Kind>` if `Kind::category()`
   // overlaps with `Category`.
 
-  explicit NodeIdInCategory(NodeId node_id) : NodeId(node_id) {}
+  constexpr explicit NodeIdInCategory(NodeId node_id) : NodeId(node_id) {}
+  constexpr NodeIdInCategory(InvalidNodeId) : NodeId(NodeId::InvalidIndex) {}
 };
 
-template <NodeCategory Category>
-constexpr NodeIdInCategory<Category> NodeIdInCategory<Category>::Invalid =
-    NodeIdInCategory<Category>(NodeId::InvalidIndex);
-
 // Aliases for `NodeIdInCategory` to describe particular categories of nodes.
 using AnyDeclId = NodeIdInCategory<NodeCategory::Decl>;
 using AnyExprId = NodeIdInCategory<NodeCategory::Expr>;
@@ -72,32 +67,21 @@ using AnyStatementId = NodeIdInCategory<NodeCategory::Statement>;
 // NodeId with kind that matches either T::Kind or U::Kind.
 template <typename T, typename U>
 struct NodeIdOneOf : public NodeId {
-  // An explicitly invalid instance.
-  static const NodeIdOneOf<T, U> Invalid;
-
-  // TODO: Support conversion from `NodeIdForKind<Kind>` if `Kind` is
-  // `T::Kind` or `U::Kind`.
-
-  explicit NodeIdOneOf(NodeId node_id) : NodeId(node_id) {}
+  constexpr explicit NodeIdOneOf(NodeId node_id) : NodeId(node_id) {}
+  template <const NodeKind& Kind>
+  NodeIdOneOf(NodeIdForKind<Kind> node_id) : NodeId(node_id) {
+    static_assert(T::Kind == Kind || U::Kind == Kind);
+  }
+  constexpr NodeIdOneOf(InvalidNodeId) : NodeId(NodeId::InvalidIndex) {}
 };
 
-template <typename T, typename U>
-constexpr NodeIdOneOf<T, U> NodeIdOneOf<T, U>::Invalid =
-    NodeIdOneOf<T, U>(NodeId::InvalidIndex);
-
 // NodeId with kind that is anything but T::Kind.
 template <typename T>
 struct NodeIdNot : public NodeId {
-  // An explicitly invalid instance.
-  static const NodeIdNot<T> Invalid;
-
-  explicit NodeIdNot(NodeId node_id) : NodeId(node_id) {}
+  constexpr explicit NodeIdNot(NodeId node_id) : NodeId(node_id) {}
+  constexpr NodeIdNot(InvalidNodeId) : NodeId(NodeId::InvalidIndex) {}
 };
 
-template <typename T>
-constexpr NodeIdNot<T> NodeIdNot<T>::Invalid =
-    NodeIdNot<T>(NodeId::InvalidIndex);
-
 // Note that the support for extracting these types using the `Tree::Extract*`
 // functions is defined in `extract.cpp`.
 

+ 3 - 0
toolchain/parse/tree_node_location_translator.h

@@ -16,6 +16,9 @@ class NodeLocation {
   NodeLocation(NodeId node_id) : NodeLocation(node_id, false) {}
   NodeLocation(NodeId node_id, bool token_only)
       : node_id_(node_id), token_only_(token_only) {}
+  // TODO: Have some other way of representing diagnostic that applies to a file
+  // as a whole.
+  NodeLocation(InvalidNodeId node_id) : NodeLocation(node_id, false) {}
 
   auto node_id() const -> NodeId { return node_id_; }
   auto token_only() const -> bool { return token_only_; }

+ 1 - 0
toolchain/sem_ir/BUILD

@@ -36,6 +36,7 @@ cc_library(
     textual_hdrs = ["inst_kind.def"],
     deps = [
         "//common:enum_base",
+        "//toolchain/parse:node_kind",
         "//toolchain/parse:tree",
         "//toolchain/sem_ir:builtin_kind",
         "//toolchain/sem_ir:ids",

+ 2 - 1
toolchain/sem_ir/inst.h

@@ -101,7 +101,8 @@ class Inst : public Printable<Inst> {
                                   << " to wrong kind " << TypedInst::Kind;
     auto build_with_type_id_and_args = [&](auto... type_id_and_args) {
       if constexpr (HasParseNodeMember<TypedInst>) {
-        return TypedInst{parse_node(), type_id_and_args...};
+        return TypedInst{decltype(TypedInst::parse_node)(parse_node()),
+                         type_id_and_args...};
       } else {
         return TypedInst{type_id_and_args...};
       }

+ 71 - 25
toolchain/sem_ir/typed_insts.h

@@ -5,6 +5,7 @@
 #ifndef CARBON_TOOLCHAIN_SEM_IR_TYPED_INSTS_H_
 #define CARBON_TOOLCHAIN_SEM_IR_TYPED_INSTS_H_
 
+#include "toolchain/parse/node_ids.h"
 #include "toolchain/parse/tree.h"
 #include "toolchain/sem_ir/builtin_kind.h"
 #include "toolchain/sem_ir/ids.h"
@@ -40,6 +41,7 @@ namespace Carbon::SemIR {
 struct AddressOf {
   static constexpr auto Kind = InstKind::AddressOf.Define("address_of");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId lvalue_id;
@@ -48,7 +50,7 @@ struct AddressOf {
 struct AddrPattern {
   static constexpr auto Kind = InstKind::AddrPattern.Define("addr_pattern");
 
-  Parse::NodeId parse_node;
+  Parse::AddressId parse_node;
   TypeId type_id;
   // The `self` parameter.
   InstId inner_id;
@@ -57,6 +59,7 @@ struct AddrPattern {
 struct ArrayIndex {
   static constexpr auto Kind = InstKind::ArrayIndex.Define("array_index");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId array_id;
@@ -69,6 +72,7 @@ struct ArrayIndex {
 struct ArrayInit {
   static constexpr auto Kind = InstKind::ArrayInit.Define("array_init");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId inits_id;
@@ -78,7 +82,7 @@ struct ArrayInit {
 struct ArrayType {
   static constexpr auto Kind = InstKind::ArrayType.Define("array_type");
 
-  Parse::NodeId parse_node;
+  Parse::ArrayExprId parse_node;
   TypeId type_id;
   InstId bound_id;
   TypeId element_type_id;
@@ -90,7 +94,8 @@ struct ArrayType {
 struct Assign {
   static constexpr auto Kind = InstKind::Assign.Define("assign");
 
-  Parse::NodeId parse_node;
+  Parse::NodeIdOneOf<Parse::InfixOperatorEqualId, Parse::VariableDeclId>
+      parse_node;
   // Assignments are statements, and so have no type.
   InstId lhs_id;
   InstId rhs_id;
@@ -102,7 +107,7 @@ struct Assign {
 struct BaseDecl {
   static constexpr auto Kind = InstKind::BaseDecl.Define("base_decl");
 
-  Parse::NodeId parse_node;
+  Parse::BaseDeclId parse_node;
   TypeId type_id;
   TypeId base_type_id;
   ElementIndex index;
@@ -111,6 +116,7 @@ struct BaseDecl {
 struct BindName {
   static constexpr auto Kind = InstKind::BindName.Define("bind_name");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   NameId name_id;
@@ -120,6 +126,7 @@ struct BindName {
 struct BindValue {
   static constexpr auto Kind = InstKind::BindValue.Define("bind_value");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId value_id;
@@ -128,6 +135,7 @@ struct BindValue {
 struct BlockArg {
   static constexpr auto Kind = InstKind::BlockArg.Define("block_arg");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId block_id;
@@ -136,6 +144,7 @@ struct BlockArg {
 struct BoolLiteral {
   static constexpr auto Kind = InstKind::BoolLiteral.Define("bool_literal");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   BoolValue value;
@@ -146,7 +155,7 @@ struct BoolLiteral {
 struct BoundMethod {
   static constexpr auto Kind = InstKind::BoundMethod.Define("bound_method");
 
-  Parse::NodeId parse_node;
+  Parse::MemberAccessExprId parse_node;
   TypeId type_id;
   // The object argument in the bound method, which will be used to initialize
   // `self`, or whose address will be used to initialize `self` for an `addr
@@ -159,6 +168,7 @@ struct Branch {
   static constexpr auto Kind =
       InstKind::Branch.Define("br", TerminatorKind::Terminator);
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   // Branches don't produce a value, so have no type.
   InstBlockId target_id;
@@ -168,6 +178,7 @@ struct BranchIf {
   static constexpr auto Kind =
       InstKind::BranchIf.Define("br", TerminatorKind::TerminatorSequence);
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   // Branches don't produce a value, so have no type.
   InstBlockId target_id;
@@ -178,6 +189,7 @@ struct BranchWithArg {
   static constexpr auto Kind =
       InstKind::BranchWithArg.Define("br", TerminatorKind::Terminator);
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   // Branches don't produce a value, so have no type.
   InstBlockId target_id;
@@ -195,7 +207,7 @@ struct Builtin {
 struct Call {
   static constexpr auto Kind = InstKind::Call.Define("call");
 
-  Parse::NodeId parse_node;
+  Parse::CallExprStartId parse_node;
   TypeId type_id;
   InstId callee_id;
   // The arguments block contains IDs for the following arguments, in order:
@@ -208,7 +220,8 @@ struct Call {
 struct ClassDecl {
   static constexpr auto Kind = InstKind::ClassDecl.Define("class_decl");
 
-  Parse::NodeId parse_node;
+  Parse::NodeIdOneOf<Parse::ClassDeclId, Parse::ClassDefinitionStartId>
+      parse_node;
   // No type: a class declaration is not itself a value. The name of a class
   // declaration becomes a class type value.
   // TODO: For a generic class declaration, the name of the class declaration
@@ -223,6 +236,7 @@ struct ClassElementAccess {
   static constexpr auto Kind =
       InstKind::ClassElementAccess.Define("class_element_access");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId base_id;
@@ -232,6 +246,7 @@ struct ClassElementAccess {
 struct ClassInit {
   static constexpr auto Kind = InstKind::ClassInit.Define("class_init");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId elements_id;
@@ -241,7 +256,8 @@ struct ClassInit {
 struct ClassType {
   static constexpr auto Kind = InstKind::ClassType.Define("class_type");
 
-  Parse::NodeId parse_node;
+  Parse::NodeIdOneOf<Parse::ClassDeclId, Parse::ClassDefinitionStartId>
+      parse_node;
   TypeId type_id;
   ClassId class_id;
   // TODO: Once we support generic classes, include the class's arguments here.
@@ -250,7 +266,7 @@ struct ClassType {
 struct ConstType {
   static constexpr auto Kind = InstKind::ConstType.Define("const_type");
 
-  Parse::NodeId parse_node;
+  Parse::PrefixOperatorConstId parse_node;
   TypeId type_id;
   TypeId inner_id;
 };
@@ -258,6 +274,7 @@ struct ConstType {
 struct Converted {
   static constexpr auto Kind = InstKind::Converted.Define("converted");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId original_id;
@@ -279,6 +296,7 @@ struct CrossRef {
 struct Deref {
   static constexpr auto Kind = InstKind::Deref.Define("deref");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId pointer_id;
@@ -289,7 +307,7 @@ struct Deref {
 struct FieldDecl {
   static constexpr auto Kind = InstKind::FieldDecl.Define("field_decl");
 
-  Parse::NodeId parse_node;
+  Parse::BindingPatternId parse_node;
   TypeId type_id;
   NameId name_id;
   ElementIndex index;
@@ -298,7 +316,8 @@ struct FieldDecl {
 struct FunctionDecl {
   static constexpr auto Kind = InstKind::FunctionDecl.Define("fn_decl");
 
-  Parse::NodeId parse_node;
+  Parse::NodeIdOneOf<Parse::FunctionDeclId, Parse::FunctionDefinitionStartId>
+      parse_node;
   TypeId type_id;
   FunctionId function_id;
 };
@@ -310,6 +329,7 @@ struct FunctionDecl {
 struct Import {
   static constexpr auto Kind = InstKind::Import.Define("import");
 
+  // TODO: Should always be an ImportDirectiveId?
   Parse::NodeId parse_node;
   TypeId type_id;
   CrossRefIRId first_cross_ref_ir_id;
@@ -323,6 +343,7 @@ struct InitializeFrom {
   static constexpr auto Kind =
       InstKind::InitializeFrom.Define("initialize_from");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId src_id;
@@ -332,7 +353,8 @@ struct InitializeFrom {
 struct InterfaceDecl {
   static constexpr auto Kind = InstKind::InterfaceDecl.Define("interface_decl");
 
-  Parse::NodeId parse_node;
+  Parse::NodeIdOneOf<Parse::InterfaceDeclId, Parse::InterfaceDefinitionStartId>
+      parse_node;
   // No type: an interface declaration is not itself a value. The name of an
   // interface declaration becomes a facet type value.
   // TODO: For a generic interface declaration, the name of the interface
@@ -346,6 +368,7 @@ struct InterfaceDecl {
 struct IntLiteral {
   static constexpr auto Kind = InstKind::IntLiteral.Define("int_literal");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   IntId int_id;
@@ -368,6 +391,7 @@ struct LazyImportRef {
 struct NameRef {
   static constexpr auto Kind = InstKind::NameRef.Define("name_ref");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   NameId name_id;
@@ -377,7 +401,7 @@ struct NameRef {
 struct Namespace {
   static constexpr auto Kind = InstKind::Namespace.Define("namespace");
 
-  Parse::NodeId parse_node;
+  Parse::NamespaceId parse_node;
   TypeId type_id;
   NameScopeId name_scope_id;
 };
@@ -385,6 +409,7 @@ struct Namespace {
 struct NoOp {
   static constexpr auto Kind = InstKind::NoOp.Define("no_op");
 
+  // TODO: Delete since now unused.
   Parse::NodeId parse_node;
   // This instruction doesn't produce a value, so has no type.
 };
@@ -392,6 +417,7 @@ struct NoOp {
 struct Param {
   static constexpr auto Kind = InstKind::Param.Define("param");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   NameId name_id;
@@ -400,6 +426,7 @@ struct Param {
 struct PointerType {
   static constexpr auto Kind = InstKind::PointerType.Define("ptr_type");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   TypeId pointee_id;
@@ -408,7 +435,7 @@ struct PointerType {
 struct RealLiteral {
   static constexpr auto Kind = InstKind::RealLiteral.Define("real_literal");
 
-  Parse::NodeId parse_node;
+  Parse::RealLiteralId parse_node;
   TypeId type_id;
   RealId real_id;
 };
@@ -417,7 +444,8 @@ struct Return {
   static constexpr auto Kind =
       InstKind::Return.Define("return", TerminatorKind::Terminator);
 
-  Parse::NodeId parse_node;
+  Parse::NodeIdOneOf<Parse::FunctionDefinitionId, Parse::ReturnStatementId>
+      parse_node;
   // This is a statement, so has no type.
 };
 
@@ -425,7 +453,7 @@ struct ReturnExpr {
   static constexpr auto Kind =
       InstKind::ReturnExpr.Define("return", TerminatorKind::Terminator);
 
-  Parse::NodeId parse_node;
+  Parse::ReturnStatementId parse_node;
   // This is a statement, so has no type.
   InstId expr_id;
 };
@@ -433,6 +461,7 @@ struct ReturnExpr {
 struct SpliceBlock {
   static constexpr auto Kind = InstKind::SpliceBlock.Define("splice_block");
 
+  // TODO: Can we make this more specific?
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId block_id;
@@ -442,7 +471,7 @@ struct SpliceBlock {
 struct StringLiteral {
   static constexpr auto Kind = InstKind::StringLiteral.Define("string_literal");
 
-  Parse::NodeId parse_node;
+  Parse::StringLiteralId parse_node;
   TypeId type_id;
   StringLiteralValueId string_literal_id;
 };
@@ -450,6 +479,7 @@ struct StringLiteral {
 struct StructAccess {
   static constexpr auto Kind = InstKind::StructAccess.Define("struct_access");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId struct_id;
@@ -459,6 +489,7 @@ struct StructAccess {
 struct StructInit {
   static constexpr auto Kind = InstKind::StructInit.Define("struct_init");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId elements_id;
@@ -468,7 +499,7 @@ struct StructInit {
 struct StructLiteral {
   static constexpr auto Kind = InstKind::StructLiteral.Define("struct_literal");
 
-  Parse::NodeId parse_node;
+  Parse::StructLiteralId parse_node;
   TypeId type_id;
   InstBlockId elements_id;
 };
@@ -476,6 +507,8 @@ struct StructLiteral {
 struct StructType {
   static constexpr auto Kind = InstKind::StructType.Define("struct_type");
 
+  // TODO: Make this more specific. It can be one of: ClassDefinitionId,
+  // StructLiteralId, StructTypeLiteralId
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId fields_id;
@@ -485,6 +518,7 @@ struct StructTypeField {
   static constexpr auto Kind =
       InstKind::StructTypeField.Define("struct_type_field");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   // This instruction is an implementation detail of `StructType`, and doesn't
   // produce a value, so has no type, even though it declares a field with a
@@ -496,6 +530,7 @@ struct StructTypeField {
 struct StructValue {
   static constexpr auto Kind = InstKind::StructValue.Define("struct_value");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId elements_id;
@@ -504,6 +539,7 @@ struct StructValue {
 struct Temporary {
   static constexpr auto Kind = InstKind::Temporary.Define("temporary");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId storage_id;
@@ -514,6 +550,7 @@ struct TemporaryStorage {
   static constexpr auto Kind =
       InstKind::TemporaryStorage.Define("temporary_storage");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
 };
@@ -521,6 +558,7 @@ struct TemporaryStorage {
 struct TupleAccess {
   static constexpr auto Kind = InstKind::TupleAccess.Define("tuple_access");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId tuple_id;
@@ -530,7 +568,7 @@ struct TupleAccess {
 struct TupleIndex {
   static constexpr auto Kind = InstKind::TupleIndex.Define("tuple_index");
 
-  Parse::NodeId parse_node;
+  Parse::IndexExprId parse_node;
   TypeId type_id;
   InstId tuple_id;
   InstId index_id;
@@ -539,6 +577,7 @@ struct TupleIndex {
 struct TupleInit {
   static constexpr auto Kind = InstKind::TupleInit.Define("tuple_init");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId elements_id;
@@ -548,7 +587,7 @@ struct TupleInit {
 struct TupleLiteral {
   static constexpr auto Kind = InstKind::TupleLiteral.Define("tuple_literal");
 
-  Parse::NodeId parse_node;
+  Parse::TupleLiteralId parse_node;
   TypeId type_id;
   InstBlockId elements_id;
 };
@@ -556,6 +595,7 @@ struct TupleLiteral {
 struct TupleType {
   static constexpr auto Kind = InstKind::TupleType.Define("tuple_type");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   TypeBlockId elements_id;
@@ -564,6 +604,7 @@ struct TupleType {
 struct TupleValue {
   static constexpr auto Kind = InstKind::TupleValue.Define("tuple_value");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstBlockId elements_id;
@@ -572,6 +613,7 @@ struct TupleValue {
 struct UnaryOperatorNot {
   static constexpr auto Kind = InstKind::UnaryOperatorNot.Define("not");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId operand_id;
@@ -584,7 +626,7 @@ struct UnboundElementType {
   static constexpr auto Kind =
       InstKind::UnboundElementType.Define("unbound_element_type");
 
-  Parse::NodeId parse_node;
+  Parse::NodeIdOneOf<Parse::BaseDeclId, Parse::BindingPatternId> parse_node;
   TypeId type_id;
   // The class that a value of this type is an element of.
   TypeId class_type_id;
@@ -595,7 +637,7 @@ struct UnboundElementType {
 struct ValueAsRef {
   static constexpr auto Kind = InstKind::ValueAsRef.Define("value_as_ref");
 
-  Parse::NodeId parse_node;
+  Parse::IndexExprId parse_node;
   TypeId type_id;
   InstId value_id;
 };
@@ -604,6 +646,7 @@ struct ValueOfInitializer {
   static constexpr auto Kind =
       InstKind::ValueOfInitializer.Define("value_of_initializer");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   InstId init_id;
@@ -612,16 +655,19 @@ struct ValueOfInitializer {
 struct VarStorage {
   static constexpr auto Kind = InstKind::VarStorage.Define("var");
 
+  // TODO: Make this more specific.
   Parse::NodeId parse_node;
   TypeId type_id;
   NameId name_id;
 };
 
-// HasParseNodeMember<T> is true if T has a `Parse::NodeId parse_node` field.
-template <typename T, typename ParseNodeType = Parse::NodeId T::*>
+// HasParseNodeMember<T> is true if T has a `U parse_node` field,
+// where `U` extends `Parse::NodeId`.
+template <typename T, bool Enabled = true>
 inline constexpr bool HasParseNodeMember = false;
 template <typename T>
-inline constexpr bool HasParseNodeMember<T, decltype(&T::parse_node)> = true;
+inline constexpr bool HasParseNodeMember<
+    T, bool(std::is_base_of_v<Parse::NodeId, decltype(T::parse_node)>)> = true;
 
 // HasTypeIdMember<T> is true if T has a `TypeId type_id` field.
 template <typename T, typename TypeIdType = TypeId T::*>

+ 2 - 1
toolchain/sem_ir/typed_insts_test.cpp

@@ -44,7 +44,8 @@ auto CommonFieldOrder() -> void {
   Inst inst = MakeInstWithNumberedFields(TypedInst::Kind);
   auto typed = inst.As<TypedInst>();
   if constexpr (HasParseNodeMember<TypedInst>) {
-    EXPECT_EQ(typed.parse_node, Parse::NodeId(1));
+    EXPECT_EQ(typed.parse_node,
+              decltype(TypedInst::parse_node)(Parse::NodeId(1)));
   }
   if constexpr (HasTypeIdMember<TypedInst>) {
     EXPECT_EQ(typed.type_id, TypeId(2));