Просмотр исходного кода

Explorer: move subtyping logic to TypeChecker (#2484)

Addresses comments from this discussion: https://github.com/carbon-language/carbon-lang/pull/2460#discussion_r1046444433

Features:
* Move subtyping logic from Interpreter to TypeChecker, exposing subtyping as a series of access to `.base`.
    * Excludes function parameter conversion, which is still done in ::Convert due to parameters conversion being handled differently.

Changes:
* Add new `class BaseAccessExpression : public MemberAccessExpression`, allowing rewrites
* Handle `BaseAccessExpression` expression type in Interpreter
* Move subtyping logic to `TypeChecker::ImplicitlyConvert`
Adrien Leravat 3 лет назад
Родитель
Сommit
026c4b9dc3

+ 1 - 0
explorer/ast/ast_rtti.txt

@@ -61,6 +61,7 @@ abstract class Expression : AstNode;
   abstract class MemberAccessExpression : Expression;
     class SimpleMemberAccessExpression : MemberAccessExpression;
     class CompoundMemberAccessExpression : MemberAccessExpression;
+    class BaseAccessExpression : MemberAccessExpression;
   class IndexExpression : Expression;
   class IntTypeLiteral : Expression;
   class ContinuationTypeLiteral : Expression;

+ 6 - 0
explorer/ast/expression.cpp

@@ -181,6 +181,11 @@ void Expression::Print(llvm::raw_ostream& out) const {
       out << access.object() << ".(" << access.path() << ")";
       break;
     }
+    case ExpressionKind::BaseAccessExpression: {
+      const auto& access = cast<BaseAccessExpression>(*this);
+      out << access.object() << ".base";
+      break;
+    }
     case ExpressionKind::TupleLiteral: {
       out << "(";
       llvm::ListSeparator sep;
@@ -336,6 +341,7 @@ void Expression::PrintID(llvm::raw_ostream& out) const {
     case ExpressionKind::IndexExpression:
     case ExpressionKind::SimpleMemberAccessExpression:
     case ExpressionKind::CompoundMemberAccessExpression:
+    case ExpressionKind::BaseAccessExpression:
     case ExpressionKind::IfExpression:
     case ExpressionKind::WhereExpression:
     case ExpressionKind::BuiltinConvertExpression:

+ 36 - 0
explorer/ast/expression.h

@@ -419,6 +419,28 @@ class IndexExpression : public Expression {
   Nonnull<Expression*> offset_;
 };
 
+class BaseAccessExpression : public MemberAccessExpression {
+ public:
+  explicit BaseAccessExpression(SourceLocation source_loc,
+                                Nonnull<Expression*> object,
+                                Nonnull<const BaseElement*> base)
+      : MemberAccessExpression(AstNodeKind::BaseAccessExpression, source_loc,
+                               object),
+        base_(base) {
+    set_static_type(&base->type());
+    set_value_category(ValueCategory::Let);
+  }
+
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromBaseAccessExpression(node->kind());
+  }
+
+  auto element() const -> const BaseElement& { return *base_; }
+
+ private:
+  const Nonnull<const BaseElement*> base_;
+};
+
 class IntLiteral : public Expression {
  public:
   explicit IntLiteral(SourceLocation source_loc, int value)
@@ -975,8 +997,22 @@ class BuiltinConvertExpression : public Expression {
     return source_expression_;
   }
 
+  // Set the rewritten form of this expression. Can only be called during type
+  // checking.
+  auto set_rewritten_form(Nonnull<const Expression*> rewritten_form) -> void {
+    CARBON_CHECK(!rewritten_form_.has_value()) << "rewritten form set twice";
+    rewritten_form_ = rewritten_form;
+  }
+
+  // Get the rewritten form of this expression. A rewritten form can be used to
+  // prepare the conversion during type checking.
+  auto rewritten_form() const -> std::optional<Nonnull<const Expression*>> {
+    return rewritten_form_;
+  }
+
  private:
   Nonnull<Expression*> source_expression_;
+  std::optional<Nonnull<const Expression*>> rewritten_form_;
 };
 
 // An expression whose semantics have not been implemented. This can be used

+ 1 - 0
explorer/fuzzing/ast_to_proto.cpp

@@ -110,6 +110,7 @@ static auto ExpressionToProto(const Expression& expression)
     -> Fuzzing::Expression {
   Fuzzing::Expression expression_proto;
   switch (expression.kind()) {
+    case ExpressionKind::BaseAccessExpression:
     case ExpressionKind::ValueLiteral: {
       // This does not correspond to source syntax.
       break;

+ 1 - 0
explorer/interpreter/BUILD

@@ -220,6 +220,7 @@ cc_library(
         ":dictionary",
         ":interpreter",
         ":pattern_analysis",
+        "//common:check",
         "//common:error",
         "//common:ostream",
         "//explorer/ast",

+ 32 - 0
explorer/interpreter/interpreter.cpp

@@ -453,6 +453,18 @@ auto Interpreter::StepLvalue() -> ErrorOr<Success> {
         return todo_.FinishAction(arena_->New<LValue>(field));
       }
     }
+    case ExpressionKind::BaseAccessExpression: {
+      const auto& access = cast<BaseAccessExpression>(exp);
+      if (act.pos() == 0) {
+        // Get LValue for expression.
+        return todo_.Spawn(std::make_unique<LValAction>(&access.object()));
+      } else {
+        // Append `.base` element to the address, and return the new LValue.
+        Address object = cast<LValue>(*act.results()[0]).address();
+        Address base = object.ElementAddress(&access.element());
+        return todo_.FinishAction(arena_->New<LValue>(base));
+      }
+    }
     case ExpressionKind::IndexExpression: {
       if (act.pos() == 0) {
         //    { {e[i] :: C, E, F} :: S, H}
@@ -884,6 +896,8 @@ auto Interpreter::Convert(Nonnull<const Value*> value,
       CARBON_CHECK(pointee->kind() == Value::Kind::NominalClassValue)
           << "Unexpected pointer type";
 
+      // Conversion logic for subtyping for function arguments only.
+      // TODO: Drop when able to rewrite subtyping in TypeChecker for arguments.
       const auto* dest_ptr = cast<PointerType>(destination_type);
       std::optional<Nonnull<const NominalClassValue*>> class_subobj =
           cast<NominalClassValue>(pointee);
@@ -1268,6 +1282,21 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
         }
       }
     }
+    case ExpressionKind::BaseAccessExpression: {
+      const auto& access = cast<BaseAccessExpression>(exp);
+      if (act.pos() == 0) {
+        return todo_.Spawn(
+            std::make_unique<ExpressionAction>(&access.object()));
+      } else {
+        ElementPath::Component base_elt(&access.element(), std::nullopt,
+                                        std::nullopt);
+        const Value* value = act.results()[0];
+        CARBON_ASSIGN_OR_RETURN(Nonnull<const Value*> base_value,
+                                value->GetElement(arena_, ElementPath(base_elt),
+                                                  exp.source_loc(), value));
+        return todo_.FinishAction(base_value);
+      }
+    }
     case ExpressionKind::IdentifierExpression: {
       CARBON_CHECK(act.pos() == 0);
       const auto& ident = cast<IdentifierExpression>(exp);
@@ -1578,6 +1607,9 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
     }
     case ExpressionKind::BuiltinConvertExpression: {
       const auto& convert_expr = cast<BuiltinConvertExpression>(exp);
+      if (auto rewrite = convert_expr.rewritten_form()) {
+        return todo_.ReplaceWith(std::make_unique<ExpressionAction>(*rewrite));
+      }
       if (act.pos() == 0) {
         return todo_.Spawn(std::make_unique<ExpressionAction>(
             convert_expr.source_expression()));

+ 1 - 0
explorer/interpreter/resolve_names.cpp

@@ -275,6 +275,7 @@ static auto ResolveNames(Expression& expression,
       break;
     case ExpressionKind::ValueLiteral:
     case ExpressionKind::BuiltinConvertExpression:
+    case ExpressionKind::BaseAccessExpression:
       CARBON_FATAL() << "should not exist before type checking";
     case ExpressionKind::UnimplementedExpression:
       return ProgramError(expression.source_loc()) << "Unimplemented";

+ 1 - 0
explorer/interpreter/resolve_unformed.cpp

@@ -138,6 +138,7 @@ static auto ResolveUnformed(Nonnull<const Expression*> expression,
     case ExpressionKind::ValueLiteral:
     case ExpressionKind::IndexExpression:
     case ExpressionKind::CompoundMemberAccessExpression:
+    case ExpressionKind::BaseAccessExpression:
     case ExpressionKind::IfExpression:
     case ExpressionKind::WhereExpression:
     case ExpressionKind::StructTypeLiteral:

+ 43 - 1
explorer/interpreter/type_checker.cpp

@@ -16,6 +16,7 @@
 #include <unordered_set>
 #include <vector>
 
+#include "common/check.h"
 #include "common/error.h"
 #include "common/ostream.h"
 #include "explorer/ast/declaration.h"
@@ -610,6 +611,32 @@ auto TypeChecker::IsImplicitlyConvertible(
          impl_scope.Resolve(*iface_type, source, source_loc, *this).ok();
 }
 
+auto TypeChecker::BuildSubtypeConversion(Nonnull<Expression*> source,
+                                         Nonnull<const PointerType*> src_ptr,
+                                         Nonnull<const PointerType*> dest_ptr)
+    -> ErrorOr<Nonnull<const Expression*>> {
+  const auto* src_class = dyn_cast<NominalClassType>(&src_ptr->pointee_type());
+  const auto* dest_class =
+      dyn_cast<NominalClassType>(&dest_ptr->pointee_type());
+  const auto dest = dest_class->declaration().name();
+  CARBON_CHECK(src_class && dest_class)
+      << "Invalid source or destination pointee";
+  Nonnull<Expression*> last_expr = source;
+  const auto* cur_class = src_class;
+  while (!TypeEqual(cur_class, dest_class, std::nullopt)) {
+    const auto src = src_class->declaration().name();
+    const auto base_class = cur_class->base();
+    CARBON_CHECK(base_class) << "Invalid subtyping conversion";
+    auto* base_expr = arena_->New<BaseAccessExpression>(
+        source->source_loc(), last_expr,
+        arena_->New<BaseElement>(arena_->New<PointerType>(*base_class)));
+    last_expr = base_expr;
+    cur_class = *base_class;
+  }
+  CARBON_CHECK(last_expr) << "Error, no conversion was needed";
+  return last_expr;
+}
+
 auto TypeChecker::ImplicitlyConvert(std::string_view context,
                                     const ImplScope& impl_scope,
                                     Nonnull<Expression*> source,
@@ -671,7 +698,21 @@ auto TypeChecker::ImplicitlyConvert(std::string_view context,
     }
 
     // Perform the builtin conversion.
-    return arena_->New<BuiltinConvertExpression>(source, destination);
+    auto* convert_expr =
+        arena_->New<BuiltinConvertExpression>(source, destination);
+
+    // For subtyping, rewrite into successive `.base` accesses.
+    if (isa<PointerType>(source_type) && isa<PointerType>(destination) &&
+        cast<PointerType>(destination)->pointee_type().kind() ==
+            Value::Kind::NominalClassType) {
+      CARBON_ASSIGN_OR_RETURN(
+          const auto* rewrite,
+          BuildSubtypeConversion(source, cast<PointerType>(source_type),
+                                 cast<PointerType>(destination)))
+      convert_expr->set_rewritten_form(rewrite);
+    }
+
+    return convert_expr;
   }
 
   ErrorOr<Nonnull<Expression*>> converted = BuildBuiltinMethodCall(
@@ -2385,6 +2426,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
   switch (e->kind()) {
     case ExpressionKind::ValueLiteral:
     case ExpressionKind::BuiltinConvertExpression:
+    case ExpressionKind::BaseAccessExpression:
       CARBON_FATAL() << "attempting to type check node " << *e
                      << " generated during type checking";
     case ExpressionKind::IndexExpression: {

+ 7 - 0
explorer/interpreter/type_checker.h

@@ -383,6 +383,13 @@ class TypeChecker {
   auto ExpectNonPlaceholderType(SourceLocation source_loc,
                                 Nonnull<const Value*> type) -> ErrorOr<Success>;
 
+  // Build and return class subtyping conversion expression, converting from
+  // `src_ptr` to `dest_ptr`.
+  auto BuildSubtypeConversion(Nonnull<Expression*> source,
+                              Nonnull<const PointerType*> src_ptr,
+                              Nonnull<const PointerType*> dest_ptr)
+      -> ErrorOr<Nonnull<const Expression*>>;
+
   // Determine whether `type1` and `type2` are considered to be the same type
   // in the given scope. This is true if they're structurally identical or if
   // there is an equality relation in scope that specifies that they are the

+ 10 - 4
explorer/interpreter/value.cpp

@@ -199,10 +199,16 @@ static auto GetElement(Nonnull<Arena*> arena, Nonnull<const Value*> v,
       }
     }
     case ElementKind::BaseElement:
-      if (const auto* class_value = dyn_cast<NominalClassValue>(v)) {
-        return GetBaseElement(class_value, source_loc);
-      } else {
-        CARBON_FATAL() << "Invalid value for base element";
+      switch (v->kind()) {
+        case Value::Kind::NominalClassValue:
+          return GetBaseElement(cast<NominalClassValue>(v), source_loc);
+        case Value::Kind::PointerValue: {
+          const auto* ptr = cast<PointerValue>(v);
+          return arena->New<PointerValue>(
+              ptr->address().ElementAddress(path_comp.element()));
+        }
+        default:
+          CARBON_FATAL() << "Invalid value for base element";
       }
   }
 }

+ 44 - 0
explorer/testdata/class/class_subtyping_argument.carbon

@@ -0,0 +1,44 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// AUTOUPDATE
+// RUN: %{explorer-run}
+// RUN: %{explorer-run-trace}
+// CHECK:STDOUT: Foo(c1): 1
+// CHECK:STDOUT: Foo(c2): 1
+// CHECK:STDOUT: Foo(d): 1
+// CHECK:STDOUT: Foo(&e): 1
+// CHECK:STDOUT: result: 0
+
+package ExplorerTest api;
+
+base class C {
+  var val: i32;
+}
+
+base class D extends C {
+  var val: i32;
+}
+
+class E extends D {
+  var val: i32;
+}
+
+fn Foo(c: C*) -> i32 {
+  return (*c).val;
+}
+
+fn Main() -> i32 {
+  var e: E = { .val = 3, .base = {.val = 2,.base = {.val = 1}}};
+  var d: D* = &e;
+  var c1: C* = &e;
+  var c2: C* = d;
+
+  Print("Foo(c1): {0}", Foo(c1));
+  Print("Foo(c2): {0}", Foo(c2));
+  Print("Foo(d): {0}", Foo(d));
+  Print("Foo(&e): {0}", Foo(&e));
+
+  return 0;
+}

+ 0 - 0
explorer/testdata/class/class_subtyping.carbon → explorer/testdata/class/class_subtyping_basic.carbon


+ 43 - 0
explorer/testdata/class/class_subtyping_multiple.carbon

@@ -0,0 +1,43 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// AUTOUPDATE
+// RUN: %{explorer-run}
+// RUN: %{explorer-run-trace}
+// CHECK:STDOUT: (*c1).val: 1
+// CHECK:STDOUT: (*c2).val: 1
+// CHECK:STDOUT: (*d).val: 2
+// CHECK:STDOUT: e.val: 3
+// CHECK:STDOUT: result: 0
+
+package ExplorerTest api;
+
+base class C {
+  var val: i32;
+}
+
+base class D extends C {
+  var val: i32;
+}
+
+class E extends D {
+  var val: i32;
+}
+
+fn Foo(c: C*) -> i32 {
+  return (*c).val;
+}
+
+fn Main() -> i32 {
+  var e: E = { .val = 3, .base = {.val = 2,.base = {.val = 1}}};
+  var d: D* = &e;
+  var c1: C* = &e;
+  var c2: C* = d;
+  Print("(*c1).val: {0}", (*c1).val);
+  Print("(*c2).val: {0}", (*c2).val);
+  Print("(*d).val: {0}", (*d).val);
+  Print("e.val: {0}", e.val);
+
+  return 0;
+}

+ 0 - 0
explorer/testdata/class/abstract_class_subtyping.carbon → explorer/testdata/class/non_virtual_dispatch_abstract.carbon