Просмотр исходного кода

Rearrange name poisoning logic to do a little less work. (#4766)

Insert the poison at the same time we do the name lookup to avoid doing
two hash table lookups into each scope. This adds a bit of complication
because import logic now needs to cope with importing a name that is
already poisoned, but the complexity seems worthwhile to reduce the
number of name lookups performed.

This incidentally fixes a bug where we wouldn't poison any name scopes
if we found the name in an enclosing lexical scope, leading to one extra
diagnostic in existing tests.

Part of #4622
Richard Smith 1 год назад
Родитель
Сommit
6bc36b045f

+ 13 - 19
toolchain/check/context.cpp

@@ -365,7 +365,8 @@ auto Context::LookupNameInDecl(SemIR::LocId loc_id, SemIR::NameId name_id,
     //    // Error, no `F` in `B`.
     //    fn B.F() {}
     auto result = LookupNameInExactScope(loc_id, name_id, scope_id,
-                                         name_scopes().Get(scope_id));
+                                         name_scopes().Get(scope_id),
+                                         /*is_being_declared=*/true);
     return {result.inst_id, result.is_poisoned};
   }
 }
@@ -381,8 +382,6 @@ auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
       scope_stack().LookupInLexicalScopes(name_id);
 
   // Walk the non-lexical scopes and perform lookups into each of them.
-  // Collect scopes to poison this name when it's found.
-  llvm::SmallVector<LookupScope> scopes_to_poison;
   for (auto [index, lookup_scope_id, specific_id] :
        llvm::reverse(non_lexical_scopes)) {
     if (auto non_lexical_result =
@@ -390,17 +389,8 @@ auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
                                 LookupScope{.name_scope_id = lookup_scope_id,
                                             .specific_id = specific_id},
                                 /*required=*/false);
-        !non_lexical_result.is_poisoned) {
-      if (non_lexical_result.inst_id.is_valid()) {
-        // Poison the scopes for this name.
-        for (const auto [scope_id, specific_id] : scopes_to_poison) {
-          name_scopes().Get(scope_id).AddPoison(name_id);
-        }
-
-        return non_lexical_result;
-      }
-      scopes_to_poison.push_back(
-          {.name_scope_id = lookup_scope_id, .specific_id = specific_id});
+        non_lexical_result.inst_id.is_valid()) {
+      return non_lexical_result;
     }
   }
 
@@ -423,12 +413,16 @@ auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
 
 auto Context::LookupNameInExactScope(SemIRLoc loc, SemIR::NameId name_id,
                                      SemIR::NameScopeId scope_id,
-                                     const SemIR::NameScope& scope)
+                                     SemIR::NameScope& scope,
+                                     bool is_being_declared)
     -> LookupNameInExactScopeResult {
-  if (auto entry_id = scope.Lookup(name_id)) {
+  if (auto entry_id = is_being_declared ? scope.Lookup(name_id)
+                                        : scope.LookupOrPoison(name_id)) {
     auto entry = scope.GetEntry(*entry_id);
     if (!entry.is_poisoned) {
       LoadImportRef(*this, entry.inst_id);
+    } else if (is_being_declared) {
+      entry.inst_id = SemIR::InstId::Invalid;
     }
     return {entry.inst_id, entry.access_kind, entry.is_poisoned};
   }
@@ -593,7 +587,7 @@ auto Context::LookupQualifiedName(SemIR::LocId loc_id, SemIR::NameId name_id,
       has_error = true;
       continue;
     }
-    const auto& name_scope = name_scopes().Get(scope_id);
+    auto& name_scope = name_scopes().Get(scope_id);
     has_error |= name_scope.has_error();
 
     auto [scope_result_id, access_kind, is_poisoned] =
@@ -612,7 +606,7 @@ auto Context::LookupQualifiedName(SemIR::LocId loc_id, SemIR::NameId name_id,
       });
     }
 
-    if (!is_poisoned && (!scope_result_id.is_valid() || is_access_prohibited)) {
+    if (!scope_result_id.is_valid() || is_access_prohibited) {
       // If nothing is found in this scope or if we encountered an invalid
       // access, look in its extended scopes.
       const auto& extended = name_scope.extended_scopes();
@@ -657,7 +651,7 @@ auto Context::LookupQualifiedName(SemIR::LocId loc_id, SemIR::NameId name_id,
     result.is_poisoned = is_poisoned;
   }
 
-  if (required && (!result.inst_id.is_valid() || result.is_poisoned)) {
+  if (required && !result.inst_id.is_valid()) {
     if (!has_error) {
       if (prohibited_accesses.empty()) {
         DiagnoseMemberNameNotFound(loc_id, name_id, lookup_scopes);

+ 9 - 3
toolchain/check/context.h

@@ -242,11 +242,17 @@ class Context {
 
   // Performs a name lookup in a specified scope, returning the referenced
   // instruction. Does not look into extended scopes. Returns an invalid
-  // instruction if the name is poisoned or not found.
-  // TODO: Return the poisoning instruction if poisoned.
+  // instruction if the name is not found.
+  //
+  // If `is_being_declared` is false, then this is a regular name lookup, and
+  // the name will be poisoned if not found so that later lookups will fail; a
+  // poisoned name will be treated as if it is not declared. Otherwise, this is
+  // a lookup for a name being declared, so the name will not be poisoned, but
+  // poison will be returned if it's already been looked up.
   auto LookupNameInExactScope(SemIRLoc loc, SemIR::NameId name_id,
                               SemIR::NameScopeId scope_id,
-                              const SemIR::NameScope& scope)
+                              SemIR::NameScope& scope,
+                              bool is_being_declared = false)
       -> LookupNameInExactScopeResult;
 
   // Appends the lookup scopes corresponding to `base_const_id` to `*scopes`.

+ 14 - 11
toolchain/check/import.cpp

@@ -110,16 +110,17 @@ static auto AddNamespace(Context& context, SemIR::TypeId namespace_type_id,
       SemIR::InstId::Invalid, SemIR::AccessKind::Public);
   if (!inserted) {
     const auto& prev_entry = parent_scope->GetEntry(entry_id);
-    CARBON_CHECK(!prev_entry.is_poisoned);
-    auto prev_inst_id = prev_entry.inst_id;
-    if (auto namespace_inst =
-            context.insts().TryGetAs<SemIR::Namespace>(prev_inst_id)) {
-      if (diagnose_duplicate_namespace) {
-        auto import_id = make_import_id();
-        CARBON_CHECK(import_id.is_valid());
-        context.DiagnoseDuplicateName(import_id, prev_inst_id);
+    if (!prev_entry.is_poisoned) {
+      auto prev_inst_id = prev_entry.inst_id;
+      if (auto namespace_inst =
+              context.insts().TryGetAs<SemIR::Namespace>(prev_inst_id)) {
+        if (diagnose_duplicate_namespace) {
+          auto import_id = make_import_id();
+          CARBON_CHECK(import_id.is_valid());
+          context.DiagnoseDuplicateName(import_id, prev_inst_id);
+        }
+        return {namespace_inst->name_scope_id, prev_inst_id, true};
       }
-      return {namespace_inst->name_scope_id, prev_inst_id, true};
     }
   }
 
@@ -148,9 +149,11 @@ static auto AddNamespace(Context& context, SemIR::TypeId namespace_type_id,
   parent_scope = &context.name_scopes().Get(parent_scope_id);
 
   // Diagnose if there's a name conflict, but still produce the namespace to
-  // supersede the name conflict in order to avoid repeat diagnostics.
+  // supersede the name conflict in order to avoid repeat diagnostics. Names are
+  // poisoned optimistically by name lookup before checking for imports, so we
+  // may be overwriting a poisoned entry here.
   auto& entry = parent_scope->GetEntry(entry_id);
-  if (!inserted) {
+  if (!inserted && !entry.is_poisoned) {
     context.DiagnoseDuplicateName(namespace_id, entry.inst_id);
     entry.access_kind = SemIR::AccessKind::Public;
   }

+ 97 - 15
toolchain/check/testdata/function/declaration/no_prelude/name_poisoning.carbon

@@ -293,21 +293,49 @@ class X {
   }
 }
 
-// --- fail_no_poison_when_lookup_fails.carbon
+// --- fail_poison_when_lookup_fails.carbon
 
 library "[[@TEST_NAME]]";
 
 namespace N;
-// Here we fail to find C so we don't poison anything.
-// CHECK:STDERR: fail_no_poison_when_lookup_fails.carbon:[[@LINE+3]]:11: error: name `C` not found [NameNotFound]
+// CHECK:STDERR: fail_poison_when_lookup_fails.carbon:[[@LINE+5]]:11: error: name `C` not found [NameNotFound]
 // CHECK:STDERR: fn N.F(x: C);
 // CHECK:STDERR:           ^
+// CHECK:STDERR:
+// CHECK:STDERR: fail_poison_when_lookup_fails.carbon: error: name used before it was declared [NameUseBeforeDecl]
 fn N.F(x: C);
 
-// No failures below because nothing was poisoned.
+// TODO: We should ideally only produce one diagnostic here.
+// CHECK:STDERR: fail_poison_when_lookup_fails.carbon:[[@LINE+5]]:1: note: declared here [NameUseBeforeDeclNote]
+// CHECK:STDERR: class C {}
+// CHECK:STDERR: ^~~~~~~~~
+// CHECK:STDERR:
+// CHECK:STDERR: fail_poison_when_lookup_fails.carbon: error: name used before it was declared [NameUseBeforeDecl]
 class C {}
+// CHECK:STDERR: fail_poison_when_lookup_fails.carbon:[[@LINE+4]]:1: note: declared here [NameUseBeforeDeclNote]
+// CHECK:STDERR: class N.C {}
+// CHECK:STDERR: ^~~~~~~~~~~
+// CHECK:STDERR:
 class N.C {}
 
+// --- fail_poison_with_lexical_result.carbon
+// CHECK:STDERR: fail_poison_with_lexical_result.carbon: error: name used before it was declared [NameUseBeforeDecl]
+
+library "[[@TEST_NAME]]";
+
+fn F() {
+  class A {}
+
+  class B {
+    var v: A;
+
+    // CHECK:STDERR: fail_poison_with_lexical_result.carbon:[[@LINE+3]]:5: note: declared here [NameUseBeforeDeclNote]
+    // CHECK:STDERR:     class A {}
+    // CHECK:STDERR:     ^~~~~~~~~
+    class A {}
+  }
+}
+
 // CHECK:STDOUT: --- no_poison.carbon
 // CHECK:STDOUT:
 // CHECK:STDOUT: constants {
@@ -1158,25 +1186,23 @@ class N.C {}
 // CHECK:STDOUT:
 // CHECK:STDOUT: specific @B(constants.%Self) {}
 // CHECK:STDOUT:
-// CHECK:STDOUT: --- fail_no_poison_when_lookup_fails.carbon
+// CHECK:STDOUT: --- fail_poison_when_lookup_fails.carbon
 // CHECK:STDOUT:
 // CHECK:STDOUT: constants {
 // CHECK:STDOUT:   %F.type: type = fn_type @F [template]
 // CHECK:STDOUT:   %F: %F.type = struct_value () [template]
-// CHECK:STDOUT:   %C.f79: type = class_type @C.1 [template]
+// CHECK:STDOUT:   %.a95: type = class_type @.1 [template]
 // CHECK:STDOUT:   %empty_struct_type: type = struct_type {} [template]
 // CHECK:STDOUT:   %complete_type: <witness> = complete_type_witness %empty_struct_type [template]
-// CHECK:STDOUT:   %C.9f4: type = class_type @C.2 [template]
+// CHECK:STDOUT:   %.fb7: type = class_type @.2 [template]
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: file {
 // CHECK:STDOUT:   package: <namespace> = namespace [template] {
 // CHECK:STDOUT:     .N = %N
-// CHECK:STDOUT:     .C = %C.decl.loc12
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %N: <namespace> = namespace [template] {
 // CHECK:STDOUT:     .F = %F.decl
-// CHECK:STDOUT:     .C = %C.decl.loc13
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %F.decl: %F.type = fn_decl @F [template = constants.%F] {
 // CHECK:STDOUT:     %x.patt: <error> = binding_pattern x
@@ -1186,25 +1212,81 @@ class N.C {}
 // CHECK:STDOUT:     %C.ref: <error> = name_ref C, <error> [template = <error>]
 // CHECK:STDOUT:     %x: <error> = bind_name x, %x.param
 // CHECK:STDOUT:   }
-// CHECK:STDOUT:   %C.decl.loc12: type = class_decl @C.1 [template = constants.%C.f79] {} {}
-// CHECK:STDOUT:   %C.decl.loc13: type = class_decl @C.2 [template = constants.%C.9f4] {} {}
+// CHECK:STDOUT:   %.decl.loc18: type = class_decl @.1 [template = constants.%.a95] {} {}
+// CHECK:STDOUT:   %.decl.loc23: type = class_decl @.2 [template = constants.%.fb7] {} {}
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: class @C.1 {
+// CHECK:STDOUT: class @.1 {
 // CHECK:STDOUT:   %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
 // CHECK:STDOUT:
 // CHECK:STDOUT: !members:
-// CHECK:STDOUT:   .Self = constants.%C.f79
+// CHECK:STDOUT:   .Self = constants.%.a95
 // CHECK:STDOUT:   complete_type_witness = %complete_type
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: class @C.2 {
+// CHECK:STDOUT: class @.2 {
 // CHECK:STDOUT:   %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
 // CHECK:STDOUT:
 // CHECK:STDOUT: !members:
-// CHECK:STDOUT:   .Self = constants.%C.9f4
+// CHECK:STDOUT:   .Self = constants.%.fb7
 // CHECK:STDOUT:   complete_type_witness = %complete_type
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: fn @F(%x.param_patt: <error>);
 // CHECK:STDOUT:
+// CHECK:STDOUT: --- fail_poison_with_lexical_result.carbon
+// CHECK:STDOUT:
+// CHECK:STDOUT: constants {
+// CHECK:STDOUT:   %F.type: type = fn_type @F [template]
+// CHECK:STDOUT:   %F: %F.type = struct_value () [template]
+// CHECK:STDOUT:   %A: type = class_type @A [template]
+// CHECK:STDOUT:   %empty_struct_type: type = struct_type {} [template]
+// CHECK:STDOUT:   %complete_type.357: <witness> = complete_type_witness %empty_struct_type [template]
+// CHECK:STDOUT:   %B: type = class_type @B [template]
+// CHECK:STDOUT:   %B.elem: type = unbound_element_type %B, %A [template]
+// CHECK:STDOUT:   %.96d: type = class_type @.1 [template]
+// CHECK:STDOUT:   %struct_type.v: type = struct_type {.v: %A} [template]
+// CHECK:STDOUT:   %complete_type.57e: <witness> = complete_type_witness %struct_type.v [template]
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: file {
+// CHECK:STDOUT:   package: <namespace> = namespace [template] {
+// CHECK:STDOUT:     .F = %F.decl
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %F.decl: %F.type = fn_decl @F [template = constants.%F] {} {}
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: class @A {
+// CHECK:STDOUT:   %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type.357]
+// CHECK:STDOUT:
+// CHECK:STDOUT: !members:
+// CHECK:STDOUT:   .Self = constants.%A
+// CHECK:STDOUT:   complete_type_witness = %complete_type
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: class @B {
+// CHECK:STDOUT:   %.loc9: %B.elem = field_decl v, element0 [template]
+// CHECK:STDOUT:   %.decl: type = class_decl @.1 [template = constants.%.96d] {} {}
+// CHECK:STDOUT:   %complete_type: <witness> = complete_type_witness %struct_type.v [template = constants.%complete_type.57e]
+// CHECK:STDOUT:
+// CHECK:STDOUT: !members:
+// CHECK:STDOUT:   .Self = constants.%B
+// CHECK:STDOUT:   .v = %.loc9
+// CHECK:STDOUT:   complete_type_witness = %complete_type
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: class @.1 {
+// CHECK:STDOUT:   %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type.357]
+// CHECK:STDOUT:
+// CHECK:STDOUT: !members:
+// CHECK:STDOUT:   .Self = constants.%.96d
+// CHECK:STDOUT:   complete_type_witness = %complete_type
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @F() {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   %A.decl: type = class_decl @A [template = constants.%A] {} {}
+// CHECK:STDOUT:   %B.decl: type = class_decl @B [template = constants.%B] {} {}
+// CHECK:STDOUT:   return
+// CHECK:STDOUT: }
+// CHECK:STDOUT:

+ 12 - 8
toolchain/check/testdata/interface/no_prelude/import_access.carbon

@@ -101,12 +101,17 @@ fn F(i: Test.ForwardWithDef) {}
 
 impl package Test library "[[@TEST_NAME]]";
 
-// CHECK:STDERR: fail_todo_forward.impl.carbon:[[@LINE+4]]:9: error: name `Forward` not found [NameNotFound]
+// CHECK:STDERR: fail_todo_forward.impl.carbon:[[@LINE+5]]:9: error: name `Forward` not found [NameNotFound]
 // CHECK:STDERR: fn F(i: Forward*) {}
 // CHECK:STDERR:         ^~~~~~~
 // CHECK:STDERR:
+// CHECK:STDERR: fail_todo_forward.impl.carbon: error: name used before it was declared [NameUseBeforeDecl]
 fn F(i: Forward*) {}
 
+// CHECK:STDERR: fail_todo_forward.impl.carbon:[[@LINE+4]]:1: note: declared here [NameUseBeforeDeclNote]
+// CHECK:STDERR: interface Forward {}
+// CHECK:STDERR: ^~~~~~~~~~~~~~~~~~~
+// CHECK:STDERR:
 interface Forward {}
 
 // --- fail_local_forward.carbon
@@ -406,14 +411,13 @@ private interface Redecl {}
 // CHECK:STDOUT: constants {
 // CHECK:STDOUT:   %F.type: type = fn_type @F [template]
 // CHECK:STDOUT:   %F: %F.type = struct_value () [template]
-// CHECK:STDOUT:   %Forward.type: type = facet_type <@Forward> [template]
-// CHECK:STDOUT:   %Self: %Forward.type = bind_symbolic_name Self, 0 [symbolic]
+// CHECK:STDOUT:   %.type: type = facet_type <@.1> [template]
+// CHECK:STDOUT:   %Self: %.type = bind_symbolic_name Self, 0 [symbolic]
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: file {
 // CHECK:STDOUT:   package: <namespace> = namespace [template] {
 // CHECK:STDOUT:     .F = %F.decl
-// CHECK:STDOUT:     .Forward = %Forward.decl
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %Test.import = import Test
 // CHECK:STDOUT:   %default.import = import <invalid>
@@ -422,17 +426,17 @@ private interface Redecl {}
 // CHECK:STDOUT:     %i.param_patt: <error> = value_param_pattern %i.patt, runtime_param0
 // CHECK:STDOUT:   } {
 // CHECK:STDOUT:     %i.param: <error> = value_param runtime_param0
-// CHECK:STDOUT:     %.loc8: type = splice_block %ptr [template = <error>] {
+// CHECK:STDOUT:     %.loc9: type = splice_block %ptr [template = <error>] {
 // CHECK:STDOUT:       %Forward.ref: <error> = name_ref Forward, <error> [template = <error>]
 // CHECK:STDOUT:       %ptr: type = ptr_type <error> [template = <error>]
 // CHECK:STDOUT:     }
 // CHECK:STDOUT:     %i: <error> = bind_name i, %i.param
 // CHECK:STDOUT:   }
-// CHECK:STDOUT:   %Forward.decl: type = interface_decl @Forward [template = constants.%Forward.type] {} {}
+// CHECK:STDOUT:   %.decl: type = interface_decl @.1 [template = constants.%.type] {} {}
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: interface @Forward {
-// CHECK:STDOUT:   %Self: %Forward.type = bind_symbolic_name Self, 0 [symbolic = constants.%Self]
+// CHECK:STDOUT: interface @.1 {
+// CHECK:STDOUT:   %Self: %.type = bind_symbolic_name Self, 0 [symbolic = constants.%Self]
 // CHECK:STDOUT:
 // CHECK:STDOUT: !members:
 // CHECK:STDOUT:   .Self = %Self

+ 0 - 1
toolchain/check/testdata/namespace/merging_with_indirections.carbon

@@ -168,7 +168,6 @@ fn Run() {
 // CHECK:STDOUT:   }
 // CHECK:STDOUT:   %Other: <namespace> = namespace file.%Other.import, [template] {
 // CHECK:STDOUT:     .F = %import_ref.f04
-// CHECK:STDOUT:     .NS1 = %NS1
 // CHECK:STDOUT:     import Other//b
 // CHECK:STDOUT:     import Other//a
 // CHECK:STDOUT:   }

+ 17 - 11
toolchain/sem_ir/name_scope.cpp

@@ -33,8 +33,7 @@ auto NameScope::Print(llvm::raw_ostream& out) const -> void {
 }
 
 auto NameScope::AddRequired(Entry name_entry) -> void {
-  CARBON_CHECK(!name_entry.is_poisoned,
-               "Cannot add a poisoned name: {0}. Use AddPoison()",
+  CARBON_CHECK(!name_entry.is_poisoned, "Cannot add a poisoned name: {0}.",
                name_entry.name_id);
   auto add_name = [&] {
     EntryId index(names_.size());
@@ -42,8 +41,13 @@ auto NameScope::AddRequired(Entry name_entry) -> void {
     return index;
   };
   auto result = name_map_.Insert(name_entry.name_id, add_name);
-  CARBON_CHECK(result.is_inserted(), "Failed to add required name: {0}",
-               name_entry.name_id);
+  if (!result.is_inserted()) {
+    // A required name can overwrite poison.
+    auto& name = names_[result.value().index];
+    CARBON_CHECK(name.is_poisoned, "Failed to add required name: {0}",
+                 name_entry.name_id);
+    name = name_entry;
+  }
 }
 
 auto NameScope::LookupOrAdd(SemIR::NameId name_id, InstId inst_id,
@@ -59,14 +63,16 @@ auto NameScope::LookupOrAdd(SemIR::NameId name_id, InstId inst_id,
   return {true, EntryId(names_.size() - 1)};
 }
 
-auto NameScope::AddPoison(NameId name_id) -> void {
+auto NameScope::LookupOrPoison(NameId name_id) -> std::optional<EntryId> {
   auto insert_result = name_map_.Insert(name_id, EntryId(names_.size()));
-  CARBON_CHECK(insert_result.is_inserted(),
-               "Trying to poison an existing name: {0}", name_id);
-  names_.push_back({.name_id = name_id,
-                    .inst_id = InstId::Invalid,
-                    .access_kind = AccessKind::Public,
-                    .is_poisoned = true});
+  if (insert_result.is_inserted()) {
+    names_.push_back({.name_id = name_id,
+                      .inst_id = InstId::Invalid,
+                      .access_kind = AccessKind::Public,
+                      .is_poisoned = true});
+    return std::nullopt;
+  }
+  return insert_result.value();
 }
 
 auto NameScopeStore::GetInstIfValid(NameScopeId scope_id) const

+ 10 - 7
toolchain/sem_ir/name_scope.h

@@ -62,19 +62,22 @@ class NameScope : public Printable<NameScope> {
     return lookup.value();
   }
 
-  // Adds a new name known to not exist. Must not be poisoned.
+  // Adds a new name that is known to not exist. The new entry is not allowed to
+  // be poisoned. An existing poisoned entry can be overwritten.
   auto AddRequired(Entry name_entry) -> void;
 
-  // If the given name already exists, return true with the EntryId; the entry
-  // might be poisoned. Otherwise, adds the name using inst_id and access_kind
-  // and returns false with the new EntryId.
+  // Searches for the given name. If found, including if a poisoned entry is
+  // found, returns true with the existing EntryId. Otherwise, adds the name
+  // using inst_id and access_kind and returns false with the new EntryId.
   //
-  // This cannot be used to add poisoned entries; use AddPoison instead.
+  // This cannot be used to add poisoned entries; use LookupOrPoison instead.
   auto LookupOrAdd(SemIR::NameId name_id, InstId inst_id,
                    AccessKind access_kind) -> std::pair<bool, EntryId>;
 
-  // Adds a new poisoned name.
-  auto AddPoison(NameId name_id) -> void;
+  // Searches for the given name. If found, including if a poisoned entry is
+  // found, returns the corresponding EntryId. Otherwise, returns nullopt and
+  // poisons the name so it can't be declared later.
+  auto LookupOrPoison(NameId name_id) -> std::optional<EntryId>;
 
   auto extended_scopes() const -> llvm::ArrayRef<InstId> {
     return extended_scopes_;

+ 79 - 2
toolchain/sem_ir/name_scope_test.cpp

@@ -86,6 +86,49 @@ TEST(NameScope, Lookup) {
   EXPECT_EQ(lookup, std::nullopt);
 }
 
+TEST(NameScope, LookupOrPoison) {
+  int id = 0;
+
+  InstId scope_inst_id(++id);
+  NameId scope_name_id(++id);
+  NameScopeId parent_scope_id(++id);
+  NameScope name_scope(scope_inst_id, scope_name_id, parent_scope_id);
+
+  NameScope::Entry entry1 = {.name_id = NameId(++id),
+                             .inst_id = InstId(++id),
+                             .access_kind = AccessKind::Public};
+  name_scope.AddRequired(entry1);
+
+  NameScope::Entry entry2 = {.name_id = NameId(++id),
+                             .inst_id = InstId(++id),
+                             .access_kind = AccessKind::Protected};
+  name_scope.AddRequired(entry2);
+
+  NameScope::Entry entry3 = {.name_id = NameId(++id),
+                             .inst_id = InstId(++id),
+                             .access_kind = AccessKind::Private};
+  name_scope.AddRequired(entry3);
+
+  auto lookup = name_scope.LookupOrPoison(entry1.name_id);
+  ASSERT_NE(lookup, std::nullopt);
+  EXPECT_THAT(static_cast<NameScope&>(name_scope).GetEntry(*lookup),
+              NameScopeEntryEquals(entry1));
+  EXPECT_THAT(static_cast<const NameScope&>(name_scope).GetEntry(*lookup),
+              NameScopeEntryEquals(entry1));
+
+  lookup = name_scope.LookupOrPoison(entry2.name_id);
+  ASSERT_NE(lookup, std::nullopt);
+  EXPECT_THAT(name_scope.GetEntry(*lookup), NameScopeEntryEquals(entry2));
+
+  lookup = name_scope.LookupOrPoison(entry3.name_id);
+  ASSERT_NE(lookup, std::nullopt);
+  EXPECT_THAT(name_scope.GetEntry(*lookup), NameScopeEntryEquals(entry3));
+
+  NameId unknown_name_id(++id);
+  lookup = name_scope.LookupOrPoison(unknown_name_id);
+  EXPECT_EQ(lookup, std::nullopt);
+}
+
 TEST(NameScope, LookupOrAdd) {
   int id = 0;
 
@@ -155,7 +198,7 @@ TEST(NameScope, Poison) {
   NameScope name_scope(scope_inst_id, scope_name_id, parent_scope_id);
 
   NameId poison1(++id);
-  name_scope.AddPoison(poison1);
+  EXPECT_EQ(name_scope.LookupOrPoison(poison1), std::nullopt);
   EXPECT_THAT(name_scope.entries(),
               ElementsAre(NameScopeEntryEquals(
                   NameScope::Entry({.name_id = poison1,
@@ -164,7 +207,7 @@ TEST(NameScope, Poison) {
                                     .is_poisoned = true}))));
 
   NameId poison2(++id);
-  name_scope.AddPoison(poison2);
+  EXPECT_EQ(name_scope.LookupOrPoison(poison2), std::nullopt);
   EXPECT_THAT(name_scope.entries(),
               ElementsAre(NameScopeEntryEquals(NameScope::Entry(
                               {.name_id = poison1,
@@ -187,6 +230,40 @@ TEST(NameScope, Poison) {
                                              .is_poisoned = true})));
 }
 
+TEST(NameScope, AddRequiredAfterPoison) {
+  int id = 0;
+
+  InstId scope_inst_id(++id);
+  NameId scope_name_id(++id);
+  NameScopeId parent_scope_id(++id);
+  NameScope name_scope(scope_inst_id, scope_name_id, parent_scope_id);
+
+  NameId name_id(++id);
+  InstId inst_id(++id);
+
+  EXPECT_EQ(name_scope.LookupOrPoison(name_id), std::nullopt);
+  EXPECT_THAT(name_scope.entries(),
+              ElementsAre(NameScopeEntryEquals(
+                  NameScope::Entry({.name_id = name_id,
+                                    .inst_id = InstId::Invalid,
+                                    .access_kind = AccessKind::Public,
+                                    .is_poisoned = true}))));
+
+  NameScope::Entry entry = {.name_id = name_id,
+                            .inst_id = inst_id,
+                            .access_kind = AccessKind::Private};
+  name_scope.AddRequired(entry);
+
+  auto lookup = name_scope.LookupOrPoison(name_id);
+  ASSERT_NE(lookup, std::nullopt);
+  EXPECT_THAT(
+      name_scope.GetEntry(*lookup),
+      NameScopeEntryEquals(NameScope::Entry({.name_id = name_id,
+                                             .inst_id = inst_id,
+                                             .access_kind = AccessKind::Private,
+                                             .is_poisoned = false})));
+}
+
 TEST(NameScope, ExtendedScopes) {
   int id = 0;