Browse Source

C++ interop: Support more binary operators (#6017)

Already supported: `+`.
Newly supported: `-`, `*`, `/`, `%`, `&`, `|`, `^`, `<<`, `>>`, `==`,
`!=`, `<`, `>`, `<=`, `>=`.
Partially supported due to lack of reference support: `+=`, `-=`, `*=`,
`/=`, `%=`, `&=`, `|=`, `^=`.
Not supported due to lack of reference support: `<<=`, `>>=`.
Not supported (I think Carbon doesn't want overloading these): `&&`,
`||`.

C++ Interop Demo:

```c++
// my_number.h

class MyNumber {
 public:
  explicit MyNumber(int value) : value_(value) {}
  auto value() const -> int { return value_; }
  void set_value(int value) { value_ = value; }

 private:
  int value_;
};

// Arithmetic
auto operator+(MyNumber lhs, MyNumber rhs) -> MyNumber;
auto operator-(MyNumber lhs, MyNumber rhs) -> MyNumber;
auto operator*(MyNumber lhs, MyNumber rhs) -> MyNumber;
auto operator/(MyNumber lhs, MyNumber rhs) -> MyNumber;
auto operator%(MyNumber lhs, MyNumber rhs) -> MyNumber;

// Bitwise
auto operator&(MyNumber lhs, MyNumber rhs) -> MyNumber;
auto operator|(MyNumber lhs, MyNumber rhs) -> MyNumber;
auto operator^(MyNumber lhs, MyNumber rhs) -> MyNumber;
auto operator<<(MyNumber lhs, int shift) -> MyNumber;
auto operator>>(MyNumber lhs, int shift) -> MyNumber;

// Compound Arithmetic
auto operator+=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;
auto operator-=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;
auto operator*=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;
auto operator/=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;
auto operator%=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;

// Compound Bitwise
auto operator&=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;
auto operator|=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;
auto operator^=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull;

// Relational
auto operator==(MyNumber lhs, MyNumber rhs) -> bool;
auto operator!=(MyNumber lhs, MyNumber rhs) -> bool;
auto operator<(MyNumber lhs, MyNumber rhs) -> bool;
auto operator>(MyNumber lhs, MyNumber rhs) -> bool;
auto operator<=(MyNumber lhs, MyNumber rhs) -> bool;
auto operator>=(MyNumber lhs, MyNumber rhs) -> bool;
```

```c++
// my_number.cpp

#include "my_number.h"

// Arithmetic
auto operator+(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() + rhs.value());
}
auto operator-(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() - rhs.value());
}
auto operator*(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() * rhs.value());
}
auto operator/(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() / rhs.value());
}
auto operator%(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() % rhs.value());
}

// Bitwise
auto operator&(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() & rhs.value());
}
auto operator|(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() | rhs.value());
}
auto operator^(MyNumber lhs, MyNumber rhs) -> MyNumber {
  return MyNumber(lhs.value() ^ rhs.value());
}
auto operator<<(MyNumber lhs, int shift) -> MyNumber {
  return MyNumber(lhs.value() << shift);
}
auto operator>>(MyNumber lhs, int shift) -> MyNumber {
  return MyNumber(lhs.value() >> shift);
}

// Compound Arithmetic
auto operator+=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs + rhs);
}
auto operator-=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs - rhs);
}
auto operator*=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs * rhs);
}
auto operator/=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs / rhs);
}
auto operator%=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs % rhs);
}

// Compound Bitwise
auto operator&=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs & rhs);
}
auto operator|=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs | rhs);
}
auto operator^=(MyNumber* _Nonnull lhs, MyNumber rhs) -> MyNumber* _Nonnull {
  return &(*lhs = *lhs ^ rhs);
}

// Relational
auto operator==(MyNumber lhs, MyNumber rhs) -> bool {
  return lhs.value() == rhs.value();
}
auto operator!=(MyNumber lhs, MyNumber rhs) -> bool {
  return lhs.value() != rhs.value();
}
auto operator<(MyNumber lhs, MyNumber rhs) -> bool {
  return lhs.value() < rhs.value();
}
auto operator>(MyNumber lhs, MyNumber rhs) -> bool {
  return lhs.value() > rhs.value();
}
auto operator<=(MyNumber lhs, MyNumber rhs) -> bool {
  return lhs.value() <= rhs.value();
}
auto operator>=(MyNumber lhs, MyNumber rhs) -> bool {
  return lhs.value() >= rhs.value();
}
```

```carbon
// main.carbon

library "Main";

import Core library "io";
import Cpp library "my_number.h";

fn PrintBool(b: bool) {
  if (b) {
    Core.Print(1);
  } else {
    Core.Print(0);
  }
}

fn Run() -> i32 {
  // Arithmetic
  var num1: Cpp.MyNumber = Cpp.MyNumber.MyNumber(14);
  var num2: Cpp.MyNumber = Cpp.MyNumber.MyNumber(5);
  Core.Print(num1.value());
  Core.Print(num2.value());
  Core.Print((num1 + num2).value());
  Core.Print((num1 - num2).value());
  Core.Print((num1 * num2).value());
  Core.Print((num1 / num2).value());
  Core.Print((num1 % num2).value());

  // Bitwise
  var bits1: Cpp.MyNumber = Cpp.MyNumber.MyNumber(12);
  var bits2: Cpp.MyNumber = Cpp.MyNumber.MyNumber(10);
  Core.Print(bits1.value());
  Core.Print(bits2.value());
  Core.Print((bits1 & bits2).value());
  Core.Print((bits1 | bits2).value());
  Core.Print((bits1 ^ bits2).value());
  Core.Print((bits1 << 2).value());
  Core.Print((bits1 >> 1).value());

  // Compound Arithmetic
  var c: Cpp.MyNumber = Cpp.MyNumber.MyNumber(100);
  Core.Print(c.value());
  &c += Cpp.MyNumber.MyNumber(10);
  Core.Print(c.value());
  &c -= Cpp.MyNumber.MyNumber(20);
  Core.Print(c.value());
  &c *= Cpp.MyNumber.MyNumber(2);
  Core.Print(c.value());
  &c /= Cpp.MyNumber.MyNumber(6);
  Core.Print(c.value());
  &c %= Cpp.MyNumber.MyNumber(9);
  Core.Print(c.value());

  // Compound Bitwise
  &c |= Cpp.MyNumber.MyNumber(12);
  Core.Print(c.value());
  &c &= Cpp.MyNumber.MyNumber(7);
  Core.Print(c.value());
  &c ^= Cpp.MyNumber.MyNumber(10);
  Core.Print(c.value());

  // Relational
  var rel1: Cpp.MyNumber = Cpp.MyNumber.MyNumber(20);
  var rel2: Cpp.MyNumber = Cpp.MyNumber.MyNumber(30);
  var rel3: Cpp.MyNumber = Cpp.MyNumber.MyNumber(20);
  Core.Print(rel1.value());
  Core.Print(rel2.value());
  Core.Print(rel3.value());
  PrintBool(rel1 == rel3);
  PrintBool(rel1 != rel2);
  PrintBool(rel1 < rel2);
  PrintBool(rel2 > rel1);
  PrintBool(rel1 <= rel3);
  PrintBool(rel1 >= rel2);

  return 0;
}
```

```shell
$ clang -c my_number.cpp
$ bazel-bin/toolchain/carbon compile main.carbon
$ bazel-bin/toolchain/carbon link my_number.o main.o --output=demo
$ ./demo
14
5
19
9
70
2
4
12
10
8
14
6
48
6
100
110
90
180
30
3
15
7
13
20
30
20
1
1
1
1
1
0
```

Part of https://github.com/carbon-language/carbon-lang/issues/5995.
Boaz Brickner 7 months ago
parent
commit
ee42b2db93

+ 106 - 3
toolchain/check/import_cpp.cpp

@@ -2017,12 +2017,114 @@ auto ImportNameFromCpp(Context& context, SemIR::LocId loc_id,
                                  access);
 }
 
-static auto GetOperatorKind(Context& context, SemIR::LocId loc_id,
-                            llvm::StringLiteral interface_name)
+static auto GetClangOperatorKind(Context& context, SemIR::LocId loc_id,
+                                 llvm::StringLiteral interface_name,
+                                 llvm::StringLiteral op_name)
     -> std::optional<clang::OverloadedOperatorKind> {
+  // Arithmetic Operators.
   if (interface_name == "AddWith") {
+    CARBON_CHECK(op_name == "Op");
     return clang::OO_Plus;
   }
+  if (interface_name == "SubWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_Minus;
+  }
+  if (interface_name == "MulWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_Star;
+  }
+  if (interface_name == "DivWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_Slash;
+  }
+  if (interface_name == "ModWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_Percent;
+  }
+
+  // Bitwise Operators.
+  if (interface_name == "BitAndWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_Amp;
+  }
+  if (interface_name == "BitOrWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_Pipe;
+  }
+  if (interface_name == "BitXorWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_Caret;
+  }
+  if (interface_name == "LeftShiftWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_LessLess;
+  }
+  if (interface_name == "RightShiftWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_GreaterGreater;
+  }
+
+  // Compound Assignment Arithmetic Operators.
+  if (interface_name == "AddAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_PlusEqual;
+  }
+  if (interface_name == "SubAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_MinusEqual;
+  }
+  if (interface_name == "MulAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_StarEqual;
+  }
+  if (interface_name == "DivAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_SlashEqual;
+  }
+  if (interface_name == "ModAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_PercentEqual;
+  }
+
+  // Compound Assignment Bitwise Operators.
+  if (interface_name == "BitAndAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_AmpEqual;
+  }
+  if (interface_name == "BitOrAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_PipeEqual;
+  }
+  if (interface_name == "BitXorAssignWith") {
+    CARBON_CHECK(op_name == "Op");
+    return clang::OO_CaretEqual;
+  }
+  // TODO: Add support for `LeftShiftAssignWith` (`OO_LessLessEqual`) and
+  // `RightShiftAssignWith` (`OO_GreaterGreaterEqual`) when references are
+  // supported.
+
+  // Relational Operators.
+  if (interface_name == "EqWith") {
+    if (op_name == "Equal") {
+      return clang::OO_EqualEqual;
+    }
+    CARBON_CHECK(op_name == "NotEqual");
+    return clang::OO_ExclaimEqual;
+  }
+  if (interface_name == "OrderedWith") {
+    if (op_name == "Less") {
+      return clang::OO_Less;
+    }
+    if (op_name == "Greater") {
+      return clang::OO_Greater;
+    }
+    if (op_name == "LessOrEquivalent") {
+      return clang::OO_LessEqual;
+    }
+    CARBON_CHECK(op_name == "GreaterOrEquivalent");
+    return clang::OO_GreaterEqual;
+  }
 
   context.TODO(loc_id, llvm::formatv("Unsupported operator interface `{0}`",
                                      interface_name));
@@ -2038,7 +2140,8 @@ auto ImportOperatorFromCpp(Context& context, SemIR::LocId loc_id, Operator op)
         builder.Note(loc_id, InCppOperatorLookup, op.interface_name.str());
       });
 
-  auto op_kind = GetOperatorKind(context, loc_id, op.interface_name);
+  auto op_kind =
+      GetClangOperatorKind(context, loc_id, op.interface_name, op.op_name);
   if (!op_kind) {
     return SemIR::ScopeLookupResult::MakeNotFound();
   }

File diff suppressed because it is too large
+ 931 - 71
toolchain/check/testdata/interop/cpp/function/operators.carbon


Some files were not shown because too many files changed in this diff