Bläddra i källkod

Introduce FindIfOrNull() FindIfOrNone() and Contains() (#5322)

`FindIfOrNull` returns a pointer to the element in the range if it's
found, and nullptr otherwise. `FindIfOrNone` returns a copy of the
element in the range if it's found, and `T::None` (for a range of
elements of type `T`) otherwise. `Contains` returns a bool indicating
whether the element in the range is found.

These functions replace `llvm::find()` and `llvm::find_if()` when you
want a single answer back instead of an iterator. This avoids the need
to check against `end()`, allowing the return condition to be tested as
a standard bool.

We replace uses of `find()` and `find_if()` that did not require an
iterator with these new helpers.

Note that the return type of `FindIfOrNull` is a pointer since we can
not write `optional<T&>`, which must be tested for null. If the null
check is omitted, UB occurs and the resulting code may end up with an
incorrect pointer (https://crbug.com/40153300) into the range (or
elsewhere), rather than a null dereference. And this would be very
confusing to debug. Hopefully debug builds and sanitizers keep this from
being an issue we sink a bunch of time into debugging.
Dana Jansens 1 år sedan
förälder
incheckning
9a6c74f0cd

+ 22 - 0
common/BUILD

@@ -170,6 +170,28 @@ cc_test(
     ],
 )
 
+cc_library(
+    name = "find",
+    hdrs = ["find.h"],
+    deps = [
+        ":check",
+        ":ostream",
+        ":raw_string_ostream",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_test(
+    name = "find_test",
+    size = "small",
+    srcs = ["find_test.cpp"],
+    deps = [
+        ":find",
+        "//testing/base:gtest_main",
+        "@googletest//:gtest",
+    ],
+)
+
 cc_library(
     name = "hashing",
     srcs = ["hashing.cpp"],

+ 92 - 0
common/find.h

@@ -0,0 +1,92 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_COMMON_FIND_H_
+#define CARBON_COMMON_FIND_H_
+
+#include <concepts>
+#include <type_traits>
+
+#include "llvm/ADT/STLExtras.h"
+
+namespace Carbon {
+
+namespace Internal {
+
+template <typename Range>
+using RangePointerType = typename std::iterator_traits<decltype(std::begin(
+    std::declval<Range>()))>::pointer;
+
+template <typename Range>
+using RangeValueType = typename std::iterator_traits<decltype(std::begin(
+    std::declval<Range>()))>::value_type;
+
+template <typename Range, typename Pred>
+concept IsValidFindPredicate =
+    requires(const RangeValueType<Range>& elem, Pred pred) {
+      { pred(elem) } -> std::convertible_to<bool>;
+    };
+
+template <typename A, typename B>
+concept IsComparable = requires(const A& a, const B& b) {
+  { a == b } -> std::convertible_to<bool>;
+};
+
+template <typename Range>
+concept RangeValueHasNoneType = requires {
+  { RangeValueType<Range>::None } -> std::convertible_to<RangeValueType<Range>>;
+};
+
+}  // namespace Internal
+
+// Finds a value in the given `range` by testing the `predicate`. Returns a
+// pointer to the value from the range on success, and nullptr if nothing is
+// found.
+//
+// This is similar to `std::find_if()` but returns a pointer to the value
+// instead of an iterator that must be tested against `end()`.
+template <typename Range, typename Pred>
+  requires Internal::IsValidFindPredicate<Range, Pred>
+constexpr auto FindIfOrNull(Range&& range, Pred predicate)
+    -> Internal::RangePointerType<Range> {
+  auto it = llvm::find_if(range, predicate);
+  if (it != range.end()) {
+    return std::addressof(*it);
+  } else {
+    return nullptr;
+  }
+}
+
+// Finds a value in the given `range` by testing the `predicate` and returns a
+// copy of it. If no match is found, returns `T::None` where the input range is
+// over values of type `T`.
+template <typename Range, typename Pred>
+  requires Internal::IsValidFindPredicate<Range, Pred> &&
+           Internal::RangeValueHasNoneType<Range> &&
+           std::copy_constructible<Internal::RangeValueType<Range>>
+constexpr auto FindIfOrNone(Range&& range, Pred predicate)
+    -> Internal::RangeValueType<Range> {
+  auto it = llvm::find_if(range, predicate);
+  if (it != range.end()) {
+    return *it;
+  } else {
+    return Internal::RangeValueType<Range>::None;
+  }
+}
+
+// Finds a value in the given `range` by comparing to `query`. Returns a
+// pointer to the value from the range on success, and nullptr if nothing is
+// found.
+//
+// This is similar to `std::find_if()` but returns a pointer to the value
+// instead of an iterator that must be tested against `end()`.
+template <typename Range, typename Query = Internal::RangeValueType<Range>>
+  requires Internal::IsComparable<Query, Internal::RangeValueType<Range>>
+constexpr auto Contains(Range&& range, const Query& query) -> bool {
+  return llvm::find(range, query) != range.end();
+}
+
+}  // namespace Carbon
+
+#endif  // CARBON_COMMON_FIND_H_

+ 75 - 0
common/find_test.cpp

@@ -0,0 +1,75 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "common/find.h"
+
+#include <gtest/gtest.h>
+
+#include <vector>
+
+namespace Carbon {
+namespace {
+
+struct NoneType {
+  static const NoneType None;
+  int i;
+
+  friend auto operator==(NoneType, NoneType) -> bool = default;
+};
+
+const NoneType NoneType::None = {.i = -1};
+
+TEST(FindTest, ReturnType) {
+  const std::vector<int> c;
+  std::vector<int> m;
+
+  auto pred = [](int) { return true; };
+  static_assert(std::same_as<decltype(FindIfOrNull(c, pred)), const int*>);
+  static_assert(std::same_as<decltype(FindIfOrNull(m, pred)), int*>);
+}
+
+TEST(FindTest, FindIfOrNull) {
+  auto make_pred = [](int query) {
+    return [=](int elem) { return query == elem; };
+  };
+
+  std::vector<int> empty;
+  EXPECT_EQ(FindIfOrNull(empty, make_pred(0)), nullptr);
+
+  std::vector<int> range = {1, 2};
+  EXPECT_EQ(FindIfOrNull(range, make_pred(0)), nullptr);
+  // NOLINTNEXTLINE(readability-container-data-pointer)
+  EXPECT_EQ(FindIfOrNull(range, make_pred(1)), &range[0]);
+  EXPECT_EQ(FindIfOrNull(range, make_pred(2)), &range[1]);
+  EXPECT_EQ(FindIfOrNull(range, make_pred(3)), nullptr);
+}
+
+TEST(FindTest, FindIfOrNone) {
+  auto make_pred = [](NoneType query) {
+    return [=](NoneType elem) { return query == elem; };
+  };
+
+  std::vector<NoneType> empty;
+  EXPECT_EQ(FindIfOrNone(empty, make_pred(NoneType{0})).i, -1);
+
+  std::vector<NoneType> range = {NoneType{1}, NoneType{2}};
+  EXPECT_EQ(FindIfOrNone(range, make_pred(NoneType{0})).i, -1);
+  EXPECT_EQ(FindIfOrNone(range, make_pred(NoneType{1})).i, 1);
+  EXPECT_EQ(FindIfOrNone(range, make_pred(NoneType{2})).i, 2);
+  EXPECT_EQ(FindIfOrNone(range, make_pred(NoneType{3})).i, -1);
+}
+
+TEST(FindTest, Contains) {
+  std::vector<int> empty;
+  EXPECT_EQ(Contains(empty, 0), false);
+
+  std::vector<int> range = {1, 2};
+  EXPECT_EQ(Contains(range, 0), false);
+  EXPECT_EQ(Contains(range, 1), true);
+  EXPECT_EQ(Contains(range, 2), true);
+  EXPECT_EQ(Contains(range, 3), false);
+}
+
+}  // namespace
+}  // namespace Carbon

+ 1 - 0
testing/file_test/BUILD

@@ -43,6 +43,7 @@ cc_library(
         "//common:check",
         "//common:error",
         "//common:exe_path",
+        "//common:find",
         "//common:init_llvm",
         "//common:ostream",
         "//common:raw_string_ostream",

+ 6 - 6
testing/file_test/test_file.cpp

@@ -11,6 +11,7 @@
 
 #include "common/check.h"
 #include "common/error.h"
+#include "common/find.h"
 #include "common/raw_string_ostream.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/JSON.h"
@@ -132,14 +133,13 @@ static auto AutoFillDidOpenParams(llvm::json::Object* params,
   }
 
   CARBON_ASSIGN_OR_RETURN(auto file_path, ExtractFilePathFromUri(*uri));
-  const auto* split_it =
-      llvm::find_if(splits, [&](const TestFile::Split& split) {
-        return split.filename == file_path;
-      });
-  if (split_it == splits.end()) {
+  const auto* split = FindIfOrNull(splits, [&](const TestFile::Split& split) {
+    return split.filename == file_path;
+  });
+  if (!split) {
     return ErrorBuilder() << "No split found for uri: " << *uri;
   }
-  attr_it->second = split_it->content;
+  attr_it->second = split->content;
   return Success();
 }
 

+ 1 - 0
toolchain/check/BUILD

@@ -171,6 +171,7 @@ cc_library(
         ":pointer_dereference",
         "//common:check",
         "//common:error",
+        "//common:find",
         "//common:map",
         "//common:ostream",
         "//common:variant_helpers",

+ 4 - 9
toolchain/check/handle_function.cpp

@@ -5,6 +5,7 @@
 #include <optional>
 #include <utility>
 
+#include "common/find.h"
 #include "toolchain/base/kind_switch.h"
 #include "toolchain/check/context.h"
 #include "toolchain/check/control_flow.h"
@@ -90,15 +91,9 @@ static auto FindSelfPattern(Context& context,
     -> SemIR::InstId {
   auto implicit_param_patterns =
       context.inst_blocks().GetOrEmpty(implicit_param_patterns_id);
-  if (const auto* i = llvm::find_if(implicit_param_patterns,
-                                    [&](auto implicit_param_id) {
-                                      return SemIR::IsSelfPattern(
-                                          context.sem_ir(), implicit_param_id);
-                                    });
-      i != implicit_param_patterns.end()) {
-    return *i;
-  }
-  return SemIR::InstId::None;
+  return FindIfOrNone(implicit_param_patterns, [&](auto implicit_param_id) {
+    return SemIR::IsSelfPattern(context.sem_ir(), implicit_param_id);
+  });
 }
 
 // Diagnoses issues with the modifiers, removing modifiers that shouldn't be

+ 1 - 0
toolchain/parse/BUILD

@@ -153,6 +153,7 @@ cc_library(
         ":node_kind",
         "//common:check",
         "//common:error",
+        "//common:find",
         "//common:ostream",
         "//common:struct_reflection",
         "//toolchain/base:value_store",

+ 2 - 1
toolchain/parse/extract.cpp

@@ -9,6 +9,7 @@
 #include <utility>
 
 #include "common/error.h"
+#include "common/find.h"
 #include "common/struct_reflection.h"
 #include "toolchain/parse/tree.h"
 #include "toolchain/parse/tree_and_subtrees.h"
@@ -210,7 +211,7 @@ auto NodeExtractor::MatchesNodeIdOneOf(
       *trace_ << "\n";
     }
     return false;
-  } else if (llvm::find(kinds, node_kind) == kinds.end()) {
+  } else if (!Contains(kinds, node_kind)) {
     if (trace_) {
       *trace_ << "NodeIdOneOf error: wrong kind " << node_kind << ", expected ";
       trace_kinds();

+ 1 - 0
toolchain/testing/BUILD

@@ -49,6 +49,7 @@ cc_library(
     testonly = 1,
     hdrs = ["coverage_helper.h"],
     deps = [
+        "//common:find",
         "//common:set",
         "@googletest//:gtest",
         "@llvm-project//llvm:Support",

+ 2 - 1
toolchain/testing/coverage_helper.h

@@ -10,6 +10,7 @@
 #include <fstream>
 #include <string>
 
+#include "common/find.h"
 #include "common/set.h"
 #include "llvm/ADT/StringExtras.h"
 #include "re2/re2.h"
@@ -52,7 +53,7 @@ auto TestKindCoverage(const std::string& manifest_path,
 
   llvm::SmallVector<llvm::StringRef> missing_kinds;
   for (auto kind : kinds) {
-    if (llvm::find(untested_kinds, kind) != untested_kinds.end()) {
+    if (Contains(untested_kinds, kind)) {
       EXPECT_FALSE(covered_kinds.Erase(kind.name()))
           << "Kind " << kind
           << " has coverage even though none was expected. If this has "