Bladeren bron

Encapsulate `clang::ASTUnit` in `SemIR::CppFile`. (#6459)

This intends to avoid proliferation of dependencies on the exact API of
`clang::ASTUnit`, and would enable us to more easily switch to a
different approach that gives us more control over the construction of
the Clang AST.

Also remove some unnecessary tracking of the `CppFile` and instead
always retrieve it from the `SemIR::File`.

---------

Co-authored-by: Jon Ross-Perkins <jperkins@google.com>
Richard Smith 4 maanden geleden
bovenliggende
commit
d208e950c7

+ 1 - 0
toolchain/check/BUILD

@@ -159,6 +159,7 @@ cc_library(
         "//toolchain/parse:tree",
         "//toolchain/sem_ir:absolute_node_id",
         "//toolchain/sem_ir:clang_decl",
+        "//toolchain/sem_ir:cpp_file",
         "//toolchain/sem_ir:expr_info",
         "//toolchain/sem_ir:file",
         "//toolchain/sem_ir:formatter",

+ 5 - 4
toolchain/check/check.cpp

@@ -394,11 +394,12 @@ static auto MaybeDumpCppAST(llvm::ArrayRef<Unit> units,
   }
 
   for (const Unit& unit : units) {
-    if (!unit.clang_ast_unit || !*unit.clang_ast_unit) {
-      continue;
+    if (options.include_in_dumps->Get(unit.sem_ir->check_ir_id())) {
+      if (auto* cpp_file = unit.sem_ir->cpp_file()) {
+        cpp_file->ast_context().getTranslationUnitDecl()->dump(
+            *options.dump_cpp_ast_stream);
+      }
     }
-    clang::ASTContext& ast_context = (*unit.clang_ast_unit)->getASTContext();
-    ast_context.getTranslationUnitDecl()->dump(*options.dump_cpp_ast_stream);
   }
 }
 

+ 0 - 4
toolchain/check/check.h

@@ -29,10 +29,6 @@ struct Unit {
   SemIR::File* sem_ir;
   // The total number of files.
   int total_ir_count;
-
-  // Storage for the unit's Clang AST. The unique_ptr should start empty, and
-  // can be assigned as part of checking.
-  std::unique_ptr<clang::ASTUnit>* clang_ast_unit;
 };
 
 struct CheckParseTreesOptions {

+ 3 - 7
toolchain/check/check_unit.cpp

@@ -154,11 +154,7 @@ auto CheckUnit::InitPackageScopeAndImports() -> void {
 
   const auto& cpp_imports = unit_and_imports_->cpp_imports;
   if (!cpp_imports.empty()) {
-    auto* clang_ast_unit = unit_and_imports_->unit->clang_ast_unit;
-    CARBON_CHECK(clang_ast_unit);
-    CARBON_CHECK(!clang_ast_unit->get());
-    *clang_ast_unit =
-        ImportCppFiles(context_, cpp_imports, fs_, clang_invocation_);
+    ImportCpp(context_, cpp_imports, fs_, clang_invocation_);
   }
 }
 
@@ -584,10 +580,10 @@ auto CheckUnit::FinishRun() -> void {
   CheckPoisonedConcreteImplLookupQueries();
   CheckImpls();
 
-  if (auto* clang_ast = context_.sem_ir().clang_ast_unit()) {
+  if (auto* cpp_file = context_.sem_ir().cpp_file()) {
     // Ask Clang to perform any cleanups required, including instantiating used
     // templates.
-    clang_ast->getSema().ActOnEndOfTranslationUnit();
+    cpp_file->sema().ActOnEndOfTranslationUnit();
     context_.emitter().Flush();
   }
 

+ 2 - 4
toolchain/check/context.h

@@ -311,11 +311,9 @@ class Context {
     return sem_ir().import_ir_insts();
   }
   auto ast_context() -> clang::ASTContext& {
-    return sem_ir().clang_ast_unit()->getASTContext();
-  }
-  auto clang_sema() -> clang::Sema& {
-    return sem_ir().clang_ast_unit()->getSema();
+    return sem_ir().cpp_file()->ast_context();
   }
+  auto clang_sema() -> clang::Sema& { return sem_ir().cpp_file()->sema(); }
   auto clang_decls() -> SemIR::ClangDeclStore& {
     return sem_ir().clang_decls();
   }

+ 26 - 28
toolchain/check/cpp/import.cpp

@@ -16,7 +16,6 @@
 #include "clang/AST/UnresolvedSet.h"
 #include "clang/AST/VTableBuilder.h"
 #include "clang/Basic/FileManager.h"
-#include "clang/Frontend/ASTUnit.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/CompilerInvocation.h"
 #include "clang/Frontend/TextDiagnostic.h"
@@ -62,6 +61,7 @@
 #include "toolchain/parse/node_ids.h"
 #include "toolchain/sem_ir/clang_decl.h"
 #include "toolchain/sem_ir/class.h"
+#include "toolchain/sem_ir/cpp_file.h"
 #include "toolchain/sem_ir/cpp_overload_set.h"
 #include "toolchain/sem_ir/function.h"
 #include "toolchain/sem_ir/ids.h"
@@ -278,8 +278,9 @@ class CarbonClangDiagnosticConsumer : public clang::DiagnosticConsumer {
   // Outputs Carbon diagnostics based on the collected Clang diagnostics. Must
   // be called after the AST is set in the context.
   auto EmitDiagnostics() -> void {
-    CARBON_CHECK(sem_ir_->clang_ast_unit(),
-                 "Attempted to emit diagnostics before the AST Unit is loaded");
+    CARBON_CHECK(
+        sem_ir_->cpp_file(),
+        "Attempted to emit C++ diagnostics before the C++ file is set");
 
     for (size_t i = 0; i != diagnostic_infos_.size(); ++i) {
       const ClangDiagnosticInfo& info = diagnostic_infos_[i];
@@ -388,15 +389,14 @@ class ShallowCopyCompilerInvocation : public clang::CompilerInvocation {
 
 }  // namespace
 
-// Returns an AST for the C++ imports and a bool that represents whether
-// compilation errors where encountered or the generated AST is null due to an
-// error. Sets the AST in the context's `sem_ir`.
-// TODO: Consider to always have a (non-null) AST.
+// Generates a Clang AST for the C++ imports and sets it in the context's
+// `sem_ir`. Returns a bool that represents whether compilation was successful.
+// TODO: Consider to always have a (non-null) AST even if there are no Cpp
+// imports.
 static auto GenerateAst(
     Context& context, llvm::ArrayRef<Parse::Tree::PackagingNames> imports,
     llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
-    std::shared_ptr<clang::CompilerInvocation> base_invocation)
-    -> std::pair<std::unique_ptr<clang::ASTUnit>, bool> {
+    std::shared_ptr<clang::CompilerInvocation> base_invocation) -> bool {
   auto invocation =
       std::make_shared<ShallowCopyCompilerInvocation>(*base_invocation);
 
@@ -434,12 +434,13 @@ static auto GenerateAst(
   // Attach the AST to SemIR. This needs to be done before we can emit any
   // diagnostics, so their locations can be properly interpreted by our
   // diagnostics machinery.
-  context.sem_ir().set_clang_ast_unit(ast.get());
+  context.sem_ir().set_cpp_file(
+      std::make_unique<SemIR::CppFile>(std::move(ast)));
 
   // Emit any diagnostics we queued up while building the AST.
   context.emitter().Flush();
 
-  return {std::move(ast), !ast || trap.hasErrorOccurred()};
+  return !trap.hasErrorOccurred();
 }
 
 // Adds a namespace for the `Cpp` import and returns its `NameScopeId`.
@@ -462,16 +463,15 @@ static auto AddNamespace(Context& context, PackageNameId cpp_package_id,
       .add_result.name_scope_id;
 }
 
-auto ImportCppFiles(Context& context,
-                    llvm::ArrayRef<Parse::Tree::PackagingNames> imports,
-                    llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
-                    std::shared_ptr<clang::CompilerInvocation> invocation)
-    -> std::unique_ptr<clang::ASTUnit> {
+auto ImportCpp(Context& context,
+               llvm::ArrayRef<Parse::Tree::PackagingNames> imports,
+               llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
+               std::shared_ptr<clang::CompilerInvocation> invocation) -> void {
   if (imports.empty()) {
-    return nullptr;
+    return;
   }
 
-  CARBON_CHECK(!context.sem_ir().clang_ast_unit());
+  CARBON_CHECK(!context.sem_ir().cpp_file());
 
   PackageNameId package_id = imports.front().package_id;
   CARBON_CHECK(
@@ -480,21 +480,19 @@ auto ImportCppFiles(Context& context,
       }));
   auto name_scope_id = AddNamespace(context, package_id, imports);
 
-  auto [generated_ast, ast_has_error] =
-      GenerateAst(context, imports, fs, std::move(invocation));
+  bool ast_has_error =
+      !GenerateAst(context, imports, fs, std::move(invocation));
 
   SemIR::NameScope& name_scope = context.name_scopes().Get(name_scope_id);
   name_scope.set_is_closed_import(true);
   name_scope.set_clang_decl_context_id(context.clang_decls().Add(
-      {.key = SemIR::ClangDeclKey(
-           generated_ast->getASTContext().getTranslationUnitDecl()),
+      {.key =
+           SemIR::ClangDeclKey(context.ast_context().getTranslationUnitDecl()),
        .inst_id = name_scope.inst_id()}));
 
   if (ast_has_error) {
     name_scope.set_has_error();
   }
-
-  return std::move(generated_ast);
 }
 
 // Returns the Clang `DeclContext` for the given name scope. Return the
@@ -2472,8 +2470,8 @@ auto ImportClassDefinitionForClangDecl(Context& context, SemIR::LocId loc_id,
                                        SemIR::ClassId class_id,
                                        SemIR::ClangDeclId clang_decl_id)
     -> bool {
-  clang::ASTUnit* ast = context.sem_ir().clang_ast_unit();
-  CARBON_CHECK(ast);
+  SemIR::CppFile* cpp_file = context.sem_ir().cpp_file();
+  CARBON_CHECK(cpp_file);
 
   auto* clang_decl =
       cast<clang::TagDecl>(context.clang_decls().Get(clang_decl_id).key.decl);
@@ -2494,8 +2492,8 @@ auto ImportClassDefinitionForClangDecl(Context& context, SemIR::LocId loc_id,
 
   // Ask Clang whether the type is complete. This triggers template
   // instantiation if necessary.
-  clang::DiagnosticErrorTrap trap(ast->getDiagnostics());
-  if (!ast->getSema().isCompleteType(
+  clang::DiagnosticErrorTrap trap(cpp_file->diagnostics());
+  if (!cpp_file->sema().isCompleteType(
           loc, context.ast_context().getCanonicalTagType(clang_decl))) {
     // Type is incomplete. Nothing more to do, but tell the caller if we
     // produced an error.

+ 6 - 7
toolchain/check/cpp/import.h

@@ -17,13 +17,12 @@
 namespace Carbon::Check {
 
 // Generates a C++ header that includes the imported cpp files, parses it,
-// generates the AST from it and links `SemIR::File` to it. Report C++ errors
-// and warnings. If successful, adds a `Cpp` namespace and returns the AST.
-auto ImportCppFiles(Context& context,
-                    llvm::ArrayRef<Parse::Tree::PackagingNames> imports,
-                    llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
-                    std::shared_ptr<clang::CompilerInvocation> invocation)
-    -> std::unique_ptr<clang::ASTUnit>;
+// generates the AST from it and links `SemIR::File` to it. Reports C++ errors
+// and warnings. If successful, adds a `Cpp` namespace.
+auto ImportCpp(Context& context,
+               llvm::ArrayRef<Parse::Tree::PackagingNames> imports,
+               llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs,
+               std::shared_ptr<clang::CompilerInvocation> invocation) -> void;
 
 // Imports a function declaration from Clang to Carbon. If successful, returns
 // the new Carbon function declaration `InstId`. If the declaration was already

+ 1 - 1
toolchain/check/cpp/location.cpp

@@ -56,7 +56,7 @@ static auto GetFileInfo(Context& context, SemIR::CheckIRId ir_id) -> FileInfo {
 
 auto GetCppLocation(Context& context, SemIR::LocId loc_id)
     -> clang::SourceLocation {
-  if (!context.sem_ir().clang_ast_unit()) {
+  if (!context.sem_ir().cpp_file()) {
     return clang::SourceLocation();
   }
 

+ 1 - 3
toolchain/driver/compile_subcommand.cpp

@@ -571,7 +571,6 @@ class CompilationUnit {
   std::optional<std::function<auto()->const Parse::TreeAndSubtrees&>>
       tree_and_subtrees_getter_;
   std::optional<SemIR::File> sem_ir_;
-  std::unique_ptr<clang::ASTUnit> clang_ast_unit_;
   std::unique_ptr<llvm::LLVMContext> llvm_context_;
   std::unique_ptr<llvm::Module> module_;
   std::unique_ptr<llvm::TargetMachine> target_machine_;
@@ -764,8 +763,7 @@ auto CompilationUnit::GetCheckUnit() -> Check::Unit {
           .value_stores = &value_stores_,
           .timings = timings_ ? &*timings_ : nullptr,
           .sem_ir = &*sem_ir_,
-          .total_ir_count = total_ir_count_,
-          .clang_ast_unit = &clang_ast_unit_};
+          .total_ir_count = total_ir_count_};
 }
 
 auto CompilationUnit::PostCheck() -> void {

+ 1 - 3
toolchain/language_server/context.cpp

@@ -148,14 +148,12 @@ auto Context::File::SetText(Context& context, std::optional<int64_t> version,
 
   SemIR::File sem_ir(tree_.get(), SemIR::CheckIRId(0), tree_->packaging_decl(),
                      *value_stores_, uri_.file().str());
-  std::unique_ptr<clang::ASTUnit> clang_ast_unt;
   // TODO: Support cross-file checking when multiple files have edits.
   llvm::SmallVector<Check::Unit> units = {{{.consumer = &consumer,
                                             .value_stores = value_stores_.get(),
                                             .timings = nullptr,
                                             .sem_ir = &sem_ir,
-                                            .total_ir_count = 1,
-                                            .clang_ast_unit = &clang_ast_unt}}};
+                                            .total_ir_count = 1}}};
 
   auto getter = [this]() -> const Parse::TreeAndSubtrees& {
     return *tree_and_subtrees_;

+ 13 - 29
toolchain/lower/file_context.cpp

@@ -73,27 +73,16 @@ auto FileContext::PrepareToLower() -> void {
     // Clang code generation should not actually modify the AST, but isn't
     // const-correct.
     cpp_code_generator_->Initialize(
-        const_cast<clang::ASTContext&>(clang_ast_unit()->getASTContext()));
-
-    // Work around `visitLocalTopLevelDecls` not being const. It doesn't modify
-    // the AST unit other than triggering deserialization.
-    auto* non_const_ast_unit = const_cast<clang::ASTUnit*>(clang_ast_unit());
+        const_cast<clang::ASTContext&>(cpp_file()->ast_context()));
 
     // Emit any top-level declarations now.
-    // TODO: This may miss things that we need to emit which are handed to the
-    // ASTConsumer in other ways. Instead of doing this, we should create the
-    // CodeGenerator earlier and register it as an ASTConsumer before we parse
-    // the C++ inputs.
-    non_const_ast_unit->visitLocalTopLevelDecls(
-        cpp_code_generator_.get(),
-        [](void* codegen_ptr, const clang::Decl* decl) {
-          auto* codegen = static_cast<clang::CodeGenerator*>(codegen_ptr);
-          // CodeGenerator won't modify the declaration it's given, but we can
-          // only call it via the ASTConsumer interface which doesn't know that.
-          auto* non_const_decl = const_cast<clang::Decl*>(decl);
-          codegen->HandleTopLevelDecl(clang::DeclGroupRef(non_const_decl));
-          return true;
-        });
+    cpp_file()->VisitLocalTopLevelDecls([&](const clang::Decl* decl) {
+      // CodeGenerator won't modify the declaration it's given, but we can
+      // only call it via the ASTConsumer interface which doesn't know that.
+      auto* non_const_decl = const_cast<clang::Decl*>(decl);
+      cpp_code_generator_->HandleTopLevelDecl(
+          clang::DeclGroupRef(non_const_decl));
+    });
   }
 
   // Lower all types that were required to be complete.
@@ -180,7 +169,7 @@ auto FileContext::Finalize() -> void {
     // Clang code generation should not actually modify the AST, but isn't
     // const-correct.
     cpp_code_generator_->HandleTranslationUnit(
-        const_cast<clang::ASTContext&>(clang_ast_unit()->getASTContext()));
+        const_cast<clang::ASTContext&>(cpp_file()->ast_context()));
     bool link_error = llvm::Linker::linkModules(
         /*Dest=*/llvm_module(),
         /*Src=*/std::unique_ptr<llvm::Module>(
@@ -196,7 +185,7 @@ auto FileContext::Finalize() -> void {
 
 auto FileContext::CreateCppCodeGenerator()
     -> std::unique_ptr<clang::CodeGenerator> {
-  if (!clang_ast_unit()) {
+  if (!cpp_file()) {
     return nullptr;
   }
 
@@ -207,10 +196,9 @@ auto FileContext::CreateCppCodeGenerator()
   cpp_code_gen_options_.EmitVersionIdentMetadata = false;
 
   return std::unique_ptr<clang::CodeGenerator>(clang::CreateLLVMCodeGen(
-      clang_ast_unit()->getASTContext().getDiagnostics(),
-      clang_module_name_stream.TakeStr(), context().file_system(),
-      cpp_header_search_options_, cpp_preprocessor_options_,
-      cpp_code_gen_options_, llvm_context()));
+      cpp_file()->diagnostics(), clang_module_name_stream.TakeStr(),
+      context().file_system(), cpp_header_search_options_,
+      cpp_preprocessor_options_, cpp_code_gen_options_, llvm_context()));
 }
 
 auto FileContext::GetConstant(SemIR::ConstantId const_id,
@@ -403,10 +391,6 @@ auto FileContext::BuildFunctionTypeInfo(const SemIR::Function& function,
 
 auto FileContext::HandleReferencedCppFunction(clang::FunctionDecl* cpp_decl)
     -> void {
-  // TODO: To support recursive inline functions, collect all calls to
-  // `HandleTopLevelDecl()` in a custom `ASTConsumer` configured in the
-  // `ASTUnit`, and replay them in lowering in the `CodeGenerator`. See
-  // https://discord.com/channels/655572317891461132/768530752592805919/1370509111585935443
   clang::FunctionDecl* cpp_def = cpp_decl->getDefinition();
   if (!cpp_def) {
     return;

+ 1 - 3
toolchain/lower/file_context.h

@@ -118,9 +118,7 @@ class FileContext {
     return *cpp_code_generator_;
   }
   auto sem_ir() const -> const SemIR::File& { return *sem_ir_; }
-  auto clang_ast_unit() -> const clang::ASTUnit* {
-    return sem_ir().clang_ast_unit();
-  }
+  auto cpp_file() -> const SemIR::CppFile* { return sem_ir().cpp_file(); }
   auto inst_namer() -> const SemIR::InstNamer* { return inst_namer_; }
   auto global_variables() -> const Map<SemIR::InstId, llvm::GlobalVariable*>& {
     return global_variables_;

+ 15 - 0
toolchain/sem_ir/BUILD

@@ -63,6 +63,20 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "cpp_file",
+    srcs = ["cpp_file.cpp"],
+    hdrs = ["cpp_file.h"],
+    deps = [
+        "//common:check",
+        "@llvm-project//clang:ast",
+        "@llvm-project//clang:basic",
+        "@llvm-project//clang:frontend",
+        "@llvm-project//clang:sema",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "file",
     srcs = [
@@ -118,6 +132,7 @@ cc_library(
     ],
     deps = [
         ":clang_decl",
+        ":cpp_file",
         ":typed_insts",
         "//common:check",
         "//common:enum_base",

+ 19 - 0
toolchain/sem_ir/cpp_file.cpp

@@ -0,0 +1,19 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "toolchain/sem_ir/cpp_file.h"
+
+namespace Carbon::SemIR {
+
+auto CppFile::VisitLocalTopLevelDecls(
+    llvm::function_ref<void(const clang::Decl*)> visitor) const -> void {
+  ast_unit_->visitLocalTopLevelDecls(
+      &visitor, [](void* erased_visitor_ptr, const clang::Decl* decl) {
+        auto* visitor_ptr = static_cast<decltype(visitor)*>(erased_visitor_ptr);
+        (*visitor_ptr)(decl);
+        return true;
+      });
+}
+
+}  // namespace Carbon::SemIR

+ 61 - 0
toolchain/sem_ir/cpp_file.h

@@ -0,0 +1,61 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_TOOLCHAIN_SEM_IR_CPP_FILE_H_
+#define CARBON_TOOLCHAIN_SEM_IR_CPP_FILE_H_
+
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Frontend/ASTUnit.h"
+#include "clang/Frontend/CompilerInvocation.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/Support/FileSystem.h"
+
+namespace Carbon::SemIR {
+
+// The result of compiling the C++ portion of a `File`, including both any
+// imported C++ headers and any inline C++ fragments.
+class CppFile {
+ public:
+  explicit CppFile(std::unique_ptr<clang::ASTUnit> ast_unit)
+      : ast_unit_(std::move(ast_unit)) {}
+
+  // Access to compilation options.
+  auto diagnostic_options() const -> const clang::DiagnosticOptions& {
+    return ast_unit_->getDiagnostics().getDiagnosticOptions();
+  }
+  auto lang_options() const -> const clang::LangOptions& {
+    return ast_unit_->getLangOpts();
+  }
+
+  // Access to Clang's compilation environment.
+  auto source_manager() -> clang::SourceManager& {
+    return ast_unit_->getSourceManager();
+  }
+  auto source_manager() const -> const clang::SourceManager& {
+    return ast_unit_->getSourceManager();
+  }
+  auto diagnostics() const -> clang::DiagnosticsEngine& {
+    return ast_unit_->getDiagnostics();
+  }
+
+  // Access to layers of Clang's C++ representation.
+  auto ast_context() -> clang::ASTContext& {
+    return ast_unit_->getASTContext();
+  }
+  auto ast_context() const -> const clang::ASTContext& {
+    return ast_unit_->getASTContext();
+  }
+  auto sema() -> clang::Sema& { return ast_unit_->getSema(); }
+
+  // Visit all top-level declarations in the file.
+  auto VisitLocalTopLevelDecls(
+      llvm::function_ref<auto(const clang::Decl*)->void> visitor) const -> void;
+
+ private:
+  std::unique_ptr<clang::ASTUnit> ast_unit_;
+};
+
+}  // namespace Carbon::SemIR
+
+#endif  // CARBON_TOOLCHAIN_SEM_IR_CPP_FILE_H_

+ 14 - 9
toolchain/sem_ir/diagnostic_loc_converter.cpp

@@ -40,9 +40,13 @@ namespace {
 class ClangImportCollector : public clang::DiagnosticRenderer {
  public:
   explicit ClangImportCollector(
-      const clang::LangOptions& lang_opts, clang::DiagnosticOptions& diag_opts,
+      const clang::LangOptions& lang_opts,
+      const clang::DiagnosticOptions& diag_opts,
       llvm::SmallVectorImpl<DiagnosticLocConverter::ImportLoc>* imports)
-      : DiagnosticRenderer(lang_opts, diag_opts), imports_(imports) {}
+      : DiagnosticRenderer(lang_opts,
+                           // Work around lack of const-correctness in Clang.
+                           const_cast<clang::DiagnosticOptions&>(diag_opts)),
+        imports_(imports) {}
 
   void emitDiagnosticMessage(clang::FullSourceLoc loc, clang::PresumedLoc ploc,
                              clang::DiagnosticsEngine::Level /*level*/,
@@ -126,15 +130,16 @@ auto DiagnosticLocConverter::ConvertWithImports(LocId loc_id,
 
   // Convert the C++ import locations.
   if (final_node_id.check_ir_id() == CheckIRId::Cpp) {
-    const clang::ASTUnit* ast = sem_ir_->clang_ast_unit();
+    const SemIR::CppFile* cpp_file = sem_ir_->cpp_file();
+    CARBON_CHECK(cpp_file, "Converting C++ location before C++ file is set");
+
     // Collect the location backtrace that Clang would use for an error here.
-    ClangImportCollector(ast->getLangOpts(),
-                         ast->getDiagnostics().getDiagnosticOptions(),
-                         &result.imports)
+    ClangImportCollector(cpp_file->lang_options(),
+                         cpp_file->diagnostic_options(), &result.imports)
         .emitDiagnostic(
             clang::FullSourceLoc(sem_ir_->clang_source_locs().Get(
                                      final_node_id.clang_source_loc_id()),
-                                 ast->getSourceManager()),
+                                 cpp_file->source_manager()),
             clang::DiagnosticsEngine::Error, "", {}, {});
   }
 
@@ -174,8 +179,8 @@ auto DiagnosticLocConverter::ConvertImpl(
   clang::SourceLocation clang_loc =
       sem_ir_->clang_source_locs().Get(clang_source_loc_id);
 
-  CARBON_CHECK(sem_ir_->clang_ast_unit());
-  const auto& src_mgr = sem_ir_->clang_ast_unit()->getSourceManager();
+  CARBON_CHECK(sem_ir_->cpp_file());
+  const auto& src_mgr = sem_ir_->cpp_file()->source_manager();
   clang::PresumedLoc presumed_loc = src_mgr.getPresumedLoc(clang_loc);
   if (presumed_loc.isInvalid()) {
     return Diagnostics::ConvertedLoc();

+ 3 - 4
toolchain/sem_ir/file.cpp

@@ -222,10 +222,9 @@ auto File::CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const
   mem_usage.Collect(MemUsage::ConcatLabel(label, "types_"), types_);
 }
 
-auto File::set_clang_ast_unit(clang::ASTUnit* clang_ast_unit) -> void {
-  clang_ast_unit_ = clang_ast_unit;
-  clang_mangle_context_.reset(
-      clang_ast_unit->getASTContext().createMangleContext());
+auto File::set_cpp_file(std::unique_ptr<SemIR::CppFile> cpp_file) -> void {
+  cpp_file_ = std::move(cpp_file);
+  clang_mangle_context_.reset(cpp_file_->ast_context().createMangleContext());
 }
 
 }  // namespace Carbon::SemIR

+ 11 - 13
toolchain/sem_ir/file.h

@@ -5,7 +5,6 @@
 #ifndef CARBON_TOOLCHAIN_SEM_IR_FILE_H_
 #define CARBON_TOOLCHAIN_SEM_IR_FILE_H_
 
-#include "clang/Frontend/ASTUnit.h"
 #include "common/error.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/iterator_range.h"
@@ -21,6 +20,7 @@
 #include "toolchain/sem_ir/associated_constant.h"
 #include "toolchain/sem_ir/class.h"
 #include "toolchain/sem_ir/constant.h"
+#include "toolchain/sem_ir/cpp_file.h"
 #include "toolchain/sem_ir/cpp_global_var.h"
 #include "toolchain/sem_ir/cpp_overload_set.h"
 #include "toolchain/sem_ir/entity_name.h"
@@ -228,14 +228,12 @@ class File : public Printable<File> {
   auto import_ir_insts() const -> const ImportIRInstStore& {
     return import_ir_insts_;
   }
-  auto clang_ast_unit() -> clang::ASTUnit* { return clang_ast_unit_; }
-  auto clang_ast_unit() const -> const clang::ASTUnit* {
-    return clang_ast_unit_;
-  }
-  // TODO: When the AST can be created before creating `File`, initialize the
-  // pointer in the constructor and remove this function. This is part of
-  // https://github.com/carbon-language/carbon-lang/issues/4666
-  auto set_clang_ast_unit(clang::ASTUnit* clang_ast_unit) -> void;
+  auto cpp_file() -> SemIR::CppFile* { return cpp_file_.get(); }
+  auto cpp_file() const -> const SemIR::CppFile* { return cpp_file_.get(); }
+  // TODO: We should be able to create the initial C++ AST before creating the
+  // `File` and initialize the pointer in the constructor instead of using a
+  // setter.
+  auto set_cpp_file(std::unique_ptr<SemIR::CppFile> cpp_file) -> void;
   auto clang_mangle_context() -> clang::MangleContext* {
     return clang_mangle_context_.get();
   }
@@ -381,12 +379,12 @@ class File : public Printable<File> {
   // that are import-related.
   ImportIRInstStore import_ir_insts_;
 
-  // The Clang AST to use when looking up `Cpp` names. Null if there are no
-  // `Cpp` imports.
-  clang::ASTUnit* clang_ast_unit_ = nullptr;
+  // The C++ file to use when looking up `Cpp` names. Null if there are no `Cpp`
+  // imports.
+  std::unique_ptr<SemIR::CppFile> cpp_file_;
 
   // The Clang mangle context for the target in the ASTContext. Initialized
-  // together with `clang_ast_unit_`.
+  // together with `cpp_file_`.
   std::unique_ptr<clang::MangleContext> clang_mangle_context_;
 
   // Clang AST declarations pointing to the AST and their mapped Carbon

+ 4 - 5
toolchain/sem_ir/type_info.cpp

@@ -202,9 +202,8 @@ static auto PrintCppCompatLiteral(
     const File& file, clang::CanQualType clang::ASTContext::* qual_type_member,
     unsigned int carbon_bit_width, llvm::StringRef cpp_builtin_name,
     llvm::raw_ostream& out) -> bool {
-  if (file.clang_ast_unit()) {
-    const clang::ASTContext& ast_context =
-        file.clang_ast_unit()->getASTContext();
+  if (const auto* cpp_file = file.cpp_file()) {
+    const clang::ASTContext& ast_context = cpp_file->ast_context();
     if (ast_context.getIntWidth(ast_context.*qual_type_member) ==
         carbon_bit_width) {
       out << "Cpp." << cpp_builtin_name;
@@ -238,13 +237,13 @@ auto RecognizedTypeInfo::PrintLiteral(const File& file,
       return PrintCppCompatLiteral(file, &clang::ASTContext::UnsignedLongLongTy,
                                    64, "unsigned_long_long", out);
     case CppNullptrT:
-      if (file.clang_ast_unit()) {
+      if (file.cpp_file()) {
         out << "Cpp.nullptr_t";
         return true;
       }
       break;
     case CppVoidBase:
-      if (file.clang_ast_unit()) {
+      if (file.cpp_file()) {
         out << "Cpp.void";
         return true;
       }