Browse Source

Fix crash discovered by fuzzer. (#5100)

If the vtable of a base class is erroneous, generate an erroneous vtable
for the derived class too.
Richard Smith 1 year ago
parent
commit
59003a5d4c

+ 58 - 47
toolchain/check/handle_class.cpp

@@ -666,6 +666,62 @@ static auto AddStructTypeFields(
   return fields_id;
 }
 
+// Builds and returns a vtable for the current class. Assumes that the virtual
+// functions for the class are listed as the top element of the `vtable_stack`.
+static auto BuildVtable(Context& context, Parse::NodeId node_id,
+                        SemIR::InstId base_vtable_id) -> SemIR::InstId {
+  llvm::SmallVector<SemIR::InstId> vtable;
+  if (base_vtable_id.has_value()) {
+    LoadImportRef(context, base_vtable_id);
+    auto canonical_base_vtable_id =
+        context.constant_values().GetConstantInstId(base_vtable_id);
+    if (canonical_base_vtable_id == SemIR::ErrorInst::SingletonInstId) {
+      return SemIR::ErrorInst::SingletonInstId;
+    }
+    auto base_vtable_inst_block = context.inst_blocks().Get(
+        context.insts()
+            .GetAs<SemIR::Vtable>(canonical_base_vtable_id)
+            .virtual_functions_id);
+    // TODO: Avoid quadratic search. Perhaps build a map from `NameId` to the
+    // elements of the top of `vtable_stack`.
+    for (auto fn_decl_id : base_vtable_inst_block) {
+      auto fn_decl = GetCalleeFunction(context.sem_ir(), fn_decl_id);
+      const auto& fn = context.functions().Get(fn_decl.function_id);
+      for (auto override_fn_decl_id :
+           context.vtable_stack().PeekCurrentBlockContents()) {
+        auto override_fn_decl =
+            context.insts().GetAs<SemIR::FunctionDecl>(override_fn_decl_id);
+        const auto& override_fn =
+            context.functions().Get(override_fn_decl.function_id);
+        if (override_fn.virtual_modifier ==
+                SemIR::FunctionFields::VirtualModifier::Impl &&
+            override_fn.name_id == fn.name_id) {
+          // TODO: Support generic base classes, rather than passing
+          // `SpecificId::None`.
+          CheckFunctionTypeMatches(context, override_fn, fn,
+                                   SemIR::SpecificId::None,
+                                   /*check_syntax=*/false,
+                                   /*check_self=*/false);
+          fn_decl_id = override_fn_decl_id;
+        }
+      }
+      vtable.push_back(fn_decl_id);
+    }
+  }
+
+  for (auto inst_id : context.vtable_stack().PeekCurrentBlockContents()) {
+    auto fn_decl = context.insts().GetAs<SemIR::FunctionDecl>(inst_id);
+    const auto& fn = context.functions().Get(fn_decl.function_id);
+    if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
+      vtable.push_back(inst_id);
+    }
+  }
+  return AddInst<SemIR::Vtable>(
+      context, node_id,
+      {.type_id = GetSingletonType(context, SemIR::VtableType::SingletonInstId),
+       .virtual_functions_id = context.inst_blocks().Add(vtable)});
+}
+
 // Checks that the specified finished class definition is valid and builds and
 // returns a corresponding complete type witness instruction.
 static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
@@ -709,54 +765,9 @@ static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
   }
 
   if (class_info.is_dynamic) {
-    llvm::SmallVector<SemIR::InstId> vtable;
-    if (!defining_vptr) {
-      LoadImportRef(context, base_class_info->vtable_id);
-      auto base_vtable_id = context.constant_values().GetConstantInstId(
-          base_class_info->vtable_id);
-      auto base_vtable_inst_block =
-          context.inst_blocks().Get(context.insts()
-                                        .GetAs<SemIR::Vtable>(base_vtable_id)
-                                        .virtual_functions_id);
-      // TODO: Avoid quadratic search. Perhaps build a map from `NameId` to the
-      // elements of the top of `vtable_stack`.
-      for (auto fn_decl_id : base_vtable_inst_block) {
-        auto fn_decl = GetCalleeFunction(context.sem_ir(), fn_decl_id);
-        const auto& fn = context.functions().Get(fn_decl.function_id);
-        for (auto override_fn_decl_id :
-             context.vtable_stack().PeekCurrentBlockContents()) {
-          auto override_fn_decl =
-              context.insts().GetAs<SemIR::FunctionDecl>(override_fn_decl_id);
-          const auto& override_fn =
-              context.functions().Get(override_fn_decl.function_id);
-          if (override_fn.virtual_modifier ==
-                  SemIR::FunctionFields::VirtualModifier::Impl &&
-              override_fn.name_id == fn.name_id) {
-            // TODO: Support generic base classes, rather than passing
-            // `SpecificId::None`.
-            CheckFunctionTypeMatches(context, override_fn, fn,
-                                     SemIR::SpecificId::None,
-                                     /*check_syntax=*/false,
-                                     /*check_self=*/false);
-            fn_decl_id = override_fn_decl_id;
-          }
-        }
-        vtable.push_back(fn_decl_id);
-      }
-    }
-
-    for (auto inst_id : context.vtable_stack().PeekCurrentBlockContents()) {
-      auto fn_decl = context.insts().GetAs<SemIR::FunctionDecl>(inst_id);
-      const auto& fn = context.functions().Get(fn_decl.function_id);
-      if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
-        vtable.push_back(inst_id);
-      }
-    }
-    class_info.vtable_id = AddInst<SemIR::Vtable>(
+    class_info.vtable_id = BuildVtable(
         context, node_id,
-        {.type_id =
-             GetSingletonType(context, SemIR::VtableType::SingletonInstId),
-         .virtual_functions_id = context.inst_blocks().Add(vtable)});
+        defining_vptr ? SemIR::InstId::None : base_class_info->vtable_id);
   }
 
   return AddInst<SemIR::CompleteTypeWitness>(

+ 127 - 0
toolchain/check/testdata/class/no_prelude/fail_error_recovery.carbon

@@ -0,0 +1,127 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// AUTOUPDATE
+// TIP: To test this file alone, run:
+// TIP:   bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/check/testdata/class/no_prelude/fail_error_recovery.carbon
+// TIP: To dump output, run:
+// TIP:   bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/class/no_prelude/fail_error_recovery.carbon
+
+// --- fail_virtual_fn_in_invalid_context.carbon
+
+// CHECK:STDERR: fail_virtual_fn_in_invalid_context.carbon:[[@LINE+4]]:10: error: name `error_not_found` not found [NameNotFound]
+// CHECK:STDERR: fn F(N:! error_not_found) {
+// CHECK:STDERR:          ^~~~~~~~~~~~~~~
+// CHECK:STDERR:
+fn F(N:! error_not_found) {
+  base class C {
+    virtual fn Foo[self: Self]() {}
+  }
+
+  base class D {
+    extend base: C;
+  }
+}
+
+// CHECK:STDOUT: --- fail_virtual_fn_in_invalid_context.carbon
+// CHECK:STDOUT:
+// CHECK:STDOUT: constants {
+// CHECK:STDOUT:   %N.patt: <error> = symbolic_binding_pattern N, 0 [symbolic]
+// CHECK:STDOUT:   %F.type: type = fn_type @F [concrete]
+// CHECK:STDOUT:   %F: %F.type = struct_value () [concrete]
+// CHECK:STDOUT:   %C: type = class_type @C [concrete]
+// CHECK:STDOUT:   %ptr.454: type = ptr_type <vtable> [concrete]
+// CHECK:STDOUT:   %struct_type.vptr: type = struct_type {.<vptr>: %ptr.454} [concrete]
+// CHECK:STDOUT:   %complete_type.513: <witness> = complete_type_witness %struct_type.vptr [concrete]
+// CHECK:STDOUT:   %D: type = class_type @D [concrete]
+// CHECK:STDOUT:   %struct_type.base: type = struct_type {.base: %C} [concrete]
+// CHECK:STDOUT:   %complete_type.1d0: <witness> = complete_type_witness %struct_type.base [concrete]
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: file {
+// CHECK:STDOUT:   package: <namespace> = namespace [concrete] {
+// CHECK:STDOUT:     .error_not_found = <poisoned>
+// CHECK:STDOUT:     .F = %F.decl
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %F.decl: %F.type = fn_decl @F [concrete = constants.%F] {
+// CHECK:STDOUT:     %N.patt.loc6_6.1: <error> = symbolic_binding_pattern N, 0 [symbolic = %N.patt.loc6_6.2 (constants.%N.patt)]
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     %error_not_found.ref: <error> = name_ref error_not_found, <error> [concrete = <error>]
+// CHECK:STDOUT:     %N: <error> = bind_symbolic_name N, 0 [concrete = <error>]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: generic class @C(@F.%N: <error>) {
+// CHECK:STDOUT: !definition:
+// CHECK:STDOUT:
+// CHECK:STDOUT:   class {
+// CHECK:STDOUT:     %Foo.decl: <error> = fn_decl @Foo [concrete = <error>] {
+// CHECK:STDOUT:       %self.patt: <error> = binding_pattern self
+// CHECK:STDOUT:       %self.param_patt: <error> = value_param_pattern %self.patt, call_param0 [concrete = <error>]
+// CHECK:STDOUT:     } {
+// CHECK:STDOUT:       %self.param: <error> = value_param call_param0
+// CHECK:STDOUT:       %Self.ref: <error> = name_ref Self, <error> [concrete = <error>]
+// CHECK:STDOUT:       %self: <error> = bind_name self, %self.param
+// CHECK:STDOUT:     }
+// CHECK:STDOUT:     %.loc9: <vtable> = vtable (%Foo.decl) [concrete = <error>]
+// CHECK:STDOUT:     %complete_type: <witness> = complete_type_witness %struct_type.vptr [concrete = constants.%complete_type.513]
+// CHECK:STDOUT:     complete_type_witness = %complete_type
+// CHECK:STDOUT:
+// CHECK:STDOUT:   !members:
+// CHECK:STDOUT:     .Self = <error>
+// CHECK:STDOUT:     .Foo = %Foo.decl
+// CHECK:STDOUT:   }
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: generic class @D(@F.%N: <error>) {
+// CHECK:STDOUT: !definition:
+// CHECK:STDOUT:
+// CHECK:STDOUT:   class {
+// CHECK:STDOUT:     %C.ref: type = name_ref C, @F.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:     %.loc12: <error> = base_decl %C.ref, element0 [concrete]
+// CHECK:STDOUT:     %complete_type: <witness> = complete_type_witness %struct_type.base [concrete = constants.%complete_type.1d0]
+// CHECK:STDOUT:     complete_type_witness = %complete_type
+// CHECK:STDOUT:
+// CHECK:STDOUT:   !members:
+// CHECK:STDOUT:     .Self = <error>
+// CHECK:STDOUT:     .C = <poisoned>
+// CHECK:STDOUT:     .base = %.loc12
+// CHECK:STDOUT:     extend %C.ref
+// CHECK:STDOUT:   }
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: generic fn @F(%N: <error>) {
+// CHECK:STDOUT:   %N.patt.loc6_6.2: <error> = symbolic_binding_pattern N, 0 [symbolic = %N.patt.loc6_6.2 (constants.%N.patt)]
+// CHECK:STDOUT:
+// CHECK:STDOUT: !definition:
+// CHECK:STDOUT:
+// CHECK:STDOUT:   fn(%N.patt.loc6_6.1: <error>) {
+// CHECK:STDOUT:   !entry:
+// CHECK:STDOUT:     %C.decl: type = class_decl @C [concrete = constants.%C] {} {}
+// CHECK:STDOUT:     %D.decl: type = class_decl @D [concrete = constants.%D] {} {}
+// CHECK:STDOUT:     return
+// CHECK:STDOUT:   }
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: generic virtual fn @Foo(@F.%N: <error>) {
+// CHECK:STDOUT: !definition:
+// CHECK:STDOUT:
+// CHECK:STDOUT:   virtual fn[%self.param_patt: <error>]() {
+// CHECK:STDOUT:   !entry:
+// CHECK:STDOUT:     return
+// CHECK:STDOUT:   }
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: specific @F(<error>) {
+// CHECK:STDOUT:   %N.patt.loc6_6.2 => <error>
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: specific @C(<error>) {
+// CHECK:STDOUT: !definition:
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: specific @Foo(<error>) {}
+// CHECK:STDOUT:
+// CHECK:STDOUT: specific @D(<error>) {}
+// CHECK:STDOUT: