Explorar el Código

Remove VariantMatch; use CARBON_KIND_SWITCH for std::variants (#5437)

Teach CARBON_KIND_SWITCH to handle mutable lvalues and rvalues, and
CARBON_KIND to forward along rvalues so that it's possible to write
`case CARBON_KIND(const T& t)`, `case CARBON_KIND(T& t)`, and `case
CARBON_KIND(T&& t)`, depending on the type that was passed to
CARBON_KIND_SWITCH.

Replace all uses of VariantMatch with their equivalent of a switch using
CARBON_KIND_SWITCH, and remove the VariantMatch helper from the
codebase.
Dana Jansens hace 11 meses
padre
commit
517c4d3c20

+ 0 - 9
common/BUILD

@@ -553,15 +553,6 @@ cc_test(
     ],
 )
 
-cc_library(
-    name = "variant_helpers",
-    hdrs = ["variant_helpers.h"],
-    deps = [
-        ":error",
-        "@llvm-project//llvm:Support",
-    ],
-)
-
 # The base version source file only uses non-stamped parts of the version
 # information so we expand it once here without any stamping.
 expand_version_build_info(

+ 0 - 42
common/variant_helpers.h

@@ -1,42 +0,0 @@
-// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
-// Exceptions. See /LICENSE for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef CARBON_COMMON_VARIANT_HELPERS_H_
-#define CARBON_COMMON_VARIANT_HELPERS_H_
-
-#include <variant>
-
-#include "common/error.h"
-#include "llvm/ADT/StringRef.h"
-
-namespace Carbon {
-
-namespace Internal {
-
-// Form an overload set from a list of functions. For example:
-//
-// ```
-// auto overloaded = Overload{[] (int) {}, [] (float) {}};
-// ```
-template <typename... Fs>
-struct Overload : Fs... {
-  using Fs::operator()...;
-};
-template <typename... Fs>
-Overload(Fs...) -> Overload<Fs...>;
-
-}  // namespace Internal
-
-// Pattern-match against the type of the value stored in the variant `V`. Each
-// element of `fs` should be a function that takes one or more of the variant
-// values in `V`.
-template <typename V, typename... Fs>
-auto VariantMatch(V&& v, Fs&&... fs) -> decltype(auto) {
-  return std::visit(Internal::Overload{std::forward<Fs&&>(fs)...},
-                    std::forward<V&&>(v));
-}
-
-}  // namespace Carbon
-
-#endif  // CARBON_COMMON_VARIANT_HELPERS_H_

+ 16 - 13
toolchain/base/kind_switch.h

@@ -190,7 +190,8 @@ template <typename SwitchT, typename CaseFnT>
 consteval auto ForCase() -> auto {
   using CaseT = llvm::function_traits<CaseFnT>::template arg_t<0>;
   if constexpr (IsStdVariant<SwitchT>) {
-    return CaseValueOfTypeInStdVariant<CaseT, SwitchT>;
+    using NoRefCaseT = std::remove_cvref_t<CaseT>;
+    return CaseValueOfTypeInStdVariant<NoRefCaseT, SwitchT>;
   } else {
     using KindT = llvm::function_traits<
         decltype(&std::remove_cvref_t<SwitchT>::kind)>::result_t;
@@ -204,7 +205,8 @@ template <typename CaseFnT, typename SwitchT>
 auto Cast(SwitchT&& kind_switch_value) -> decltype(auto) {
   using CaseT = llvm::function_traits<CaseFnT>::template arg_t<0>;
   if constexpr (IsStdVariant<SwitchT>) {
-    return std::get<CaseT>(kind_switch_value);
+    using NoRefCaseT = std::remove_cvref_t<CaseT>;
+    return std::get<NoRefCaseT>(std::forward<SwitchT>(kind_switch_value));
   } else {
     return kind_switch_value.template As<CaseT>();
   }
@@ -217,9 +219,9 @@ auto Cast(SwitchT&& kind_switch_value) -> decltype(auto) {
 }  // namespace Carbon::Internal::Kind
 
 // Produces a switch statement on value.kind().
-#define CARBON_KIND_SWITCH(value)                            \
-  switch (                                                   \
-      const auto& carbon_internal_kind_switch_value = value; \
+#define CARBON_KIND_SWITCH(value)                       \
+  switch (                                              \
+      auto&& carbon_internal_kind_switch_value = value; \
       ::Carbon::Internal::Kind::SwitchOn(carbon_internal_kind_switch_value))
 
 // Produces a case-compatible block of code that also instantiates a local typed
@@ -228,14 +230,15 @@ auto Cast(SwitchT&& kind_switch_value) -> decltype(auto) {
 // This uses `if` to scope the variable, and provides a dangling `else` in order
 // to prevent accidental `else` use. The label allows `:` to follow the macro
 // name, making it look more like a typical `case`.
-#define CARBON_KIND(typed_variable_decl)                                \
-  ::Carbon::Internal::Kind::ForCase<                                    \
-      decltype(carbon_internal_kind_switch_value),                      \
-      decltype([]([[maybe_unused]] typed_variable_decl) {})>()          \
-      : if (typed_variable_decl = ::Carbon::Internal::Kind::Cast<       \
-                decltype([]([[maybe_unused]] typed_variable_decl) {})>( \
-                carbon_internal_kind_switch_value);                     \
-            false) {}                                                   \
+#define CARBON_KIND(typed_variable_decl)                                   \
+  ::Carbon::Internal::Kind::ForCase<                                       \
+      decltype(carbon_internal_kind_switch_value),                         \
+      decltype([]([[maybe_unused]] typed_variable_decl) {})>()             \
+      : if (typed_variable_decl = ::Carbon::Internal::Kind::Cast<          \
+                decltype([]([[maybe_unused]] typed_variable_decl) {})>(    \
+                std::forward<decltype(carbon_internal_kind_switch_value)>( \
+                    carbon_internal_kind_switch_value));                   \
+            false) {}                                                      \
   else [[maybe_unused]] CARBON_INTERNAL_KIND_LABEL(__LINE__)
 
 #endif  // CARBON_TOOLCHAIN_BASE_KIND_SWITCH_H_

+ 0 - 1
toolchain/check/BUILD

@@ -178,7 +178,6 @@ cc_library(
         "//common:find",
         "//common:map",
         "//common:ostream",
-        "//common:variant_helpers",
         "//common:vlog",
         "//toolchain/base:kind_switch",
         "//toolchain/base:pretty_stack_trace_function",

+ 18 - 17
toolchain/check/deferred_definition_worklist.cpp

@@ -7,8 +7,8 @@
 #include <algorithm>
 #include <optional>
 
-#include "common/variant_helpers.h"
 #include "common/vlog.h"
+#include "toolchain/base/kind_switch.h"
 #include "toolchain/check/handle.h"
 
 namespace Carbon::Check {
@@ -106,22 +106,23 @@ auto DeferredDefinitionWorklist::SuspendFinishedScopeAndPush(Context& context)
 auto DeferredDefinitionWorklist::Pop(
     llvm::function_ref<auto(Task&&)->void> handle_fn) -> void {
   if (vlog_stream_) {
-    VariantMatch(
-        worklist_.back(),
-        [&](CheckSkippedDefinition& definition) {
-          CARBON_VLOG("{0}Handle CheckSkippedDefinition {1}\n", VlogPrefix,
-                      definition.definition_index.index);
-        },
-        [&](EnterDeferredDefinitionScope& enter) {
-          CARBON_CHECK(enter.in_deferred_definition_scope);
-          CARBON_VLOG("{0}Handle EnterDeferredDefinitionScope (nested)\n",
-                      VlogPrefix);
-        },
-        [&](LeaveDeferredDefinitionScope& leave) {
-          bool nested = leave.in_deferred_definition_scope;
-          CARBON_VLOG("{0}Handle LeaveDeferredDefinitionScope {1}\n",
-                      VlogPrefix, nested ? "(nested)" : "(non-nested)");
-        });
+    CARBON_KIND_SWITCH(worklist_.back()) {
+      case CARBON_KIND(const CheckSkippedDefinition& definition):
+        CARBON_VLOG("{0}Handle CheckSkippedDefinition {1}\n", VlogPrefix,
+                    definition.definition_index.index);
+        break;
+      case CARBON_KIND(const EnterDeferredDefinitionScope& enter):
+        CARBON_CHECK(enter.in_deferred_definition_scope);
+        CARBON_VLOG("{0}Handle EnterDeferredDefinitionScope (nested)\n",
+                    VlogPrefix);
+        break;
+      case CARBON_KIND(const LeaveDeferredDefinitionScope& leave): {
+        bool nested = leave.in_deferred_definition_scope;
+        CARBON_VLOG("{0}Handle LeaveDeferredDefinitionScope {1}\n", VlogPrefix,
+                    nested ? "(nested)" : "(non-nested)");
+        break;
+      }
+    }
   }
 
   handle_fn(std::move(worklist_.back()));

+ 1 - 1
toolchain/lex/BUILD

@@ -192,7 +192,7 @@ cc_library(
         ":token_kind",
         ":tokenized_buffer",
         "//common:check",
-        "//common:variant_helpers",
+        "//toolchain/base:kind_switch",
         "//toolchain/base:shared_value_stores",
         "//toolchain/diagnostics:diagnostic_emitter",
         "//toolchain/source:source_buffer",

+ 19 - 21
toolchain/lex/lex.cpp

@@ -10,10 +10,10 @@
 #include <utility>
 
 #include "common/check.h"
-#include "common/variant_helpers.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/Compiler.h"
+#include "toolchain/base/kind_switch.h"
 #include "toolchain/base/shared_value_stores.h"
 #include "toolchain/lex/character_set.h"
 #include "toolchain/lex/helpers.h"
@@ -1112,26 +1112,24 @@ auto Lexer::LexNumericLiteral(llvm::StringRef source_text, ssize_t& position)
   int token_size = literal->text().size();
   position += token_size;
 
-  return VariantMatch(
-      literal->ComputeValue(emitter_),
-      [&](NumericLiteral::IntValue&& value) {
-        return LexTokenWithPayload(TokenKind::IntLiteral,
-                                   buffer_.value_stores_->ints()
-                                       .AddUnsigned(std::move(value.value))
-                                       .AsTokenPayload(),
-                                   byte_offset);
-      },
-      [&](NumericLiteral::RealValue&& value) {
-        auto real_id = buffer_.value_stores_->reals().Add(Real{
-            .mantissa = value.mantissa,
-            .exponent = value.exponent,
-            .is_decimal = (value.radix == NumericLiteral::Radix::Decimal)});
-        return LexTokenWithPayload(TokenKind::RealLiteral, real_id.index,
-                                   byte_offset);
-      },
-      [&](NumericLiteral::UnrecoverableError) {
-        return LexTokenWithPayload(TokenKind::Error, token_size, byte_offset);
-      });
+  CARBON_KIND_SWITCH(literal->ComputeValue(emitter_)) {
+    case CARBON_KIND(NumericLiteral::IntValue && value):
+      return LexTokenWithPayload(TokenKind::IntLiteral,
+                                 buffer_.value_stores_->ints()
+                                     .AddUnsigned(std::move(value.value))
+                                     .AsTokenPayload(),
+                                 byte_offset);
+    case CARBON_KIND(NumericLiteral::RealValue && value): {
+      auto real_id = buffer_.value_stores_->reals().Add(
+          Real{.mantissa = value.mantissa,
+               .exponent = value.exponent,
+               .is_decimal = (value.radix == NumericLiteral::Radix::Decimal)});
+      return LexTokenWithPayload(TokenKind::RealLiteral, real_id.index,
+                                 byte_offset);
+    }
+    case CARBON_KIND(NumericLiteral::UnrecoverableError _):
+      return LexTokenWithPayload(TokenKind::Error, token_size, byte_offset);
+  }
 }
 
 auto Lexer::LexStringLiteral(llvm::StringRef source_text, ssize_t& position)

+ 0 - 1
toolchain/sem_ir/BUILD

@@ -140,7 +140,6 @@ cc_library(
         ":typed_insts",
         "//common:check",
         "//common:raw_string_ostream",
-        "//common:variant_helpers",
         "//toolchain/base:kind_switch",
         "@llvm-project//llvm:Support",
     ],

+ 53 - 36
toolchain/sem_ir/stringify.cpp

@@ -10,7 +10,6 @@
 #include <variant>
 
 #include "common/raw_string_ostream.h"
-#include "common/variant_helpers.h"
 #include "toolchain/base/kind_switch.h"
 #include "toolchain/sem_ir/entity_with_params_base.h"
 #include "toolchain/sem_ir/ids.h"
@@ -122,25 +121,38 @@ class StepStack {
   // safe.
   auto PushArray(llvm::ArrayRef<PushItem> items) -> void {
     for (auto item : llvm::reverse(items)) {
-      VariantMatch(
-          item, [&](InstId inst_id) { PushInstId(inst_id); },
-          [&](llvm::StringRef string) { PushString(string); },
-          [&](NameId name_id) { PushNameId(name_id); },
-          [&](ElementIndex element_index) { PushElementIndex(element_index); },
-          [&](QualifiedNameItem qualified_name) {
-            PushQualifiedName(qualified_name.first, qualified_name.second);
-          },
-          [&](EntityNameItem entity_name) {
-            PushEntityName(entity_name.first, entity_name.second);
-          },
-          [&](EntityNameId entity_name_id) {
-            PushEntityNameId(entity_name_id);
-          },
-          [&](TypeId type_id) { PushTypeId(type_id); },
-          [&](SpecificInterface specific_interface) {
-            PushSpecificInterface(specific_interface);
-          },
-          [&](llvm::ListSeparator* sep) { PushString(*sep); });
+      CARBON_KIND_SWITCH(item) {
+        case CARBON_KIND(InstId inst_id):
+          PushInstId(inst_id);
+          break;
+        case CARBON_KIND(llvm::StringRef string):
+          PushString(string);
+          break;
+        case CARBON_KIND(NameId name_id):
+          PushNameId(name_id);
+          break;
+        case CARBON_KIND(ElementIndex element_index):
+          PushElementIndex(element_index);
+          break;
+        case CARBON_KIND(QualifiedNameItem qualified_name):
+          PushQualifiedName(qualified_name.first, qualified_name.second);
+          break;
+        case CARBON_KIND(EntityNameItem entity_name):
+          PushEntityName(entity_name.first, entity_name.second);
+          break;
+        case CARBON_KIND(EntityNameId entity_name_id):
+          PushEntityNameId(entity_name_id);
+          break;
+        case CARBON_KIND(TypeId type_id):
+          PushTypeId(type_id);
+          break;
+        case CARBON_KIND(SpecificInterface specific_interface):
+          PushSpecificInterface(specific_interface);
+          break;
+        case CARBON_KIND(llvm::ListSeparator * sep):
+          PushString(*sep);
+          break;
+      }
     }
   }
 
@@ -641,28 +653,33 @@ static auto Stringify(const File& sem_ir, StepStack& step_stack)
   Stringifier stringifier(&sem_ir, &step_stack, &out);
 
   while (!step_stack.empty()) {
-    auto step = step_stack.Pop();
-
-    VariantMatch(
-        step,
-        [&](InstId inst_id) {
-          if (!inst_id.has_value()) {
-            out << "<invalid>";
-            return;
-          }
-          auto untyped_inst = sem_ir.insts().Get(inst_id);
-          CARBON_KIND_SWITCH(untyped_inst) {
+    CARBON_KIND_SWITCH(step_stack.Pop()) {
+      case CARBON_KIND(InstId inst_id): {
+        if (!inst_id.has_value()) {
+          out << "<invalid>";
+          break;
+        }
+        auto untyped_inst = sem_ir.insts().Get(inst_id);
+        CARBON_KIND_SWITCH(untyped_inst) {
 #define CARBON_SEM_IR_INST_KIND(InstT)              \
   case CARBON_KIND(InstT typed_inst): {             \
     stringifier.StringifyInst(inst_id, typed_inst); \
     break;                                          \
   }
 #include "toolchain/sem_ir/inst_kind.def"
-          }
-        },
-        [&](llvm::StringRef string) { out << string; },
-        [&](NameId name_id) { out << sem_ir.names().GetFormatted(name_id); },
-        [&](ElementIndex element_index) { out << element_index.index; });
+        }
+        break;
+      }
+      case CARBON_KIND(llvm::StringRef string):
+        out << string;
+        break;
+      case CARBON_KIND(NameId name_id):
+        out << sem_ir.names().GetFormatted(name_id);
+        break;
+      case CARBON_KIND(ElementIndex element_index):
+        out << element_index.index;
+        break;
+    }
   }
 
   return out.TakeStr();