Browse Source

Bare-bones support for forward-declared classes. (#3294)

This is mostly scaffolding, but is just about enough for pointers to
classes to work properly as types.
Richard Smith 2 years ago
parent
commit
4e64b1948d

+ 6 - 0
toolchain/check/context.cpp

@@ -336,6 +336,10 @@ static auto ProfileType(Context& semantics_context, SemIR::Node node,
     case SemIR::Builtin::Kind:
     case SemIR::Builtin::Kind:
       canonical_id.AddInteger(node.As<SemIR::Builtin>().builtin_kind.AsInt());
       canonical_id.AddInteger(node.As<SemIR::Builtin>().builtin_kind.AsInt());
       break;
       break;
+    case SemIR::ClassDeclaration::Kind:
+      canonical_id.AddInteger(
+          node.As<SemIR::ClassDeclaration>().class_id.index);
+      break;
     case SemIR::CrossReference::Kind: {
     case SemIR::CrossReference::Kind: {
       // TODO: Cross-references should be canonicalized by looking at their
       // TODO: Cross-references should be canonicalized by looking at their
       // target rather than treating them as new unique types.
       // target rather than treating them as new unique types.
@@ -385,6 +389,8 @@ auto Context::CanonicalizeTypeAndAddNodeIfNew(SemIR::Node node)
 }
 }
 
 
 auto Context::CanonicalizeType(SemIR::NodeId node_id) -> SemIR::TypeId {
 auto Context::CanonicalizeType(SemIR::NodeId node_id) -> SemIR::TypeId {
+  node_id = FollowNameReferences(node_id);
+
   auto it = canonical_types_.find(node_id);
   auto it = canonical_types_.find(node_id);
   if (it != canonical_types_.end()) {
   if (it != canonical_types_.end()) {
     return it->second;
     return it->second;

+ 37 - 6
toolchain/check/handle_class.cpp

@@ -6,21 +6,52 @@
 
 
 namespace Carbon::Check {
 namespace Carbon::Check {
 
 
-auto HandleClassDeclaration(Context& context, Parse::Node parse_node) -> bool {
-  return context.TODO(parse_node, "HandleClassDeclaration");
+auto HandleClassIntroducer(Context& context, Parse::Node parse_node) -> bool {
+  // Create a node block to hold the nodes created as part of the class
+  // signature, such as generic parameters.
+  context.node_block_stack().Push();
+  // Push the bracketing node.
+  context.node_stack().Push(parse_node);
+  // A name should always follow.
+  context.declaration_name_stack().Push();
+  return true;
 }
 }
 
 
-auto HandleClassDefinition(Context& context, Parse::Node parse_node) -> bool {
-  return context.TODO(parse_node, "HandleClassDefinition");
+static auto BuildClassDeclaration(Context& context) -> void {
+  auto name_context = context.declaration_name_stack().Pop();
+
+  auto class_keyword =
+      context.node_stack()
+          .PopForSoloParseNode<Parse::NodeKind::ClassIntroducer>();
+
+  // TODO: Track this somewhere.
+  context.node_block_stack().Pop();
+
+  auto class_id = context.semantics_ir().AddClass(
+      {.name_id = name_context.state ==
+                          DeclarationNameStack::NameContext::State::Unresolved
+                      ? name_context.unresolved_name_id
+                      : SemIR::StringId(SemIR::StringId::InvalidIndex)});
+  auto class_decl_id = context.AddNode(SemIR::ClassDeclaration(
+      class_keyword, SemIR::TypeId::TypeType, class_id));
+  context.declaration_name_stack().AddNameToLookup(name_context, class_decl_id);
+}
+
+auto HandleClassDeclaration(Context& context, Parse::Node /*parse_node*/)
+    -> bool {
+  BuildClassDeclaration(context);
+  return true;
 }
 }
 
 
 auto HandleClassDefinitionStart(Context& context, Parse::Node parse_node)
 auto HandleClassDefinitionStart(Context& context, Parse::Node parse_node)
     -> bool {
     -> bool {
+  BuildClassDeclaration(context);
+  // TODO: Introduce `Self`.
   return context.TODO(parse_node, "HandleClassDefinitionStart");
   return context.TODO(parse_node, "HandleClassDefinitionStart");
 }
 }
 
 
-auto HandleClassIntroducer(Context& context, Parse::Node parse_node) -> bool {
-  return context.TODO(parse_node, "HandleClassIntroducer");
+auto HandleClassDefinition(Context& context, Parse::Node parse_node) -> bool {
+  return context.TODO(parse_node, "HandleClassDefinition");
 }
 }
 
 
 }  // namespace Carbon::Check
 }  // namespace Carbon::Check

+ 13 - 0
toolchain/check/node_stack.h

@@ -108,6 +108,11 @@ class NodeStack {
       RequireParseKind<RequiredParseKind>(back.first);
       RequireParseKind<RequiredParseKind>(back.first);
       return back;
       return back;
     }
     }
+    if constexpr (RequiredIdKind == IdKind::ClassId) {
+      auto back = PopWithParseNode<SemIR::ClassId>();
+      RequireParseKind<RequiredParseKind>(back.first);
+      return back;
+    }
     if constexpr (RequiredIdKind == IdKind::StringId) {
     if constexpr (RequiredIdKind == IdKind::StringId) {
       auto back = PopWithParseNode<SemIR::StringId>();
       auto back = PopWithParseNode<SemIR::StringId>();
       RequireParseKind<RequiredParseKind>(back.first);
       RequireParseKind<RequiredParseKind>(back.first);
@@ -154,6 +159,9 @@ class NodeStack {
     if constexpr (RequiredIdKind == IdKind::FunctionId) {
     if constexpr (RequiredIdKind == IdKind::FunctionId) {
       return back.id<SemIR::FunctionId>();
       return back.id<SemIR::FunctionId>();
     }
     }
+    if constexpr (RequiredIdKind == IdKind::ClassId) {
+      return back.id<SemIR::ClassId>();
+    }
     if constexpr (RequiredIdKind == IdKind::StringId) {
     if constexpr (RequiredIdKind == IdKind::StringId) {
       return back.id<SemIR::StringId>();
       return back.id<SemIR::StringId>();
     }
     }
@@ -177,6 +185,7 @@ class NodeStack {
     NodeId,
     NodeId,
     NodeBlockId,
     NodeBlockId,
     FunctionId,
     FunctionId,
+    ClassId,
     StringId,
     StringId,
     TypeId,
     TypeId,
     // No associated ID type.
     // No associated ID type.
@@ -272,6 +281,7 @@ class NodeStack {
       case Parse::NodeKind::Name:
       case Parse::NodeKind::Name:
         return IdKind::StringId;
         return IdKind::StringId;
       case Parse::NodeKind::ArrayExpressionSemi:
       case Parse::NodeKind::ArrayExpressionSemi:
+      case Parse::NodeKind::ClassIntroducer:
       case Parse::NodeKind::CodeBlockStart:
       case Parse::NodeKind::CodeBlockStart:
       case Parse::NodeKind::FunctionIntroducer:
       case Parse::NodeKind::FunctionIntroducer:
       case Parse::NodeKind::IfStatementElse:
       case Parse::NodeKind::IfStatementElse:
@@ -302,6 +312,9 @@ class NodeStack {
     if constexpr (std::is_same_v<IdT, SemIR::FunctionId>) {
     if constexpr (std::is_same_v<IdT, SemIR::FunctionId>) {
       return IdKind::FunctionId;
       return IdKind::FunctionId;
     }
     }
+    if constexpr (std::is_same_v<IdT, SemIR::ClassId>) {
+      return IdKind::ClassId;
+    }
     if constexpr (std::is_same_v<IdT, SemIR::StringId>) {
     if constexpr (std::is_same_v<IdT, SemIR::StringId>) {
       return IdKind::StringId;
       return IdKind::StringId;
     }
     }

+ 2 - 0
toolchain/check/testdata/basics/builtin_nodes.carbon

@@ -11,6 +11,8 @@
 // CHECK:STDOUT:   - cross_reference_irs_size: 1
 // CHECK:STDOUT:   - cross_reference_irs_size: 1
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
+// CHECK:STDOUT:     classes: [
+// CHECK:STDOUT:     ]
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     reals: [
 // CHECK:STDOUT:     reals: [

+ 4 - 0
toolchain/check/testdata/basics/multifile_raw_and_textual_ir.carbon

@@ -20,6 +20,8 @@ fn B() {}
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
+// CHECK:STDOUT:     classes: [
+// CHECK:STDOUT:     ]
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     reals: [
 // CHECK:STDOUT:     reals: [
@@ -61,6 +63,8 @@ fn B() {}
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
+// CHECK:STDOUT:     classes: [
+// CHECK:STDOUT:     ]
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     reals: [
 // CHECK:STDOUT:     reals: [

+ 4 - 0
toolchain/check/testdata/basics/multifile_raw_ir.carbon

@@ -20,6 +20,8 @@ fn B() {}
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
+// CHECK:STDOUT:     classes: [
+// CHECK:STDOUT:     ]
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     reals: [
 // CHECK:STDOUT:     reals: [
@@ -52,6 +54,8 @@ fn B() {}
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:       {name: str0, param_refs: block0, body: [block1]},
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
+// CHECK:STDOUT:     classes: [
+// CHECK:STDOUT:     ]
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     reals: [
 // CHECK:STDOUT:     reals: [

+ 2 - 0
toolchain/check/testdata/basics/raw_and_textual_ir.carbon

@@ -18,6 +18,8 @@ fn Foo(n: i32) -> (i32, f64) {
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:       {name: str0, param_refs: block1, return_type: type3, return_slot: node+4, body: [block4]},
 // CHECK:STDOUT:       {name: str0, param_refs: block1, return_type: type3, return_slot: node+4, body: [block4]},
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
+// CHECK:STDOUT:     classes: [
+// CHECK:STDOUT:     ]
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:       2,
 // CHECK:STDOUT:       2,
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]

+ 2 - 0
toolchain/check/testdata/basics/raw_ir.carbon

@@ -18,6 +18,8 @@ fn Foo(n: i32) -> (i32, f64) {
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:     functions: [
 // CHECK:STDOUT:       {name: str0, param_refs: block1, return_type: type3, return_slot: node+4, body: [block4]},
 // CHECK:STDOUT:       {name: str0, param_refs: block1, return_type: type3, return_slot: node+4, body: [block4]},
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]
+// CHECK:STDOUT:     classes: [
+// CHECK:STDOUT:     ]
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:     integers: [
 // CHECK:STDOUT:       2,
 // CHECK:STDOUT:       2,
 // CHECK:STDOUT:     ]
 // CHECK:STDOUT:     ]

+ 22 - 0
toolchain/check/testdata/class/forward_declared.carbon

@@ -0,0 +1,22 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// AUTOUPDATE
+
+class Class;
+
+fn F(p: Class*) -> Class* { return p; }
+
+// CHECK:STDOUT: file "forward_declared.carbon" {
+// CHECK:STDOUT:   %Class: type = class_declaration @Class
+// CHECK:STDOUT:   %F: <function> = fn_decl @F
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: class @Class;
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @F(%p: Class*) -> Class* {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   %p.ref: Class* = name_reference "p", %p
+// CHECK:STDOUT:   return %p.ref
+// CHECK:STDOUT: }

+ 6 - 0
toolchain/lower/handle.cpp

@@ -163,6 +163,12 @@ auto HandleCall(FunctionContext& context, SemIR::NodeId node_id,
   }
   }
 }
 }
 
 
+auto HandleClassDeclaration(FunctionContext& /*context*/,
+                            SemIR::NodeId /*node_id*/,
+                            SemIR::ClassDeclaration /*node*/) -> void {
+  // No action to perform.
+}
+
 auto HandleDereference(FunctionContext& context, SemIR::NodeId node_id,
 auto HandleDereference(FunctionContext& context, SemIR::NodeId node_id,
                        SemIR::Dereference node) -> void {
                        SemIR::Dereference node) -> void {
   context.SetLocal(node_id, context.GetLocal(node.pointer_id));
   context.SetLocal(node_id, context.GetLocal(node.pointer_id));

+ 15 - 0
toolchain/sem_ir/file.cpp

@@ -150,6 +150,7 @@ auto File::Print(llvm::raw_ostream& out, bool include_builtins) const -> void {
       << "\n";
       << "\n";
 
 
   PrintList(out, "functions", functions_);
   PrintList(out, "functions", functions_);
+  PrintList(out, "classes", classes_);
   // Integer values are APInts, and default to a signed print, but we currently
   // Integer values are APInts, and default to a signed print, but we currently
   // treat them as unsigned.
   // treat them as unsigned.
   PrintList(out, "integers", integers_,
   PrintList(out, "integers", integers_,
@@ -179,6 +180,7 @@ static auto GetTypePrecedence(NodeKind kind) -> int {
   switch (kind) {
   switch (kind) {
     case ArrayType::Kind:
     case ArrayType::Kind:
     case Builtin::Kind:
     case Builtin::Kind:
+    case ClassDeclaration::Kind:
     case StructType::Kind:
     case StructType::Kind:
     case TupleType::Kind:
     case TupleType::Kind:
       return 0;
       return 0;
@@ -285,6 +287,12 @@ auto File::StringifyType(TypeId type_id, bool in_type_context) const
         }
         }
         break;
         break;
       }
       }
+      case ClassDeclaration::Kind: {
+        auto class_name_id =
+            GetClass(node.As<ClassDeclaration>().class_id).name_id;
+        out << GetString(class_name_id);
+        break;
+      }
       case ConstType::Kind: {
       case ConstType::Kind: {
         if (step.index == 0) {
         if (step.index == 0) {
           out << "const ";
           out << "const ";
@@ -463,6 +471,7 @@ auto GetExpressionCategory(const File& file, NodeId node_id)
       case BindValue::Kind:
       case BindValue::Kind:
       case BlockArg::Kind:
       case BlockArg::Kind:
       case BoolLiteral::Kind:
       case BoolLiteral::Kind:
+      case ClassDeclaration::Kind:
       case ConstType::Kind:
       case ConstType::Kind:
       case IntegerLiteral::Kind:
       case IntegerLiteral::Kind:
       case Parameter::Kind:
       case Parameter::Kind:
@@ -637,6 +646,12 @@ auto GetValueRepresentation(const File& file, TypeId type_id)
         return {.kind = ValueRepresentation::Pointer, .type = type_id};
         return {.kind = ValueRepresentation::Pointer, .type = type_id};
       }
       }
 
 
+      case ClassDeclaration::Kind: {
+        // TODO: Pick the default value representation in a smarter way.
+        // TODO: Allow the value representation for a class to be customized.
+        return {.kind = ValueRepresentation::Pointer, .type = type_id};
+      }
+
       case Builtin::Kind:
       case Builtin::Kind:
         // clang warns on unhandled enum values; clang-tidy is incorrect here.
         // clang warns on unhandled enum values; clang-tidy is incorrect here.
         // NOLINTNEXTLINE(bugprone-switch-missing-default-case)
         // NOLINTNEXTLINE(bugprone-switch-missing-default-case)

+ 32 - 0
toolchain/sem_ir/file.h

@@ -51,6 +51,17 @@ struct Function : public Printable<Function> {
   llvm::SmallVector<NodeBlockId> body_block_ids;
   llvm::SmallVector<NodeBlockId> body_block_ids;
 };
 };
 
 
+// A class.
+struct Class : public Printable<Class> {
+  auto Print(llvm::raw_ostream& out) const -> void {
+    out << "{name: " << name_id;
+    out << "}";
+  }
+
+  // The class name.
+  StringId name_id;
+};
+
 // TODO: Replace this with a Rational type, per the design:
 // TODO: Replace this with a Rational type, per the design:
 // docs/design/expressions/literals.md
 // docs/design/expressions/literals.md
 struct Real : public Printable<Real> {
 struct Real : public Printable<Real> {
@@ -121,6 +132,23 @@ class File : public Printable<File> {
     return functions_[function_id.index];
     return functions_[function_id.index];
   }
   }
 
 
+  // Adds a class, returning an ID to reference it.
+  auto AddClass(Class class_info) -> ClassId {
+    ClassId id(classes_.size());
+    // TODO: Return failure on overflow instead of crashing.
+    CARBON_CHECK(id.index >= 0);
+    classes_.push_back(class_info);
+    return id;
+  }
+
+  // Returns the requested class.
+  auto GetClass(ClassId class_id) const -> const Class& {
+    return classes_[class_id.index];
+  }
+
+  // Returns the requested class.
+  auto GetClass(ClassId class_id) -> Class& { return classes_[class_id.index]; }
+
   // Adds an integer value, returning an ID to reference it.
   // Adds an integer value, returning an ID to reference it.
   auto AddInteger(llvm::APInt integer) -> IntegerId {
   auto AddInteger(llvm::APInt integer) -> IntegerId {
     IntegerId id(integers_.size());
     IntegerId id(integers_.size());
@@ -327,6 +355,7 @@ class File : public Printable<File> {
       -> std::string;
       -> std::string;
 
 
   auto functions_size() const -> int { return functions_.size(); }
   auto functions_size() const -> int { return functions_.size(); }
+  auto classes_size() const -> int { return classes_.size(); }
   auto nodes_size() const -> int { return nodes_.size(); }
   auto nodes_size() const -> int { return nodes_.size(); }
   auto node_blocks_size() const -> int { return node_blocks_.size(); }
   auto node_blocks_size() const -> int { return node_blocks_.size(); }
 
 
@@ -375,6 +404,9 @@ class File : public Printable<File> {
   // Storage for callable objects.
   // Storage for callable objects.
   llvm::SmallVector<Function> functions_;
   llvm::SmallVector<Function> functions_;
 
 
+  // Storage for classes.
+  llvm::SmallVector<Class> classes_;
+
   // Related IRs. There will always be at least 2 entries, the builtin IR (used
   // Related IRs. There will always be at least 2 entries, the builtin IR (used
   // for references of builtins) followed by the current IR (used for references
   // for references of builtins) followed by the current IR (used for references
   // crossing node blocks).
   // crossing node blocks).

+ 1 - 0
toolchain/sem_ir/file_test.cpp

@@ -47,6 +47,7 @@ TEST(SemIRTest, YAML) {
   auto file = Yaml::Sequence(ElementsAre(Yaml::Mapping(ElementsAre(
   auto file = Yaml::Sequence(ElementsAre(Yaml::Mapping(ElementsAre(
       Pair("cross_reference_irs_size", "1"),
       Pair("cross_reference_irs_size", "1"),
       Pair("functions", Yaml::Sequence(SizeIs(1))),
       Pair("functions", Yaml::Sequence(SizeIs(1))),
+      Pair("classes", Yaml::Sequence(SizeIs(0))),
       Pair("integers", Yaml::Sequence(ElementsAre("0"))),
       Pair("integers", Yaml::Sequence(ElementsAre("0"))),
       Pair("reals", Yaml::Sequence(IsEmpty())),
       Pair("reals", Yaml::Sequence(IsEmpty())),
       Pair("strings", Yaml::Sequence(ElementsAre("F", "x"))),
       Pair("strings", Yaml::Sequence(ElementsAre("F", "x"))),

+ 59 - 2
toolchain/sem_ir/formatter.cpp

@@ -37,7 +37,8 @@ class NodeNamer {
         semantics_ir_(semantics_ir) {
         semantics_ir_(semantics_ir) {
     nodes.resize(semantics_ir.nodes_size());
     nodes.resize(semantics_ir.nodes_size());
     labels.resize(semantics_ir.node_blocks_size());
     labels.resize(semantics_ir.node_blocks_size());
-    scopes.resize(1 + semantics_ir.functions_size());
+    scopes.resize(1 + semantics_ir.functions_size() +
+                  semantics_ir.classes_size());
 
 
     // Build the package scope.
     // Build the package scope.
     GetScopeInfo(ScopeIndex::Package).name =
     GetScopeInfo(ScopeIndex::Package).name =
@@ -74,11 +75,33 @@ class NodeNamer {
         AddBlockLabel(fn_scope, block_id);
         AddBlockLabel(fn_scope, block_id);
       }
       }
     }
     }
+
+    // Build each class scope.
+    for (int i : llvm::seq(semantics_ir.classes_size())) {
+      auto class_id = ClassId(i);
+      auto class_scope = GetScopeFor(class_id);
+      const auto& class_info = semantics_ir.GetClass(class_id);
+      // TODO: Provide a location for the class for use as a
+      // disambiguator.
+      auto class_loc = Parse::Node::Invalid;
+      GetScopeInfo(class_scope).name = globals.AllocateName(
+          *this, class_loc,
+          class_info.name_id.is_valid()
+              ? semantics_ir.GetString(class_info.name_id).str()
+              : "");
+      // TODO: Handle names declared in the class scope.
+    }
   }
   }
 
 
   // Returns the scope index corresponding to a function.
   // Returns the scope index corresponding to a function.
   auto GetScopeFor(FunctionId fn_id) -> ScopeIndex {
   auto GetScopeFor(FunctionId fn_id) -> ScopeIndex {
-    return static_cast<ScopeIndex>(fn_id.index + 1);
+    return static_cast<ScopeIndex>(1 + fn_id.index);
+  }
+
+  // Returns the scope index corresponding to a class.
+  auto GetScopeFor(ClassId class_id) -> ScopeIndex {
+    return static_cast<ScopeIndex>(1 + semantics_ir_.functions_size() +
+                                   class_id.index);
   }
   }
 
 
   // Returns the IR name to use for a function.
   // Returns the IR name to use for a function.
@@ -89,6 +112,14 @@ class NodeNamer {
     return GetScopeInfo(GetScopeFor(fn_id)).name.str();
     return GetScopeInfo(GetScopeFor(fn_id)).name.str();
   }
   }
 
 
+  // Returns the IR name to use for a class.
+  auto GetNameFor(ClassId class_id) -> llvm::StringRef {
+    if (!class_id.is_valid()) {
+      return "invalid";
+    }
+    return GetScopeInfo(GetScopeFor(class_id)).name.str();
+  }
+
   // Returns the IR name to use for a node, when referenced from a given scope.
   // Returns the IR name to use for a node, when referenced from a given scope.
   auto GetNameFor(ScopeIndex scope_idx, NodeId node_id) -> std::string {
   auto GetNameFor(ScopeIndex scope_idx, NodeId node_id) -> std::string {
     if (!node_id.is_valid()) {
     if (!node_id.is_valid()) {
@@ -393,6 +424,12 @@ class NodeNamer {
                   .name_id);
                   .name_id);
           continue;
           continue;
         }
         }
+        case ClassDeclaration::Kind: {
+          add_node_name_id(
+              semantics_ir_.GetClass(node.As<ClassDeclaration>().class_id)
+                  .name_id);
+          continue;
+        }
         case NameReference::Kind: {
         case NameReference::Kind: {
           add_node_name(
           add_node_name(
               semantics_ir_.GetString(node.As<NameReference>().name_id).str() +
               semantics_ir_.GetString(node.As<NameReference>().name_id).str() +
@@ -458,11 +495,25 @@ class Formatter {
     }
     }
     out_ << "}\n";
     out_ << "}\n";
 
 
+    for (int i : llvm::seq(semantics_ir_.classes_size())) {
+      FormatClass(ClassId(i));
+    }
+
     for (int i : llvm::seq(semantics_ir_.functions_size())) {
     for (int i : llvm::seq(semantics_ir_.functions_size())) {
       FormatFunction(FunctionId(i));
       FormatFunction(FunctionId(i));
     }
     }
   }
   }
 
 
+  auto FormatClass(ClassId id) -> void {
+    const Class& class_info = semantics_ir_.GetClass(id);
+
+    out_ << "\nclass ";
+    FormatClassName(id);
+    // TODO: Format class definitions.
+    (void)class_info;
+    out_ << ";\n";
+  }
+
   auto FormatFunction(FunctionId id) -> void {
   auto FormatFunction(FunctionId id) -> void {
     const Function& fn = semantics_ir_.GetFunction(id);
     const Function& fn = semantics_ir_.GetFunction(id);
 
 
@@ -728,6 +779,8 @@ class Formatter {
 
 
   auto FormatArg(FunctionId id) -> void { FormatFunctionName(id); }
   auto FormatArg(FunctionId id) -> void { FormatFunctionName(id); }
 
 
+  auto FormatArg(ClassId id) -> void { FormatClassName(id); }
+
   auto FormatArg(IntegerId id) -> void {
   auto FormatArg(IntegerId id) -> void {
     semantics_ir_.GetInteger(id).print(out_, /*isSigned=*/false);
     semantics_ir_.GetInteger(id).print(out_, /*isSigned=*/false);
   }
   }
@@ -815,6 +868,10 @@ class Formatter {
     out_ << node_namer_.GetNameFor(id);
     out_ << node_namer_.GetNameFor(id);
   }
   }
 
 
+  auto FormatClassName(ClassId id) -> void {
+    out_ << node_namer_.GetNameFor(id);
+  }
+
   auto FormatType(TypeId id) -> void {
   auto FormatType(TypeId id) -> void {
     if (!id.is_valid()) {
     if (!id.is_valid()) {
       out_ << "invalid";
       out_ << "invalid";

+ 13 - 0
toolchain/sem_ir/node.h

@@ -59,6 +59,15 @@ struct FunctionId : public IndexBase, public Printable<FunctionId> {
   }
   }
 };
 };
 
 
+// The ID of a class.
+struct ClassId : public IndexBase, public Printable<ClassId> {
+  using IndexBase::IndexBase;
+  auto Print(llvm::raw_ostream& out) const -> void {
+    out << "class";
+    IndexBase::Print(out);
+  }
+};
+
 // The ID of a cross-referenced IR.
 // The ID of a cross-referenced IR.
 struct CrossReferenceIRId : public IndexBase,
 struct CrossReferenceIRId : public IndexBase,
                             public Printable<CrossReferenceIRId> {
                             public Printable<CrossReferenceIRId> {
@@ -308,6 +317,10 @@ struct Call {
   FunctionId function_id;
   FunctionId function_id;
 };
 };
 
 
+struct ClassDeclaration {
+  ClassId class_id;
+};
+
 struct ConstType {
 struct ConstType {
   TypeId inner_id;
   TypeId inner_id;
 };
 };

+ 2 - 0
toolchain/sem_ir/node_kind.def

@@ -59,6 +59,8 @@ CARBON_SEMANTICS_NODE_KIND_IMPL(BranchIf, "br", None, TerminatorSequence)
 CARBON_SEMANTICS_NODE_KIND_IMPL(BranchWithArg, "br", None, Terminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(BranchWithArg, "br", None, Terminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(Builtin, "builtin", Typed, NotTerminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(Builtin, "builtin", Typed, NotTerminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(Call, "call", Typed, NotTerminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(Call, "call", Typed, NotTerminator)
+CARBON_SEMANTICS_NODE_KIND_IMPL(ClassDeclaration, "class_declaration", Typed,
+                                NotTerminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(ConstType, "const_type", Typed, NotTerminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(ConstType, "const_type", Typed, NotTerminator)
 CARBON_SEMANTICS_NODE_KIND_IMPL(Dereference, "dereference", Typed,
 CARBON_SEMANTICS_NODE_KIND_IMPL(Dereference, "dereference", Typed,
                                 NotTerminator)
                                 NotTerminator)