Переглянути джерело

Make structs to distinguish ID versus NodeStore index (#2171)

WDYT of this, to avoid raw int32_t indices? I was looking at the code again and found it hard to sort out. I think this doesn't have overhead.
Jon Ross-Perkins 3 роки тому
батько
коміт
b80e294b6c

+ 14 - 0
toolchain/semantics/node_kind.h

@@ -7,8 +7,22 @@
 
 #include <cstdint>
 
+#include "common/ostream.h"
+
 namespace Carbon::Semantics {
 
+// Type-safe storage of Node IDs.
+struct NodeId {
+  explicit NodeId(int32_t id) : id(id) {}
+
+  void Print(llvm::raw_ostream& out) const { out << "%" << id; }
+
+  // Comparison to help tests.
+  auto operator==(int32_t other) const -> bool { return id == other; }
+
+  int32_t id;
+};
+
 // Meta node information for declarations.
 enum class NodeKind {
   BinaryOperator,

+ 12 - 3
toolchain/semantics/node_ref.h

@@ -12,6 +12,15 @@
 
 namespace Carbon::Semantics {
 
+// Type-safe storage of NodeStore indices.
+struct NodeStoreIndex {
+  explicit NodeStoreIndex(int32_t index) : index(index) {}
+
+  operator int32_t() const { return index; }
+
+  int32_t index;
+};
+
 // The standard structure for nodes.
 //
 // This flyweight pattern is used so that each subtype can be stored in its own
@@ -19,7 +28,7 @@ namespace Carbon::Semantics {
 // quantities are being created.
 class NodeRef {
  public:
-  NodeRef() : NodeRef(NodeKind::Invalid, -1) {}
+  NodeRef() : NodeRef(NodeKind::Invalid, NodeStoreIndex(-1)) {}
 
   auto kind() -> NodeKind { return kind_; }
 
@@ -27,12 +36,12 @@ class NodeRef {
   template <typename... StoredNodeT>
   friend class NodeStoreBase;
 
-  NodeRef(NodeKind kind, int32_t index) : kind_(kind), index_(index) {}
+  NodeRef(NodeKind kind, NodeStoreIndex index) : kind_(kind), index_(index) {}
 
   NodeKind kind_;
 
   // The index of the named entity within its list.
-  int32_t index_;
+  NodeStoreIndex index_;
 };
 
 }  // namespace Carbon::Semantics

+ 5 - 4
toolchain/semantics/node_store.h

@@ -29,7 +29,7 @@ class NodeStoreBase {
   template <typename NodeT>
   auto Store(NodeT node) -> NodeRef {
     auto& node_store = std::get<static_cast<size_t>(NodeT::Kind)>(node_stores_);
-    int32_t index = node_store.size();
+    NodeStoreIndex index(node_store.size());
     node_store.push_back(node);
     return NodeRef(NodeT::Kind, index);
   }
@@ -38,13 +38,14 @@ class NodeStoreBase {
   // store.
   template <typename NodeT>
   auto Get(NodeRef node_ref) const -> const NodeT& {
-    CARBON_CHECK(node_ref.index_ >= 0);
+    CARBON_CHECK(node_ref.index_.index >= 0);
     CARBON_CHECK(node_ref.kind_ == NodeT::Kind)
         << "Kind mismatch: " << static_cast<int>(node_ref.kind_) << " vs "
         << static_cast<int>(NodeT::Kind);
     auto& node_store = std::get<static_cast<size_t>(NodeT::Kind)>(node_stores_);
-    CARBON_CHECK(static_cast<size_t>(node_ref.index_) < node_store.size());
-    return node_store[node_ref.index_];
+    CARBON_CHECK(static_cast<size_t>(node_ref.index_.index) <
+                 node_store.size());
+    return node_store[node_ref.index_.index];
   }
 
  private:

+ 10 - 10
toolchain/semantics/nodes/binary_operator.h

@@ -20,32 +20,32 @@ class BinaryOperator {
 
   static constexpr NodeKind Kind = NodeKind::BinaryOperator;
 
-  explicit BinaryOperator(ParseTree::Node node, int32_t id, Op op,
-                          int32_t lhs_id, int32_t rhs_id)
+  explicit BinaryOperator(ParseTree::Node node, NodeId id, Op op, NodeId lhs_id,
+                          NodeId rhs_id)
       : node_(node), id_(id), op_(op), lhs_id_(lhs_id), rhs_id_(rhs_id) {}
 
   void Print(llvm::raw_ostream& out) const {
-    out << "BinaryOperator(%" << id_ << ", ";
+    out << "BinaryOperator(" << id_ << ", ";
     switch (op_) {
       case Op::Add:
         out << "+";
         break;
     }
-    out << ", %" << lhs_id_ << ", %" << rhs_id_ << ")";
+    out << ", " << lhs_id_ << ", %" << rhs_id_ << ")";
   }
 
   auto node() const -> ParseTree::Node { return node_; }
-  auto id() const -> int32_t { return id_; }
+  auto id() const -> NodeId { return id_; }
   auto op() const -> Op { return op_; }
-  auto lhs_id() const -> int32_t { return lhs_id_; }
-  auto rhs_id() const -> int32_t { return rhs_id_; }
+  auto lhs_id() const -> NodeId { return lhs_id_; }
+  auto rhs_id() const -> NodeId { return rhs_id_; }
 
  private:
   ParseTree::Node node_;
-  int32_t id_;
+  NodeId id_;
   Op op_;
-  int32_t lhs_id_;
-  int32_t rhs_id_;
+  NodeId lhs_id_;
+  NodeId rhs_id_;
 };
 
 }  // namespace Carbon::Semantics

+ 4 - 4
toolchain/semantics/nodes/binary_operator_test_matchers.h

@@ -16,11 +16,11 @@ namespace Carbon::Testing {
 MATCHER_P4(
     BinaryOperator, id_matcher, op_matcher, lhs_id_matcher, rhs_id_matcher,
     llvm::formatv(
-        "BinaryOperator(%{0}, {1}, %{2}, %{3})",
-        ::testing::DescribeMatcher<int32_t>(id_matcher),
+        "BinaryOperator(`{0}`, `{1}`, `{2}`, `{3}`)",
+        ::testing::DescribeMatcher<Semantics::NodeId>(id_matcher),
         ::testing::DescribeMatcher<Semantics::BinaryOperator::Op>(op_matcher),
-        ::testing::DescribeMatcher<int32_t>(lhs_id_matcher),
-        ::testing::DescribeMatcher<int32_t>(rhs_id_matcher))) {
+        ::testing::DescribeMatcher<Semantics::NodeId>(lhs_id_matcher),
+        ::testing::DescribeMatcher<Semantics::NodeId>(rhs_id_matcher))) {
   const Semantics::NodeRef& node_ref = arg;
   if (auto op =
           SemanticsIRForTest::GetNode<Semantics::BinaryOperator>(node_ref)) {

+ 4 - 4
toolchain/semantics/nodes/function.h

@@ -19,7 +19,7 @@ class Function {
  public:
   static constexpr NodeKind Kind = NodeKind::Function;
 
-  Function(ParseTree::Node node, int32_t id,
+  Function(ParseTree::Node node, NodeId id,
            // llvm::SmallVector<PatternBinding, 0> params,
            // llvm::SmallVector<NodeRef, 0> return_type,
            llvm::SmallVector<NodeRef, 0> body)
@@ -31,7 +31,7 @@ class Function {
 
   void Print(llvm::raw_ostream& out,
              std::function<void(NodeRef)> print_node_ref) const {
-    out << "Function(%" << id_ << ", {";
+    out << "Function(" << id_ << ", {";
     llvm::ListSeparator sep(", ");
     for (auto& node_ref : body_) {
       out << sep;
@@ -41,7 +41,7 @@ class Function {
   }
 
   auto node() const -> ParseTree::Node { return node_; }
-  auto id() const -> int32_t { return id_; }
+  auto id() const -> NodeId { return id_; }
   // auto params() const -> llvm::ArrayRef<PatternBinding> { return params_; }
   // auto return_expr() const -> llvm::Optional<Statement> { return
   // return_expr_; }
@@ -52,7 +52,7 @@ class Function {
   ParseTree::Node node_;
 
   // The function's ID.
-  int32_t id_;
+  NodeId id_;
 
   // Regular function parameters.
   // llvm::SmallVector<PatternBinding, 0> params_;

+ 2 - 2
toolchain/semantics/nodes/function_test_matchers.h

@@ -15,8 +15,8 @@ namespace Carbon::Testing {
 
 MATCHER_P2(Function, id_matcher, body_matcher,
            llvm::formatv(
-               "Function(%{0}, {1})",
-               ::testing::DescribeMatcher<int32_t>(id_matcher),
+               "Function(`{0}`, `{1}`)",
+               ::testing::DescribeMatcher<Semantics::NodeId>(id_matcher),
                ::testing::DescribeMatcher<llvm::ArrayRef<Semantics::NodeRef>>(
                    body_matcher))) {
   const Semantics::NodeRef& node_ref = arg;

+ 4 - 4
toolchain/semantics/nodes/integer_literal.h

@@ -16,21 +16,21 @@ class IntegerLiteral {
  public:
   static constexpr NodeKind Kind = NodeKind::IntegerLiteral;
 
-  explicit IntegerLiteral(ParseTree::Node node, int32_t id,
+  explicit IntegerLiteral(ParseTree::Node node, NodeId id,
                           const llvm::APInt& value)
       : node_(node), id_(id), value_(&value) {}
 
   void Print(llvm::raw_ostream& out) const {
-    out << "IntegerLiteral(%" << id_ << ", " << *value_ << ")";
+    out << "IntegerLiteral(" << id_ << ", " << *value_ << ")";
   }
 
   auto node() const -> ParseTree::Node { return node_; }
-  auto id() const -> int32_t { return id_; }
+  auto id() const -> NodeId { return id_; }
   auto value() const -> const llvm::APInt& { return *value_; }
 
  private:
   ParseTree::Node node_;
-  int32_t id_;
+  NodeId id_;
   const llvm::APInt* value_;
 };
 

+ 2 - 2
toolchain/semantics/nodes/integer_literal_test_matchers.h

@@ -15,8 +15,8 @@ namespace Carbon::Testing {
 
 MATCHER_P2(
     IntegerLiteral, id_matcher, value_matcher,
-    llvm::formatv("IntegerLiteral(%{0}, {1})",
-                  ::testing::DescribeMatcher<int32_t>(id_matcher),
+    llvm::formatv("IntegerLiteral(`{0}`, `{1}`)",
+                  ::testing::DescribeMatcher<Semantics::NodeId>(id_matcher),
                   ::testing::DescribeMatcher<llvm::APInt>(value_matcher))) {
   const Semantics::NodeRef& node_ref = arg;
   if (auto lit =

+ 4 - 6
toolchain/semantics/nodes/return.h

@@ -19,13 +19,13 @@ class Return {
  public:
   static constexpr NodeKind Kind = NodeKind::Return;
 
-  Return(ParseTree::Node node, llvm::Optional<int32_t> target_id)
+  Return(ParseTree::Node node, llvm::Optional<NodeId> target_id)
       : node_(node), target_id_(target_id) {}
 
   void Print(llvm::raw_ostream& out) const {
     out << "Return(";
     if (target_id_) {
-      out << "%" << *target_id_;
+      out << *target_id_;
     } else {
       out << "None";
     }
@@ -33,13 +33,11 @@ class Return {
   }
 
   auto node() const -> ParseTree::Node { return node_; }
-  auto target_id() const -> const llvm::Optional<int32_t>& {
-    return target_id_;
-  }
+  auto target_id() const -> const llvm::Optional<NodeId>& { return target_id_; }
 
  private:
   ParseTree::Node node_;
-  llvm::Optional<int32_t> target_id_;
+  llvm::Optional<NodeId> target_id_;
 };
 
 }  // namespace Carbon::Semantics

+ 5 - 4
toolchain/semantics/nodes/return_test_matchers.h

@@ -13,10 +13,11 @@
 
 namespace Carbon::Testing {
 
-MATCHER_P(Return, target_id_matcher,
-          llvm::formatv("Return({0})",
-                        ::testing::DescribeMatcher<llvm::Optional<int32_t>>(
-                            target_id_matcher))) {
+MATCHER_P(
+    Return, target_id_matcher,
+    llvm::formatv("Return(`{0}`)",
+                  ::testing::DescribeMatcher<llvm::Optional<Semantics::NodeId>>(
+                      target_id_matcher))) {
   const Semantics::NodeRef& node_ref = arg;
   if (auto ret = SemanticsIRForTest::GetNode<Semantics::Return>(node_ref)) {
     return ExplainMatchResult(target_id_matcher, ret->target_id(),

+ 4 - 4
toolchain/semantics/nodes/set_name.h

@@ -17,16 +17,16 @@ class SetName {
  public:
   static constexpr NodeKind Kind = NodeKind::SetName;
 
-  SetName(ParseTree::Node node, llvm::StringRef name, int32_t target_id)
+  SetName(ParseTree::Node node, llvm::StringRef name, NodeId target_id)
       : node_(node), name_(name), target_id_(target_id) {}
 
   void Print(llvm::raw_ostream& out) const {
-    out << "SetName(`" << name_ << "`, %" << target_id_ << ")";
+    out << "SetName(`" << name_ << "`, " << target_id_ << ")";
   }
 
   auto node() const -> ParseTree::Node { return node_; }
   auto name() const -> llvm::StringRef { return name_; }
-  auto target_id() const -> int32_t { return target_id_; }
+  auto target_id() const -> NodeId { return target_id_; }
 
  private:
   // The name node.
@@ -36,7 +36,7 @@ class SetName {
   llvm::StringRef name_;
 
   // The ID being named.
-  int32_t target_id_;
+  NodeId target_id_;
 };
 
 }  // namespace Carbon::Semantics

+ 4 - 3
toolchain/semantics/nodes/set_name_test_matchers.h

@@ -15,9 +15,10 @@ namespace Carbon::Testing {
 
 MATCHER_P2(
     SetName, name_matcher, target_id_matcher,
-    llvm::formatv("SetName(`{0}`, %`{1}`)",
-                  ::testing::DescribeMatcher<llvm::StringRef>(name_matcher),
-                  ::testing::DescribeMatcher<int32_t>(target_id_matcher))) {
+    llvm::formatv(
+        "SetName(`{0}`, `{1}`)",
+        ::testing::DescribeMatcher<llvm::StringRef>(name_matcher),
+        ::testing::DescribeMatcher<Semantics::NodeId>(target_id_matcher))) {
   const Semantics::NodeRef& node_ref = arg;
   if (auto node = SemanticsIRForTest::GetNode<Semantics::SetName>(node_ref)) {
     return ExplainMatchResult(name_matcher, node->name(), result_listener) &&

+ 3 - 3
toolchain/semantics/semantics_ir_factory.cpp

@@ -87,7 +87,7 @@ auto SemanticsIRFactory::TransformCodeBlock(ParseTree::Node node)
 
 void SemanticsIRFactory::TransformDeclaredName(
     llvm::SmallVector<Semantics::NodeRef, 0>& nodes, ParseTree::Node node,
-    int32_t target_id) {
+    Semantics::NodeId target_id) {
   CARBON_CHECK(parse_tree().node_kind(node) == ParseNodeKind::DeclaredName());
   RequireNodeEmpty(node);
 
@@ -97,7 +97,7 @@ void SemanticsIRFactory::TransformDeclaredName(
 
 void SemanticsIRFactory::TransformExpression(
     llvm::SmallVector<Semantics::NodeRef, 0>& nodes, ParseTree::Node node,
-    int32_t target_id) {
+    Semantics::NodeId target_id) {
   switch (auto node_kind = parse_tree().node_kind(node)) {
     case ParseNodeKind::Literal(): {
       RequireNodeEmpty(node);
@@ -167,7 +167,7 @@ static auto GetBinaryOp(TokenKind kind) -> Semantics::BinaryOperator::Op {
 
 void SemanticsIRFactory::TransformInfixOperator(
     llvm::SmallVector<Semantics::NodeRef, 0>& nodes, ParseTree::Node node,
-    int32_t target_id) {
+    Semantics::NodeId target_id) {
   CARBON_CHECK(parse_tree().node_kind(node) == ParseNodeKind::InfixOperator());
 
   auto token = parse_tree().node_token(node);

+ 7 - 4
toolchain/semantics/semantics_ir_factory.h

@@ -40,15 +40,16 @@ class SemanticsIRFactory {
   auto TransformCodeBlock(ParseTree::Node node)
       -> llvm::SmallVector<Semantics::NodeRef, 0>;
   void TransformDeclaredName(llvm::SmallVector<Semantics::NodeRef, 0>& nodes,
-                             ParseTree::Node node, int32_t target_id);
+                             ParseTree::Node node, Semantics::NodeId target_id);
   void TransformExpression(llvm::SmallVector<Semantics::NodeRef, 0>& nodes,
-                           ParseTree::Node node, int32_t target_id);
+                           ParseTree::Node node, Semantics::NodeId target_id);
   // auto TransformExpressionStatement(ParseTree::Node node)
   //   -> Semantics::Statement;
   void TransformFunctionDeclaration(
       llvm::SmallVector<Semantics::NodeRef, 0>& nodes, ParseTree::Node node);
   void TransformInfixOperator(llvm::SmallVector<Semantics::NodeRef, 0>& nodes,
-                              ParseTree::Node node, int32_t target_id);
+                              ParseTree::Node node,
+                              Semantics::NodeId target_id);
   // auto TransformParameterList(ParseTree::Node node)
   //   -> llvm::SmallVector<Semantics::PatternBinding, 0>
   // auto TransformPatternBinding(ParseTree::Node node)
@@ -58,7 +59,9 @@ class SemanticsIRFactory {
                                 ParseTree::Node node);
 
   // Returns a unique ID for the SemanticsIR.
-  auto next_id() -> int32_t { return id_counter_++; }
+  auto next_id() -> Semantics::NodeId {
+    return Semantics::NodeId(id_counter_++);
+  }
 
   // Convenience accessor.
   auto parse_tree() -> const ParseTree& { return *semantics_.parse_tree_; }