Просмотр исходного кода

Simplify interface for getting an instruction from a type. (#3455)

Co-authored-by: Jon Ross-Perkins <jperkins@google.com>
Richard Smith 2 лет назад
Родитель
Сommit
cef7eb5522

+ 8 - 12
toolchain/check/context.cpp

@@ -542,7 +542,7 @@ class TypeCompleter {
  private:
  private:
   // Adds `type_id` to the work list, if it's not already complete.
   // Adds `type_id` to the work list, if it's not already complete.
   auto Push(SemIR::TypeId type_id) -> void {
   auto Push(SemIR::TypeId type_id) -> void {
-    if (!context_.sem_ir().IsTypeComplete(type_id)) {
+    if (!context_.types().IsComplete(type_id)) {
       work_list_.push_back({type_id, Phase::AddNestedIncompleteTypes});
       work_list_.push_back({type_id, Phase::AddNestedIncompleteTypes});
     }
     }
   }
   }
@@ -553,14 +553,12 @@ class TypeCompleter {
 
 
     // We might have enqueued the same type more than once. Just skip the
     // We might have enqueued the same type more than once. Just skip the
     // type if it's already complete.
     // type if it's already complete.
-    if (context_.sem_ir().IsTypeComplete(type_id)) {
+    if (context_.types().IsComplete(type_id)) {
       work_list_.pop_back();
       work_list_.pop_back();
       return true;
       return true;
     }
     }
 
 
-    auto inst_id = context_.sem_ir().GetTypeAllowBuiltinTypes(type_id);
-    auto inst = context_.insts().Get(inst_id);
-
+    auto inst = context_.types().GetAsInst(type_id);
     auto old_work_list_size = work_list_.size();
     auto old_work_list_size = work_list_.size();
 
 
     switch (phase) {
     switch (phase) {
@@ -583,14 +581,14 @@ class TypeCompleter {
         // Also complete the value representation type, if necessary. This
         // Also complete the value representation type, if necessary. This
         // should never fail: the value representation shouldn't require any
         // should never fail: the value representation shouldn't require any
         // additional nested types to be complete.
         // additional nested types to be complete.
-        if (!context_.sem_ir().IsTypeComplete(value_rep.type_id)) {
+        if (!context_.types().IsComplete(value_rep.type_id)) {
           work_list_.push_back({value_rep.type_id, Phase::BuildValueRepr});
           work_list_.push_back({value_rep.type_id, Phase::BuildValueRepr});
         }
         }
         // For a pointer representation, the pointee also needs to be complete.
         // For a pointer representation, the pointee also needs to be complete.
         if (value_rep.kind == SemIR::ValueRepr::Pointer) {
         if (value_rep.kind == SemIR::ValueRepr::Pointer) {
           auto pointee_type_id =
           auto pointee_type_id =
               context_.sem_ir().GetPointeeType(value_rep.type_id);
               context_.sem_ir().GetPointeeType(value_rep.type_id);
-          if (!context_.sem_ir().IsTypeComplete(pointee_type_id)) {
+          if (!context_.types().IsComplete(pointee_type_id)) {
             work_list_.push_back({pointee_type_id, Phase::BuildValueRepr});
             work_list_.push_back({pointee_type_id, Phase::BuildValueRepr});
           }
           }
         }
         }
@@ -684,9 +682,9 @@ class TypeCompleter {
   // Gets the value representation of a nested type, which should already be
   // Gets the value representation of a nested type, which should already be
   // complete.
   // complete.
   auto GetNestedValueRepr(SemIR::TypeId nested_type_id) const {
   auto GetNestedValueRepr(SemIR::TypeId nested_type_id) const {
-    CARBON_CHECK(context_.sem_ir().IsTypeComplete(nested_type_id))
+    CARBON_CHECK(context_.types().IsComplete(nested_type_id))
         << "Nested type should already be complete";
         << "Nested type should already be complete";
-    auto value_rep = context_.sem_ir().GetValueRepr(nested_type_id);
+    auto value_rep = context_.types().GetValueRepr(nested_type_id);
     CARBON_CHECK(value_rep.kind != SemIR::ValueRepr::Unknown)
     CARBON_CHECK(value_rep.kind != SemIR::ValueRepr::Unknown)
         << "Complete type should have a value representation";
         << "Complete type should have a value representation";
     return value_rep;
     return value_rep;
@@ -1128,9 +1126,7 @@ auto Context::GetPointerType(Parse::NodeId parse_node,
 }
 }
 
 
 auto Context::GetUnqualifiedType(SemIR::TypeId type_id) -> SemIR::TypeId {
 auto Context::GetUnqualifiedType(SemIR::TypeId type_id) -> SemIR::TypeId {
-  SemIR::Inst type_inst =
-      insts().Get(sem_ir_->GetTypeAllowBuiltinTypes(type_id));
-  if (auto const_type = type_inst.TryAs<SemIR::ConstType>()) {
+  if (auto const_type = types().TryGetAs<SemIR::ConstType>(type_id)) {
     return const_type->inner_id;
     return const_type->inner_id;
   }
   }
   return type_id;
   return type_id;

+ 1 - 1
toolchain/check/context.h

@@ -347,7 +347,7 @@ class Context {
   auto name_scopes() -> SemIR::NameScopeStore& {
   auto name_scopes() -> SemIR::NameScopeStore& {
     return sem_ir().name_scopes();
     return sem_ir().name_scopes();
   }
   }
-  auto types() -> ValueStore<SemIR::TypeId>& { return sem_ir().types(); }
+  auto types() -> SemIR::TypeStore& { return sem_ir().types(); }
   auto type_blocks() -> SemIR::BlockValueStore<SemIR::TypeBlockId>& {
   auto type_blocks() -> SemIR::BlockValueStore<SemIR::TypeBlockId>& {
     return sem_ir().type_blocks();
     return sem_ir().type_blocks();
   }
   }

+ 13 - 18
toolchain/check/convert.cpp

@@ -546,8 +546,8 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
                            context.sem_ir().StringifyType(target.type_id));
                            context.sem_ir().StringifyType(target.type_id));
     return SemIR::InstId::BuiltinError;
     return SemIR::InstId::BuiltinError;
   }
   }
-  auto dest_struct_type = context.insts().GetAs<SemIR::StructType>(
-      context.sem_ir().GetTypeAllowBuiltinTypes(class_info.object_repr_id));
+  auto dest_struct_type =
+      context.types().GetAs<SemIR::StructType>(class_info.object_repr_id);
 
 
   // If we're trying to create a class value, form a temporary for the value to
   // If we're trying to create a class value, form a temporary for the value to
   // point to.
   // point to.
@@ -599,8 +599,7 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
   auto& sem_ir = context.sem_ir();
   auto& sem_ir = context.sem_ir();
   auto value = sem_ir.insts().Get(value_id);
   auto value = sem_ir.insts().Get(value_id);
   auto value_type_id = value.type_id();
   auto value_type_id = value.type_id();
-  auto target_type_inst =
-      sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(target.type_id));
+  auto target_type_inst = sem_ir.types().GetAsInst(target.type_id);
 
 
   // Various forms of implicit conversion are supported as builtin conversions,
   // Various forms of implicit conversion are supported as builtin conversions,
   // either in addition to or instead of `impl`s of `ImplicitAs` in the Carbon
   // either in addition to or instead of `impl`s of `ImplicitAs` in the Carbon
@@ -660,9 +659,8 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
   // A tuple (T1, T2, ..., Tn) converts to (U1, U2, ..., Un) if each Ti
   // A tuple (T1, T2, ..., Tn) converts to (U1, U2, ..., Un) if each Ti
   // converts to Ui.
   // converts to Ui.
   if (auto target_tuple_type = target_type_inst.TryAs<SemIR::TupleType>()) {
   if (auto target_tuple_type = target_type_inst.TryAs<SemIR::TupleType>()) {
-    auto value_type_inst =
-        sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
-    if (auto src_tuple_type = value_type_inst.TryAs<SemIR::TupleType>()) {
+    if (auto src_tuple_type =
+            sem_ir.types().TryGetAs<SemIR::TupleType>(value_type_id)) {
       return ConvertTupleToTuple(context, *src_tuple_type, *target_tuple_type,
       return ConvertTupleToTuple(context, *src_tuple_type, *target_tuple_type,
                                  value_id, target);
                                  value_id, target);
     }
     }
@@ -673,9 +671,8 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
   // (p(1), ..., p(n)) is a permutation of (1, ..., n) and each Ti converts
   // (p(1), ..., p(n)) is a permutation of (1, ..., n) and each Ti converts
   // to Ui.
   // to Ui.
   if (auto target_struct_type = target_type_inst.TryAs<SemIR::StructType>()) {
   if (auto target_struct_type = target_type_inst.TryAs<SemIR::StructType>()) {
-    auto value_type_inst =
-        sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
-    if (auto src_struct_type = value_type_inst.TryAs<SemIR::StructType>()) {
+    if (auto src_struct_type =
+            sem_ir.types().TryGetAs<SemIR::StructType>(value_type_id)) {
       return ConvertStructToStruct(context, *src_struct_type,
       return ConvertStructToStruct(context, *src_struct_type,
                                    *target_struct_type, value_id, target);
                                    *target_struct_type, value_id, target);
     }
     }
@@ -683,9 +680,8 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
 
 
   // A tuple (T1, T2, ..., Tn) converts to [T; n] if each Ti converts to T.
   // A tuple (T1, T2, ..., Tn) converts to [T; n] if each Ti converts to T.
   if (auto target_array_type = target_type_inst.TryAs<SemIR::ArrayType>()) {
   if (auto target_array_type = target_type_inst.TryAs<SemIR::ArrayType>()) {
-    auto value_type_inst =
-        sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
-    if (auto src_tuple_type = value_type_inst.TryAs<SemIR::TupleType>()) {
+    if (auto src_tuple_type =
+            sem_ir.types().TryGetAs<SemIR::TupleType>(value_type_id)) {
       return ConvertTupleToArray(context, *src_tuple_type, *target_array_type,
       return ConvertTupleToArray(context, *src_tuple_type, *target_array_type,
                                  value_id, target);
                                  value_id, target);
     }
     }
@@ -696,9 +692,8 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
   // (a struct with the same fields as the class, plus a base field where
   // (a struct with the same fields as the class, plus a base field where
   // relevant).
   // relevant).
   if (auto target_class_type = target_type_inst.TryAs<SemIR::ClassType>()) {
   if (auto target_class_type = target_type_inst.TryAs<SemIR::ClassType>()) {
-    auto value_type_inst =
-        sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
-    if (auto src_struct_type = value_type_inst.TryAs<SemIR::StructType>()) {
+    if (auto src_struct_type =
+            sem_ir.types().TryGetAs<SemIR::StructType>(value_type_id)) {
       return ConvertStructToClass(context, *src_struct_type, *target_class_type,
       return ConvertStructToClass(context, *src_struct_type, *target_class_type,
                                   value_id, target);
                                   value_id, target);
     }
     }
@@ -716,7 +711,7 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
         type_ids.push_back(ExprAsType(context, parse_node, tuple_inst_id));
         type_ids.push_back(ExprAsType(context, parse_node, tuple_inst_id));
       }
       }
       auto tuple_type_id = context.CanonicalizeTupleType(parse_node, type_ids);
       auto tuple_type_id = context.CanonicalizeTupleType(parse_node, type_ids);
-      return sem_ir.GetTypeAllowBuiltinTypes(tuple_type_id);
+      return sem_ir.types().GetInstId(tuple_type_id);
     }
     }
 
 
     // `{}` converts to `{} as type`.
     // `{}` converts to `{} as type`.
@@ -725,7 +720,7 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
     if (auto struct_literal = value.TryAs<SemIR::StructLiteral>();
     if (auto struct_literal = value.TryAs<SemIR::StructLiteral>();
         struct_literal &&
         struct_literal &&
         struct_literal->elements_id == SemIR::InstBlockId::Empty) {
         struct_literal->elements_id == SemIR::InstBlockId::Empty) {
-      value_id = sem_ir.GetTypeAllowBuiltinTypes(value_type_id);
+      value_id = sem_ir.types().GetInstId(value_type_id);
     }
     }
   }
   }
 
 

+ 3 - 7
toolchain/check/handle_class.cpp

@@ -144,9 +144,8 @@ auto HandleClassDefinitionStart(Context& context, Parse::NodeId parse_node)
   context.PushScope(class_decl_id, class_info.scope_id);
   context.PushScope(class_decl_id, class_info.scope_id);
 
 
   // Introduce `Self`.
   // Introduce `Self`.
-  context.AddNameToLookup(
-      parse_node, SemIR::NameId::SelfType,
-      context.sem_ir().GetTypeAllowBuiltinTypes(class_info.self_type_id));
+  context.AddNameToLookup(parse_node, SemIR::NameId::SelfType,
+                          context.types().GetInstId(class_info.self_type_id));
 
 
   context.inst_block_stack().Push();
   context.inst_block_stack().Push();
   context.node_stack().Push(parse_node, class_id);
   context.node_stack().Push(parse_node, class_id);
@@ -227,10 +226,7 @@ auto HandleBaseDecl(Context& context, Parse::NodeId parse_node) -> bool {
     // declaration as being final classes.
     // declaration as being final classes.
     // TODO: Once we have a better idea of which types are considered to be
     // TODO: Once we have a better idea of which types are considered to be
     // classes, produce a better diagnostic for deriving from a non-class type.
     // classes, produce a better diagnostic for deriving from a non-class type.
-    auto base_class =
-        context.insts()
-            .Get(context.sem_ir().GetTypeAllowBuiltinTypes(base_type_id))
-            .TryAs<SemIR::ClassType>();
+    auto base_class = context.types().TryGetAs<SemIR::ClassType>(base_type_id);
     if (!base_class ||
     if (!base_class ||
         context.classes().Get(base_class->class_id).inheritance_kind ==
         context.classes().Get(base_class->class_id).inheritance_kind ==
             SemIR::Class::Final) {
             SemIR::Class::Final) {

+ 1 - 2
toolchain/check/handle_index.cpp

@@ -42,8 +42,7 @@ auto HandleIndexExpr(Context& context, Parse::NodeId parse_node) -> bool {
   operand_inst_id = ConvertToValueOrRefExpr(context, operand_inst_id);
   operand_inst_id = ConvertToValueOrRefExpr(context, operand_inst_id);
   auto operand_inst = context.insts().Get(operand_inst_id);
   auto operand_inst = context.insts().Get(operand_inst_id);
   auto operand_type_id = operand_inst.type_id();
   auto operand_type_id = operand_inst.type_id();
-  auto operand_type_inst = context.insts().Get(
-      context.sem_ir().GetTypeAllowBuiltinTypes(operand_type_id));
+  auto operand_type_inst = context.types().GetAsInst(operand_type_id);
 
 
   switch (operand_type_inst.kind()) {
   switch (operand_type_inst.kind()) {
     case SemIR::ArrayType::Kind: {
     case SemIR::ArrayType::Kind: {

+ 5 - 8
toolchain/check/handle_name.cpp

@@ -45,7 +45,7 @@ static auto GetExprValueForLookupResult(Context& context,
   // If lookup finds a class declaration, the value is its `Self` type.
   // If lookup finds a class declaration, the value is its `Self` type.
   auto lookup_result = context.insts().Get(lookup_result_id);
   auto lookup_result = context.insts().Get(lookup_result_id);
   if (auto class_decl = lookup_result.TryAs<SemIR::ClassDecl>()) {
   if (auto class_decl = lookup_result.TryAs<SemIR::ClassDecl>()) {
-    return context.sem_ir().GetTypeAllowBuiltinTypes(
+    return context.types().GetInstId(
         context.classes().Get(class_decl->class_id).self_type_id);
         context.classes().Get(class_decl->class_id).self_type_id);
   }
   }
 
 
@@ -108,10 +108,8 @@ auto HandleMemberAccessExpr(Context& context, Parse::NodeId parse_node)
   base_id = ConvertToValueOrRefExpr(context, base_id);
   base_id = ConvertToValueOrRefExpr(context, base_id);
   base_type_id = context.insts().Get(base_id).type_id();
   base_type_id = context.insts().Get(base_id).type_id();
 
 
-  auto base_type = context.insts().Get(
-      context.sem_ir().GetTypeAllowBuiltinTypes(base_type_id));
-
-  switch (base_type.kind()) {
+  switch (auto base_type = context.types().GetAsInst(base_type_id);
+          base_type.kind()) {
     case SemIR::ClassType::Kind: {
     case SemIR::ClassType::Kind: {
       // Perform lookup for the name in the class scope.
       // Perform lookup for the name in the class scope.
       auto class_scope_id = context.classes()
       auto class_scope_id = context.classes()
@@ -123,10 +121,9 @@ auto HandleMemberAccessExpr(Context& context, Parse::NodeId parse_node)
 
 
       // Perform instance binding if we found an instance member.
       // Perform instance binding if we found an instance member.
       auto member_type_id = context.insts().Get(member_id).type_id();
       auto member_type_id = context.insts().Get(member_id).type_id();
-      auto member_type_inst = context.insts().Get(
-          context.sem_ir().GetTypeAllowBuiltinTypes(member_type_id));
       if (auto unbound_element_type =
       if (auto unbound_element_type =
-              member_type_inst.TryAs<SemIR::UnboundElementType>()) {
+              context.types().TryGetAs<SemIR::UnboundElementType>(
+                  member_type_id)) {
         // TODO: Check that the unbound element type describes a member of this
         // TODO: Check that the unbound element type describes a member of this
         // class. Perform a conversion of the base if necessary.
         // class. Perform a conversion of the base if necessary.
 
 

+ 2 - 3
toolchain/check/handle_operator.cpp

@@ -150,10 +150,9 @@ auto HandlePrefixOperator(Context& context, Parse::NodeId parse_node) -> bool {
       value_id = ConvertToValueExpr(context, value_id);
       value_id = ConvertToValueExpr(context, value_id);
       auto type_id =
       auto type_id =
           context.GetUnqualifiedType(context.insts().Get(value_id).type_id());
           context.GetUnqualifiedType(context.insts().Get(value_id).type_id());
-      auto type_inst = context.insts().Get(
-          context.sem_ir().GetTypeAllowBuiltinTypes(type_id));
       auto result_type_id = SemIR::TypeId::Error;
       auto result_type_id = SemIR::TypeId::Error;
-      if (auto pointer_type = type_inst.TryAs<SemIR::PointerType>()) {
+      if (auto pointer_type =
+              context.types().TryGetAs<SemIR::PointerType>(type_id)) {
         result_type_id = pointer_type->pointee_id;
         result_type_id = pointer_type->pointee_id;
       } else if (type_id != SemIR::TypeId::Error) {
       } else if (type_id != SemIR::TypeId::Error) {
         CARBON_DIAGNOSTIC(
         CARBON_DIAGNOSTIC(

+ 1 - 1
toolchain/lower/handle.cpp

@@ -210,7 +210,7 @@ auto HandleIntLiteral(FunctionContext& context, SemIR::InstId inst_id,
 
 
 auto HandleNameRef(FunctionContext& context, SemIR::InstId inst_id,
 auto HandleNameRef(FunctionContext& context, SemIR::InstId inst_id,
                    SemIR::NameRef inst) -> void {
                    SemIR::NameRef inst) -> void {
-  auto type_inst_id = context.sem_ir().GetTypeAllowBuiltinTypes(inst.type_id);
+  auto type_inst_id = context.sem_ir().types().GetInstId(inst.type_id);
   if (type_inst_id == SemIR::InstId::BuiltinNamespaceType) {
   if (type_inst_id == SemIR::InstId::BuiltinNamespaceType) {
     return;
     return;
   }
   }

+ 3 - 8
toolchain/lower/handle_aggregates.cpp

@@ -91,9 +91,8 @@ static auto GetStructFieldName(FunctionContext& context,
                                SemIR::ElementIndex index) -> llvm::StringRef {
                                SemIR::ElementIndex index) -> llvm::StringRef {
   auto fields = context.sem_ir().inst_blocks().Get(
   auto fields = context.sem_ir().inst_blocks().Get(
       context.sem_ir()
       context.sem_ir()
-          .insts()
-          .GetAs<SemIR::StructType>(
-              context.sem_ir().types().Get(struct_type_id).inst_id)
+          .types()
+          .GetAs<SemIR::StructType>(struct_type_id)
           .fields_id);
           .fields_id);
   auto field = context.sem_ir().insts().GetAs<SemIR::StructTypeField>(
   auto field = context.sem_ir().insts().GetAs<SemIR::StructTypeField>(
       fields[index.index]);
       fields[index.index]);
@@ -105,11 +104,7 @@ auto HandleClassElementAccess(FunctionContext& context, SemIR::InstId inst_id,
   // Find the class that we're performing access into.
   // Find the class that we're performing access into.
   auto class_type_id = context.sem_ir().insts().Get(inst.base_id).type_id();
   auto class_type_id = context.sem_ir().insts().Get(inst.base_id).type_id();
   auto class_id =
   auto class_id =
-      context.sem_ir()
-          .insts()
-          .GetAs<SemIR::ClassType>(
-              context.sem_ir().GetTypeAllowBuiltinTypes(class_type_id))
-          .class_id;
+      context.sem_ir().types().GetAs<SemIR::ClassType>(class_type_id).class_id;
   const auto& class_info = context.sem_ir().classes().Get(class_id);
   const auto& class_info = context.sem_ir().classes().Get(class_id);
 
 
   // Translate the class field access into a struct access on the object
   // Translate the class field access into a struct access on the object

+ 12 - 0
toolchain/sem_ir/BUILD

@@ -68,6 +68,7 @@ cc_library(
         ":ids",
         ":ids",
         ":inst",
         ":inst",
         ":inst_kind",
         ":inst_kind",
+        ":type_info",
         ":value_stores",
         ":value_stores",
         "//common:check",
         "//common:check",
         "//toolchain/base:value_store",
         "//toolchain/base:value_store",
@@ -101,12 +102,23 @@ cc_library(
     ],
     ],
 )
 )
 
 
+cc_library(
+    name = "type_info",
+    hdrs = ["type_info.h"],
+    deps = [
+        ":ids",
+        ":inst",
+        "//common:ostream",
+    ],
+)
+
 cc_library(
 cc_library(
     name = "value_stores",
     name = "value_stores",
     srcs = ["value_stores.cpp"],
     srcs = ["value_stores.cpp"],
     hdrs = ["value_stores.h"],
     hdrs = ["value_stores.h"],
     deps = [
     deps = [
         ":inst",
         ":inst",
+        ":type_info",
         "//toolchain/base:value_store",
         "//toolchain/base:value_store",
         "//toolchain/base:yaml",
         "//toolchain/base:yaml",
         "//toolchain/lex:token_kind",
         "//toolchain/lex:token_kind",

+ 7 - 9
toolchain/sem_ir/file.cpp

@@ -238,7 +238,7 @@ static auto GetTypePrecedence(InstKind kind) -> int {
 }
 }
 
 
 auto File::StringifyType(TypeId type_id) const -> std::string {
 auto File::StringifyType(TypeId type_id) const -> std::string {
-  return StringifyTypeExpr(GetTypeAllowBuiltinTypes(type_id));
+  return StringifyTypeExpr(types().GetInstId(type_id));
 }
 }
 
 
 auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
 auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
@@ -280,7 +280,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
           out << "[";
           out << "[";
           steps.push_back(step.Next());
           steps.push_back(step.Next());
           steps.push_back(
           steps.push_back(
-              {.inst_id = GetTypeAllowBuiltinTypes(array.element_type_id)});
+              {.inst_id = types().GetInstId(array.element_type_id)});
         } else if (step.index == 1) {
         } else if (step.index == 1) {
           out << "; " << GetArrayBoundValue(array.bound_id) << "]";
           out << "; " << GetArrayBoundValue(array.bound_id) << "]";
         }
         }
@@ -298,7 +298,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
 
 
           // Add parentheses if required.
           // Add parentheses if required.
           auto inner_type_inst_id =
           auto inner_type_inst_id =
-              GetTypeAllowBuiltinTypes(inst.As<ConstType>().inner_id);
+              types().GetInstId(inst.As<ConstType>().inner_id);
           if (GetTypePrecedence(insts().Get(inner_type_inst_id).kind()) <
           if (GetTypePrecedence(insts().Get(inner_type_inst_id).kind()) <
               GetTypePrecedence(inst.kind())) {
               GetTypePrecedence(inst.kind())) {
             out << "(";
             out << "(";
@@ -318,7 +318,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
       case PointerType::Kind: {
       case PointerType::Kind: {
         if (step.index == 0) {
         if (step.index == 0) {
           steps.push_back(step.Next());
           steps.push_back(step.Next());
-          steps.push_back({.inst_id = GetTypeAllowBuiltinTypes(
+          steps.push_back({.inst_id = types().GetInstId(
                                inst.As<PointerType>().pointee_id)});
                                inst.As<PointerType>().pointee_id)});
         } else if (step.index == 1) {
         } else if (step.index == 1) {
           out << "*";
           out << "*";
@@ -346,8 +346,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
       case StructTypeField::Kind: {
       case StructTypeField::Kind: {
         auto field = inst.As<StructTypeField>();
         auto field = inst.As<StructTypeField>();
         out << "." << names().GetFormatted(field.name_id) << ": ";
         out << "." << names().GetFormatted(field.name_id) << ": ";
-        steps.push_back(
-            {.inst_id = GetTypeAllowBuiltinTypes(field.field_type_id)});
+        steps.push_back({.inst_id = types().GetInstId(field.field_type_id)});
         break;
         break;
       }
       }
       case TupleType::Kind: {
       case TupleType::Kind: {
@@ -369,15 +368,14 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
           break;
           break;
         }
         }
         steps.push_back(step.Next());
         steps.push_back(step.Next());
-        steps.push_back(
-            {.inst_id = GetTypeAllowBuiltinTypes(refs[step.index])});
+        steps.push_back({.inst_id = types().GetInstId(refs[step.index])});
         break;
         break;
       }
       }
       case UnboundElementType::Kind: {
       case UnboundElementType::Kind: {
         if (step.index == 0) {
         if (step.index == 0) {
           out << "<unbound element of class ";
           out << "<unbound element of class ";
           steps.push_back(step.Next());
           steps.push_back(step.Next());
-          steps.push_back({.inst_id = GetTypeAllowBuiltinTypes(
+          steps.push_back({.inst_id = types().GetInstId(
                                inst.As<UnboundElementType>().class_type_id)});
                                inst.As<UnboundElementType>().class_type_id)});
         } else {
         } else {
           out << ">";
           out << ">";

+ 14 - 103
toolchain/sem_ir/file.h

@@ -12,6 +12,7 @@
 #include "toolchain/base/value_store.h"
 #include "toolchain/base/value_store.h"
 #include "toolchain/base/yaml.h"
 #include "toolchain/base/yaml.h"
 #include "toolchain/sem_ir/ids.h"
 #include "toolchain/sem_ir/ids.h"
+#include "toolchain/sem_ir/type_info.h"
 #include "toolchain/sem_ir/value_stores.h"
 #include "toolchain/sem_ir/value_stores.h"
 
 
 namespace Carbon::SemIR {
 namespace Carbon::SemIR {
@@ -150,69 +151,6 @@ struct Interface : public Printable<Interface> {
   bool defined = true;
   bool defined = true;
 };
 };
 
 
-// The value representation to use when passing by value.
-struct ValueRepr : public Printable<ValueRepr> {
-  auto Print(llvm::raw_ostream& out) const -> void;
-
-  enum Kind : int8_t {
-    // The value representation is not yet known. This is used for incomplete
-    // types.
-    Unknown,
-    // The type has no value representation. This is used for empty types, such
-    // as `()`, where there is no value.
-    None,
-    // The value representation is a copy of the value. On call boundaries, the
-    // value itself will be passed. `type` is the value type.
-    Copy,
-    // The value representation is a pointer to the value. When used as a
-    // parameter, the argument is a reference expression. `type` is the pointee
-    // type.
-    Pointer,
-    // The value representation has been customized, and has the same behavior
-    // as the value representation of some other type.
-    // TODO: This is not implemented or used yet.
-    Custom,
-  };
-
-  enum AggregateKind : int8_t {
-    // This type is not an aggregation of other types.
-    NotAggregate,
-    // This type is an aggregate that holds the value representations of its
-    // elements.
-    ValueAggregate,
-    // This type is an aggregate that holds the object representations of its
-    // elements.
-    ObjectAggregate,
-    // This type is an aggregate for which the value and object representation
-    // of all elements are the same, so it effectively holds both.
-    ValueAndObjectAggregate,
-  };
-
-  // Returns whether this is an aggregate that holds its elements by value.
-  auto elements_are_values() const {
-    return aggregate_kind == ValueAggregate ||
-           aggregate_kind == ValueAndObjectAggregate;
-  }
-
-  // The kind of value representation used by this type.
-  Kind kind = Unknown;
-  // The kind of aggregate representation used by this type.
-  AggregateKind aggregate_kind = AggregateKind::NotAggregate;
-  // The type used to model the value representation.
-  TypeId type_id = TypeId::Invalid;
-};
-
-// Information stored about a TypeId.
-struct TypeInfo : public Printable<TypeInfo> {
-  auto Print(llvm::raw_ostream& out) const -> void;
-
-  // The instruction that defines this type.
-  InstId inst_id;
-  // The value representation for this type. Will be `Unknown` if the type is
-  // not complete.
-  ValueRepr value_repr = ValueRepr();
-};
-
 // Provides semantic analysis on a Parse::Tree.
 // Provides semantic analysis on a Parse::Tree.
 class File : public Printable<File> {
 class File : public Printable<File> {
  public:
  public:
@@ -223,6 +161,9 @@ class File : public Printable<File> {
   explicit File(SharedValueStores& value_stores, std::string filename,
   explicit File(SharedValueStores& value_stores, std::string filename,
                 const File* builtins);
                 const File* builtins);
 
 
+  File(const File&) = delete;
+  File& operator=(const File&) = delete;
+
   // Verifies that invariants of the semantics IR hold.
   // Verifies that invariants of the semantics IR hold.
   auto Verify() const -> ErrorOr<Success>;
   auto Verify() const -> ErrorOr<Success>;
 
 
@@ -257,39 +198,9 @@ class File : public Printable<File> {
     complete_types_.push_back(object_type_id);
     complete_types_.push_back(object_type_id);
   }
   }
 
 
-  auto GetTypeAllowBuiltinTypes(TypeId type_id) const -> InstId {
-    if (type_id == TypeId::TypeType) {
-      return InstId::BuiltinTypeType;
-    } else if (type_id == TypeId::Error) {
-      return InstId::BuiltinError;
-    } else if (type_id == TypeId::Invalid) {
-      return InstId::Invalid;
-    } else {
-      return types().Get(type_id).inst_id;
-    }
-  }
-
-  // Gets the value representation to use for a type. This returns an
-  // invalid type if the given type is not complete.
-  auto GetValueRepr(TypeId type_id) const -> ValueRepr {
-    if (type_id.index < 0) {
-      // TypeType and InvalidType are their own value representation.
-      return {.kind = ValueRepr::Copy, .type_id = type_id};
-    }
-    return types().Get(type_id).value_repr;
-  }
-
-  // Determines whether the given type is known to be complete. This does not
-  // determine whether the type could be completed, only whether it has been.
-  auto IsTypeComplete(TypeId type_id) const -> bool {
-    return GetValueRepr(type_id).kind != ValueRepr::Unknown;
-  }
-
   // Gets the pointee type of the given type, which must be a pointer type.
   // Gets the pointee type of the given type, which must be a pointer type.
   auto GetPointeeType(TypeId pointer_id) const -> TypeId {
   auto GetPointeeType(TypeId pointer_id) const -> TypeId {
-    return insts()
-        .GetAs<PointerType>(types().Get(pointer_id).inst_id)
-        .pointee_id;
+    return types().GetAs<PointerType>(pointer_id).pointee_id;
   }
   }
 
 
   // Produces a string version of a type.
   // Produces a string version of a type.
@@ -338,8 +249,8 @@ class File : public Printable<File> {
   }
   }
   auto name_scopes() -> NameScopeStore& { return name_scopes_; }
   auto name_scopes() -> NameScopeStore& { return name_scopes_; }
   auto name_scopes() const -> const NameScopeStore& { return name_scopes_; }
   auto name_scopes() const -> const NameScopeStore& { return name_scopes_; }
-  auto types() -> ValueStore<TypeId>& { return types_; }
-  auto types() const -> const ValueStore<TypeId>& { return types_; }
+  auto types() -> TypeStore& { return types_; }
+  auto types() const -> const TypeStore& { return types_; }
   auto type_blocks() -> BlockValueStore<TypeBlockId>& { return type_blocks_; }
   auto type_blocks() -> BlockValueStore<TypeBlockId>& { return type_blocks_; }
   auto type_blocks() const -> const BlockValueStore<TypeBlockId>& {
   auto type_blocks() const -> const BlockValueStore<TypeBlockId>& {
     return type_blocks_;
     return type_blocks_;
@@ -399,12 +310,6 @@ class File : public Printable<File> {
   // Storage for name scopes.
   // Storage for name scopes.
   NameScopeStore name_scopes_;
   NameScopeStore name_scopes_;
 
 
-  // Descriptions of types used in this file.
-  ValueStore<TypeId> types_;
-
-  // Types that were completed in this file.
-  llvm::SmallVector<TypeId> complete_types_;
-
   // Type blocks within the IR. These reference entries in types_. Storage for
   // Type blocks within the IR. These reference entries in types_. Storage for
   // the data is provided by allocator_.
   // the data is provided by allocator_.
   BlockValueStore<TypeBlockId> type_blocks_;
   BlockValueStore<TypeBlockId> type_blocks_;
@@ -423,6 +328,12 @@ class File : public Printable<File> {
   // Storage for instructions that represent computed global constants, such as
   // Storage for instructions that represent computed global constants, such as
   // types.
   // types.
   ConstantStore constants_;
   ConstantStore constants_;
+
+  // Descriptions of types used in this file.
+  TypeStore types_ = TypeStore(&insts_);
+
+  // Types that were completed in this file.
+  llvm::SmallVector<TypeId> complete_types_;
 };
 };
 
 
 // The expression category of a sem_ir instruction. See /docs/design/values.md
 // The expression category of a sem_ir instruction. See /docs/design/values.md
@@ -459,7 +370,7 @@ auto GetExprCategory(const File& file, InstId inst_id) -> ExprCategory;
 
 
 // Returns information about the value representation to use for a type.
 // Returns information about the value representation to use for a type.
 inline auto GetValueRepr(const File& file, TypeId type_id) -> ValueRepr {
 inline auto GetValueRepr(const File& file, TypeId type_id) -> ValueRepr {
-  return file.GetValueRepr(type_id);
+  return file.types().GetValueRepr(type_id);
 }
 }
 
 
 // The initializing representation to use when returning by value.
 // The initializing representation to use when returning by value.

+ 78 - 0
toolchain/sem_ir/type_info.h

@@ -0,0 +1,78 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_SEM_IR_TYPE_INFO_H_
+#define CARBON_TOOLCHAIN_SEM_IR_TYPE_INFO_H_
+
+#include "common/ostream.h"
+#include "toolchain/sem_ir/ids.h"
+
+namespace Carbon::SemIR {
+
+// The value representation to use when passing by value.
+struct ValueRepr : public Printable<ValueRepr> {
+  auto Print(llvm::raw_ostream& out) const -> void;
+
+  enum Kind : int8_t {
+    // The value representation is not yet known. This is used for incomplete
+    // types.
+    Unknown,
+    // The type has no value representation. This is used for empty types, such
+    // as `()`, where there is no value.
+    None,
+    // The value representation is a copy of the value. On call boundaries, the
+    // value itself will be passed. `type` is the value type.
+    Copy,
+    // The value representation is a pointer to the value. When used as a
+    // parameter, the argument is a reference expression. `type` is the pointee
+    // type.
+    Pointer,
+    // The value representation has been customized, and has the same behavior
+    // as the value representation of some other type.
+    // TODO: This is not implemented or used yet.
+    Custom,
+  };
+
+  enum AggregateKind : int8_t {
+    // This type is not an aggregation of other types.
+    NotAggregate,
+    // This type is an aggregate that holds the value representations of its
+    // elements.
+    ValueAggregate,
+    // This type is an aggregate that holds the object representations of its
+    // elements.
+    ObjectAggregate,
+    // This type is an aggregate for which the value and object representation
+    // of all elements are the same, so it effectively holds both.
+    ValueAndObjectAggregate,
+  };
+
+  // Returns whether this is an aggregate that holds its elements by value.
+  auto elements_are_values() const {
+    return aggregate_kind == ValueAggregate ||
+           aggregate_kind == ValueAndObjectAggregate;
+  }
+
+  // The kind of value representation used by this type.
+  Kind kind = Unknown;
+  // The kind of aggregate representation used by this type.
+  AggregateKind aggregate_kind = AggregateKind::NotAggregate;
+  // The type used to model the value representation.
+  TypeId type_id = TypeId::Invalid;
+};
+
+// Information stored about a TypeId.
+struct TypeInfo : public Printable<TypeInfo> {
+  auto Print(llvm::raw_ostream& out) const -> void;
+
+  // The instruction that defines this type.
+  InstId inst_id;
+  // The value representation for this type. Will be `Unknown` if the type is
+  // not complete.
+  ValueRepr value_repr = ValueRepr();
+};
+
+}  // namespace Carbon::SemIR
+
+#endif  // CARBON_TOOLCHAIN_SEM_IR_TYPE_INFO_H_

+ 63 - 0
toolchain/sem_ir/value_stores.h

@@ -9,6 +9,7 @@
 #include "toolchain/base/value_store.h"
 #include "toolchain/base/value_store.h"
 #include "toolchain/base/yaml.h"
 #include "toolchain/base/yaml.h"
 #include "toolchain/sem_ir/inst.h"
 #include "toolchain/sem_ir/inst.h"
+#include "toolchain/sem_ir/type_info.h"
 
 
 namespace Carbon::SemIR {
 namespace Carbon::SemIR {
 
 
@@ -58,6 +59,68 @@ class ConstantStore {
   llvm::SmallVector<InstId> values_;
   llvm::SmallVector<InstId> values_;
 };
 };
 
 
+// Provides a ValueStore wrapper with an API specific to types.
+class TypeStore : public ValueStore<TypeId> {
+ public:
+  explicit TypeStore(InstStore* insts) : insts_(insts) {}
+
+  // Returns the ID of the instruction used to define the specified type.
+  auto GetInstId(TypeId type_id) const -> InstId {
+    if (type_id == TypeId::TypeType) {
+      return InstId::BuiltinTypeType;
+    } else if (type_id == TypeId::Error) {
+      return InstId::BuiltinError;
+    } else if (type_id == TypeId::Invalid) {
+      return InstId::Invalid;
+    } else {
+      return Get(type_id).inst_id;
+    }
+  }
+
+  // Returns the instruction used to define the specified type.
+  auto GetAsInst(TypeId type_id) const -> Inst {
+    return insts_->Get(GetInstId(type_id));
+  }
+
+  // Returns the instruction used to define the specified type, which is known
+  // to be a particular kind of instruction.
+  template <typename InstT>
+  auto GetAs(TypeId type_id) const -> InstT {
+    if constexpr (std::is_same_v<InstT, Builtin>) {
+      return GetAsInst(type_id).As<InstT>();
+    } else {
+      // The type is not a builtin, so no need to check for special values.
+      return insts_->Get(Get(type_id).inst_id).As<InstT>();
+    }
+  }
+
+  // Returns the instruction used to define the specified type, if it is of a
+  // particular kind.
+  template <typename InstT>
+  auto TryGetAs(TypeId type_id) const -> std::optional<InstT> {
+    return GetAsInst(type_id).TryAs<InstT>();
+  }
+
+  // Gets the value representation to use for a type. This returns an
+  // invalid type if the given type is not complete.
+  auto GetValueRepr(TypeId type_id) const -> ValueRepr {
+    if (type_id.index < 0) {
+      // TypeType and InvalidType are their own value representation.
+      return {.kind = ValueRepr::Copy, .type_id = type_id};
+    }
+    return Get(type_id).value_repr;
+  }
+
+  // Determines whether the given type is known to be complete. This does not
+  // determine whether the type could be completed, only whether it has been.
+  auto IsComplete(TypeId type_id) const -> bool {
+    return GetValueRepr(type_id).kind != ValueRepr::Unknown;
+  }
+
+ private:
+  InstStore* insts_;
+};
+
 // Provides a ValueStore-like interface for names.
 // Provides a ValueStore-like interface for names.
 //
 //
 // A name is either an identifier name or a special name such as `self` that
 // A name is either an identifier name or a special name such as `self` that