Просмотр исходного кода

Support for user-defined Eq impls (#1730)

Related to #1549; user-defined equality is a prerequisite for user-defined comparison.
Thejaswi Kadur 3 лет назад
Родитель
Сommit
5665f32bf7

+ 2 - 0
common/fuzzing/carbon.proto

@@ -92,6 +92,8 @@ message IntrinsicExpression {
     Alloc = 2;
     Dealloc = 3;
     Rand = 4;
+    IntEq = 5;
+    StrEq = 6;
   }
   optional Intrinsic intrinsic = 1;
   optional TupleLiteralExpression argument = 2;

+ 6 - 0
common/fuzzing/proto_to_carbon.cpp

@@ -309,6 +309,12 @@ static auto ExpressionToCarbon(const Fuzzing::Expression& expression,
         case Fuzzing::IntrinsicExpression::Rand:
           out << "__intrinsic_rand";
           break;
+        case Fuzzing::IntrinsicExpression::IntEq:
+          out << "__intrinsic_int_eq";
+          break;
+        case Fuzzing::IntrinsicExpression::StrEq:
+          out << "__intrinsic_str_eq";
+          break;
       }
       TupleLiteralExpressionToCarbon(intrinsic.argument(), out);
     } break;

+ 8 - 1
explorer/ast/expression.cpp

@@ -30,7 +30,9 @@ auto IntrinsicExpression::FindIntrinsic(std::string_view name,
       {{"print", Intrinsic::Print},
        {"new", Intrinsic::Alloc},
        {"delete", Intrinsic::Dealloc},
-       {"rand", Intrinsic::Rand}});
+       {"rand", Intrinsic::Rand},
+       {"int_eq", Intrinsic::IntEq},
+       {"str_eq", Intrinsic::StrEq}});
   name.remove_prefix(std::strlen("__intrinsic_"));
   auto it = intrinsic_map.find(name);
   if (it == intrinsic_map.end()) {
@@ -191,6 +193,11 @@ void Expression::Print(llvm::raw_ostream& out) const {
         case IntrinsicExpression::Intrinsic::Rand:
           out << "rand";
           break;
+        case IntrinsicExpression::Intrinsic::IntEq:
+          out << "int_eq";
+          break;
+        case IntrinsicExpression::Intrinsic::StrEq:
+          out << "str_eq";
       }
       out << iexp.args();
       break;

+ 1 - 1
explorer/ast/expression.h

@@ -664,7 +664,7 @@ class ValueLiteral : public Expression {
 
 class IntrinsicExpression : public Expression {
  public:
-  enum class Intrinsic { Print, Alloc, Dealloc, Rand };
+  enum class Intrinsic { Print, Alloc, Dealloc, Rand, IntEq, StrEq };
 
   // Returns the enumerator corresponding to the intrinsic named `name`,
   // or raises a fatal compile error if there is no such enumerator.

+ 29 - 0
explorer/data/prelude.carbon

@@ -120,3 +120,32 @@ class Heap {
 }
 
 var heap: Heap = {};
+
+interface EqWith(U:! Type) {
+  fn Equal[me: Self](other: U) -> Bool;
+  // TODO: NotEqual with default impl
+}
+// TODO: constraint Eq { ... }
+
+
+// TODO: Simplify this once we have variadics
+impl forall [T2:! Type, U2:! Type, T1:! EqWith(T2), U1:! EqWith(U2)]
+    (T1, U1) as EqWith((T2, U2)) {
+  fn Equal[me: Self](other: (T2, U2)) -> Bool {
+    let (l1: T1, l2: U1) = me;
+    let (r1: T2, r2: U2) = other;
+    return l1 == r1 and l2 == r2;
+  }
+}
+
+impl i32 as EqWith(Self) {
+  fn Equal[me: Self](other: Self) -> Bool {
+    return __intrinsic_int_eq(me, other);
+  }
+}
+
+impl String as EqWith(Self) {
+  fn Equal[me: Self](other: Self) -> Bool {
+    return __intrinsic_str_eq(me, other);
+  }
+}

+ 6 - 0
explorer/fuzzing/ast_to_proto.cpp

@@ -242,6 +242,12 @@ static auto ExpressionToProto(const Expression& expression)
         case IntrinsicExpression::Intrinsic::Rand:
           intrinsic_proto->set_intrinsic(Fuzzing::IntrinsicExpression::Rand);
           break;
+        case IntrinsicExpression::Intrinsic::IntEq:
+          intrinsic_proto->set_intrinsic(Fuzzing::IntrinsicExpression::IntEq);
+          break;
+        case IntrinsicExpression::Intrinsic::StrEq:
+          intrinsic_proto->set_intrinsic(Fuzzing::IntrinsicExpression::StrEq);
+          break;
       }
       *intrinsic_proto->mutable_argument() =
           TupleLiteralExpressionToProto(intrinsic.args());

+ 11 - 4
explorer/interpreter/builtins.h

@@ -5,7 +5,9 @@
 #ifndef CARBON_EXPLORER_INTERPRETER_BUILTINS_H_
 #define CARBON_EXPLORER_INTERPRETER_BUILTINS_H_
 
+#include <array>
 #include <optional>
+#include <string_view>
 
 #include "common/error.h"
 #include "explorer/ast/declaration.h"
@@ -32,7 +34,10 @@ class Builtins {
     MulWith,
     ModWith,
 
-    Last = ModWith
+    // Comparison.
+    EqWith,
+
+    Last = EqWith
   };
   // TODO: In C++20, replace with `using enum Builtin;`.
   static constexpr Builtin As = Builtin::As;
@@ -42,6 +47,7 @@ class Builtins {
   static constexpr Builtin SubWith = Builtin::SubWith;
   static constexpr Builtin MulWith = Builtin::MulWith;
   static constexpr Builtin ModWith = Builtin::ModWith;
+  static constexpr Builtin EqWith = Builtin::EqWith;
 
   // Register a declaration that might be a builtin.
   void Register(Nonnull<const Declaration*> decl);
@@ -51,14 +57,15 @@ class Builtins {
       -> ErrorOr<Nonnull<const Declaration*>>;
 
   // Get the source name of a builtin.
-  static constexpr auto GetName(Builtin builtin) -> const char* {
+  static constexpr auto GetName(Builtin builtin) -> std::string_view {
     return BuiltinNames[static_cast<int>(builtin)];
   }
 
  private:
   static constexpr int NumBuiltins = static_cast<int>(Builtin::Last) + 1;
-  static constexpr const char* BuiltinNames[NumBuiltins] = {
-      "As", "ImplicitAs", "Negate", "AddWith", "SubWith", "MulWith", "ModWith"};
+  static constexpr std::array<std::string_view, NumBuiltins> BuiltinNames = {
+      "As",      "ImplicitAs", "Negate",  "AddWith",
+      "SubWith", "MulWith",    "ModWith", "EqWith"};
 
   std::optional<Nonnull<const Declaration*>> builtins_[NumBuiltins] = {};
 };

+ 19 - 3
explorer/interpreter/interpreter.cpp

@@ -203,8 +203,6 @@ auto Interpreter::EvalPrim(Operator op, Nonnull<const Value*> static_type,
     case Operator::Or:
       return arena_->New<BoolValue>(cast<BoolValue>(*args[0]).value() ||
                                     cast<BoolValue>(*args[1]).value());
-    case Operator::Eq:
-      return arena_->New<BoolValue>(ValueEqual(args[0], args[1], std::nullopt));
     case Operator::Ptr:
       return arena_->New<PointerType>(args[0]);
     case Operator::Deref:
@@ -214,7 +212,9 @@ auto Interpreter::EvalPrim(Operator op, Nonnull<const Value*> static_type,
     case Operator::Combine:
       return &cast<TypeOfConstraintType>(static_type)->constraint_type();
     case Operator::As:
-      return Convert(args[0], args[1], source_loc);
+    case Operator::Eq:
+      CARBON_FATAL() << "These operators should have been rewritten to "
+                        "interface method calls";
   }
 }
 
@@ -1211,6 +1211,22 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
           heap_.Deallocate(cast<PointerValue>(args.elements()[0])->address());
           return todo_.FinishAction(TupleValue::Empty());
         }
+        case IntrinsicExpression::Intrinsic::IntEq: {
+          const auto& args = cast<TupleValue>(*act.results()[0]).elements();
+          CARBON_CHECK(args.size() == 2);
+          auto lhs = cast<IntValue>(*args[0]).value();
+          auto rhs = cast<IntValue>(*args[1]).value();
+          auto result = arena_->New<BoolValue>(lhs == rhs);
+          return todo_.FinishAction(result);
+        }
+        case IntrinsicExpression::Intrinsic::StrEq: {
+          const auto& args = cast<TupleValue>(*act.results()[0]).elements();
+          CARBON_CHECK(args.size() == 2);
+          auto& lhs = cast<StringValue>(*args[0]).value();
+          auto& rhs = cast<StringValue>(*args[1]).value();
+          auto result = arena_->New<BoolValue>(lhs == rhs);
+          return todo_.FinishAction(result);
+        }
       }
     }
     case ExpressionKind::IntTypeLiteral: {

+ 43 - 5
explorer/interpreter/type_checker.cpp

@@ -2024,12 +2024,20 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
           op.set_static_type(arena_->New<BoolType>());
           op.set_value_category(ValueCategory::Let);
           return Success();
-        case Operator::Eq:
-          CARBON_RETURN_IF_ERROR(
-              ExpectExactType(e->source_loc(), "==", ts[0], ts[1], impl_scope));
-          op.set_static_type(arena_->New<BoolType>());
-          op.set_value_category(ValueCategory::Let);
+        case Operator::Eq: {
+          ErrorOr<Nonnull<Expression*>> converted = BuildBuiltinMethodCall(
+              impl_scope, op.arguments()[0],
+              BuiltinInterfaceName{Builtins::EqWith, ts[1]},
+              BuiltinMethodCall{"Equal", op.arguments()[1]});
+          if (!converted.ok()) {
+            // We couldn't find a matching `impl`.
+            return CompilationError(e->source_loc())
+                   << *ts[0] << " is not equality comparable with " << *ts[1]
+                   << " (" << converted.error().message() << ")";
+          }
+          op.set_rewritten_form(*converted);
           return Success();
+        }
         case Operator::Deref:
           CARBON_RETURN_IF_ERROR(
               ExpectPointerType(e->source_loc(), "*", ts[0]));
@@ -2261,6 +2269,36 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
           e->set_value_category(ValueCategory::Let);
           return Success();
         }
+        case IntrinsicExpression::Intrinsic::IntEq: {
+          if (args.size() != 2) {
+            return CompilationError(e->source_loc())
+                   << "__intrinsic_int_eq takes 2 arguments";
+          }
+          CARBON_RETURN_IF_ERROR(ExpectExactType(
+              e->source_loc(), "__intrinsic_int_eq argument 1",
+              arena_->New<IntType>(), &args[0]->static_type(), impl_scope));
+          CARBON_RETURN_IF_ERROR(ExpectExactType(
+              e->source_loc(), "__intrinsic_int_eq argument 2",
+              arena_->New<IntType>(), &args[1]->static_type(), impl_scope));
+          e->set_static_type(arena_->New<BoolType>());
+          e->set_value_category(ValueCategory::Let);
+          return Success();
+        }
+        case IntrinsicExpression::Intrinsic::StrEq: {
+          if (args.size() != 2) {
+            return CompilationError(e->source_loc())
+                   << "__intrinsic_str_eq takes 2 arguments";
+          }
+          CARBON_RETURN_IF_ERROR(ExpectExactType(
+              e->source_loc(), "__intrinsic_str_eq argument 1",
+              arena_->New<StringType>(), &args[0]->static_type(), impl_scope));
+          CARBON_RETURN_IF_ERROR(ExpectExactType(
+              e->source_loc(), "__intrinsic_str_eq argument 2",
+              arena_->New<StringType>(), &args[1]->static_type(), impl_scope));
+          e->set_static_type(arena_->New<BoolType>());
+          e->set_value_category(ValueCategory::Let);
+          return Success();
+        }
       }
     }
     case ExpressionKind::IntTypeLiteral:

+ 20 - 0
explorer/testdata/comparison/builtin_equality.carbon

@@ -0,0 +1,20 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{explorer} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{explorer} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{explorer} %s
+// CHECK: strings equal: 0
+// CHECK: ints equal: 1
+// CHECK: result: 0
+
+package ExplorerTest api;
+
+fn Main() -> i32 {
+  Print("strings equal: {0}", if "hello" == "world" then 1 else 0);
+  Print("ints equal: {0}", if 1 == 1 then 1 else 0);
+  return 0;
+}

+ 30 - 0
explorer/testdata/comparison/custom_equality.carbon

@@ -0,0 +1,30 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{explorer} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{explorer} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{explorer} %s
+// CHECK: structs equal: 0
+// CHECK: result: 0
+
+package ExplorerTest api;
+
+class MyType {
+  var value: i32;
+
+  impl as EqWith(Self) {
+    fn Equal[me: Self](other: Self) -> Bool {
+      return me.value == other.value;
+    }
+  }
+}
+
+fn Main() -> i32 {
+  let x: MyType = {.value = 1};
+  let y: MyType = {.value = 2};
+  Print("structs equal: {0}", if x == y then 1 else 0);
+  return 0;
+}

+ 25 - 0
explorer/testdata/comparison/fail_empty_struct.carbon

@@ -0,0 +1,25 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{not} %{explorer} %s 2>&1 2>&1 | %{FileCheck} %s
+// AUTOUPDATE: %{explorer} %s
+
+package ExplorerTest api;
+
+// TODO: This should work
+// CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/comparison/fail_empty_struct.carbon:[[@LINE+1]]: type error in call: '{}' is not implicitly convertible to 'Type'
+external impl {} as EqWith({}) {
+  fn Equal[me: Self](other: Self) -> Bool {
+    return true;
+  }
+}
+
+fn Main() -> i32 {
+  var empty: auto = {};
+  if (empty == {}) {
+    return 0;
+  } else {
+    return 1;
+  }
+}

+ 14 - 0
explorer/testdata/comparison/fail_no_impl.carbon

@@ -0,0 +1,14 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{not} %{explorer} %s 2>&1 2>&1 | %{FileCheck} %s
+// AUTOUPDATE: %{explorer} %s
+
+package ExplorerTest api;
+
+fn Main() -> i32 {
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/comparison/fail_no_impl.carbon:[[@LINE+1]]: i32 is not equality comparable with String (could not find implementation of interface EqWith(U = String) for i32)
+  Print("different types equal: {0}", if 1 == "1" then 1 else 0);
+  return 0;
+}

+ 0 - 3
explorer/testdata/struct/empty.carbon

@@ -14,9 +14,6 @@ package ExplorerTest api;
 fn Main() -> i32 {
   var empty: {} = {};
   empty = {};
-  if (not (empty == {})) {
-    return 1;
-  }
   match (empty) {
     case {} => {
       return 0;

+ 7 - 0
explorer/testdata/struct/equality.carbon

@@ -11,6 +11,13 @@
 
 package ExplorerTest api;
 
+// TODO: Implement this with some kind of reflection?
+external impl {.x: i32, .y: i32} as EqWith(Self) {
+  fn Equal[me: Self](other: Self) -> Bool {
+    return me.x == other.x and me.y == other.y;
+  }
+}
+
 fn Main() -> i32 {
   var t1: {.x: i32, .y: i32} = {.x = 5, .y = 2};
   var t2: {.x: i32, .y: i32} = {.x = 5, .y = 2};

+ 7 - 0
explorer/testdata/struct/equality_false.carbon

@@ -11,6 +11,13 @@
 
 package ExplorerTest api;
 
+// TODO: Implement this with some kind of reflection?
+external impl {.x: i32, .y: i32} as EqWith(Self) {
+  fn Equal[me: Self](other: Self) -> Bool {
+    return me.x == other.x and me.y == other.y;
+  }
+}
+
 fn Main() -> i32 {
   var t1: {.x: i32, .y: i32} = {.x = 5, .y = 2};
   var t2: {.x: i32, .y: i32} = {.x = 5, .y = 4};

+ 8 - 3
explorer/testdata/struct/fail_equality_type.carbon

@@ -7,12 +7,17 @@
 
 package ExplorerTest api;
 
+// TODO: Implement this with some kind of reflection?
+external impl {.x: i32, .y: i32} as EqWith(Self) {
+  fn Equal[me: Self](other: Self) -> Bool {
+    return me.x == other.x and me.y == other.y;
+  }
+}
+
 fn Main() -> i32 {
   var t1: {.x: i32, .y: i32} = {.x = 5, .y = 2};
   var t2: {.x: i32,} = {.x = 5,};
-  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/struct/fail_equality_type.carbon:[[@LINE+3]]: type error in ==
-  // CHECK: expected: {.x: i32, .y: i32}
-  // CHECK: actual: {.x: i32}
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/struct/fail_equality_type.carbon:[[@LINE+1]]: {.x: i32, .y: i32} is not equality comparable with {.x: i32} (could not find implementation of interface EqWith(U = {.x: i32}) for {.x: i32, .y: i32})
   if (t1 == t2) {
     return 1;
   } else {

+ 1 - 3
explorer/testdata/tuple/fail_equality_type.carbon

@@ -10,9 +10,7 @@ package ExplorerTest api;
 fn Main() -> i32 {
   var t1: (i32, i32) = (5, 2);
   var t2: (i32,) = (5,);
-  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/tuple/fail_equality_type.carbon:[[@LINE+3]]: type error in ==
-  // CHECK: expected: (i32, i32)
-  // CHECK: actual: (i32)
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/tuple/fail_equality_type.carbon:[[@LINE+1]]: (i32, i32) is not equality comparable with (i32) (could not find implementation of interface EqWith(U = (i32)) for (i32, i32))
   if (t1 == t2) {
     return 1;
   } else {