Просмотр исходного кода

Add `Value` decomposition and use it to implement `Substitute` (#2389)

Add a generic mechanism to decompose a `Value` and rebuild it, and use that to implement `Substitute`'s recursive transformation of values instead of a hand-rolled decomposition. This means `Substitute` now covers all kinds of values, whereas previously it used to be unable to transform some values, and should be less work to add new kinds of value.

We can use the same mechanism for various other things: structural dumping of values, equality comparisons, and value instantiation in the interpreter would all benefit from this. But in this change I'm just switching `Substitute` to this as a first step.
Richard Smith 3 лет назад
Родитель
Сommit
43283cb516

+ 26 - 0
explorer/README.md

@@ -57,6 +57,32 @@ builders in [`error_builders.h`](common/error_builders.h). Errors caused by bugs
 in `explorer` itself should be reported with
 [`CHECK` or `FATAL`](../common/check.h).
 
+### `Decompose` functions
+
+Many of explorer's data structures provide a `Decompose` method, which allows
+simple data types to be generically decomposed into their fields. The
+`Decompose` function for a type takes a function and calls it with the fields of
+that type. For example:
+
+```
+class MyType {
+ public:
+  MyType(Type1 arg1, Type2 arg2) : arg1_(arg1), arg2_(arg2) {}
+
+  template <typename F>
+  auto Decompose(F f) const { return f(arg1_, arg2_); }
+
+ private:
+  Type1 arg1_;
+  Type2 arg2_;
+};
+```
+
+Where possible, a value equivalent to the original value should be created by
+passing the given arguments to the constructor of the type. For example,
+`my_value.Decompose([](auto ...args) { return MyType(args...); })` should
+recreate the original value.
+
 ## Example Programs (Regression Tests)
 
 The [`testdata/`](testdata/) subdirectory includes some example programs with

+ 5 - 0
explorer/ast/bindings.h

@@ -55,6 +55,11 @@ class Bindings {
   Bindings(BindingMap args, NoWitnessesTag /*unused*/)
       : args_(std::move(args)) {}
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(args_, witnesses_);
+  }
+
   // Add a value, and perhaps a witness, for a generic binding.
   void Add(Nonnull<const GenericBinding*> binding, Nonnull<const Value*> value,
            std::optional<Nonnull<const Value*>> witness);

+ 8 - 0
explorer/ast/element.cpp

@@ -44,6 +44,14 @@ auto NamedElement::declaration() const
   return std::nullopt;
 }
 
+auto NamedElement::struct_member() const
+    -> std::optional<Nonnull<const NamedValue*>> {
+  if (const auto* member = element_.dyn_cast<const NamedValue*>()) {
+    return member;
+  }
+  return std::nullopt;
+}
+
 void NamedElement::Print(llvm::raw_ostream& out) const { out << name(); }
 
 // Prints the Element

+ 47 - 0
explorer/ast/element.h

@@ -20,6 +20,14 @@ class Value;
 
 // A NamedValue represents a value with a name, such as a single struct field.
 struct NamedValue {
+  NamedValue(std::string name, Nonnull<const Value*> value)
+      : name(std::move(name)), value(value) {}
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(name, value);
+  }
+
   // The field name.
   std::string name;
 
@@ -37,6 +45,11 @@ class Element {
  public:
   virtual ~Element() = default;
 
+  // Call `f` on this value, cast to its most-derived type. `R` specifies the
+  // expected return type of `f`.
+  template <typename R, typename F>
+  auto Visit(F f) const -> R;
+
   // Prints the Member
   virtual void Print(llvm::raw_ostream& out) const = 0;
 
@@ -63,6 +76,15 @@ class NamedElement : public Element {
   explicit NamedElement(Nonnull<const Declaration*> declaration);
   explicit NamedElement(Nonnull<const NamedValue*> struct_member);
 
+  template <typename F>
+  auto Decompose(F f) const {
+    if (auto decl = declaration()) {
+      return f(*decl);
+    } else {
+      return f(*struct_member());
+    }
+  }
+
   // Prints the element's name
   void Print(llvm::raw_ostream& out) const override;
 
@@ -77,6 +99,8 @@ class NamedElement : public Element {
   auto name() const -> std::string_view;
   // A declaration of the member, if any exists.
   auto declaration() const -> std::optional<Nonnull<const Declaration*>>;
+  // A name and type pair, if this is a struct member.
+  auto struct_member() const -> std::optional<Nonnull<const NamedValue*>>;
 
  private:
   const llvm::PointerUnion<Nonnull<const Declaration*>,
@@ -92,6 +116,11 @@ class PositionalElement : public Element {
   explicit PositionalElement(int index, Nonnull<const Value*> type)
       : Element(ElementKind::PositionalElement), index_(index), type_(type) {}
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(index_, type_);
+  }
+
   // Prints the element
   void Print(llvm::raw_ostream& out) const override;
 
@@ -118,6 +147,11 @@ class BaseElement : public Element {
   explicit BaseElement(Nonnull<const Value*> type)
       : Element(ElementKind::BaseElement), type_(type) {}
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(type_);
+  }
+
   // Prints the Member
   void Print(llvm::raw_ostream& out) const override;
 
@@ -133,6 +167,19 @@ class BaseElement : public Element {
  private:
   const Nonnull<const Value*> type_;
 };
+
+template <typename R, typename F>
+auto Element::Visit(F f) const -> R {
+  switch (kind()) {
+    case ElementKind::NamedElement:
+      return f(static_cast<const NamedElement*>(this));
+    case ElementKind::PositionalElement:
+      return f(static_cast<const PositionalElement*>(this));
+    case ElementKind::BaseElement:
+      return f(static_cast<const BaseElement*>(this));
+  }
+}
+
 }  // namespace Carbon
 
 #endif  // CARBON_EXPLORER_AST_ELEMENT_H_

+ 4 - 0
explorer/interpreter/BUILD

@@ -15,6 +15,10 @@ cc_library(
     hdrs = [
         "action.h",
         "value.h",
+        "value_transform.h",
+    ],
+    textual_hdrs = [
+        "value_kinds.def",
     ],
     # Exposed to resolve `member_test` dependencies.
     visibility = ["//explorer/ast:__pkg__"],

+ 2 - 2
explorer/interpreter/interpreter.cpp

@@ -264,7 +264,7 @@ auto Interpreter::CreateStruct(const std::vector<FieldInitializer>& fields,
   CARBON_CHECK(fields.size() == values.size());
   std::vector<NamedValue> elements;
   for (size_t i = 0; i < fields.size(); ++i) {
-    elements.push_back({.name = fields[i].name(), .value = values[i]});
+    elements.push_back({fields[i].name(), values[i]});
   }
 
   return arena_->New<StructValue>(std::move(elements));
@@ -767,7 +767,7 @@ auto Interpreter::Convert(Nonnull<const Value*> value,
             CARBON_ASSIGN_OR_RETURN(
                 Nonnull<const Value*> val,
                 Convert(*old_value, field_type, source_loc));
-            new_elements.push_back({.name = field_name, .value = val});
+            new_elements.push_back({field_name, val});
           }
           return arena_->New<StructValue>(std::move(new_elements));
         }

+ 142 - 252
explorer/interpreter/type_checker.cpp

@@ -28,6 +28,7 @@
 #include "explorer/interpreter/interpreter.h"
 #include "explorer/interpreter/pattern_analysis.h"
 #include "explorer/interpreter/value.h"
+#include "explorer/interpreter/value_transform.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -436,8 +437,7 @@ auto TypeChecker::FieldTypes(const NominalClassType& class_type) const
         const auto& var = cast<VariableDeclaration>(*m);
         Nonnull<const Value*> field_type =
             Substitute(class_type.bindings(), &var.binding().static_type());
-        field_types.push_back(
-            {.name = var.binding().name(), .value = field_type});
+        field_types.push_back({var.binding().name(), field_type});
         break;
       }
       default:
@@ -933,11 +933,11 @@ auto TypeChecker::ArgumentDeduction::Deduce(Nonnull<const Value*> param,
       };
       for (size_t i = 0; i < param_struct.fields().size(); ++i) {
         NamedValue param_field = param_struct.fields()[i];
-        NamedValue arg_field;
+        std::optional<NamedValue> arg_field;
         if (allow_implicit_conversion) {
           if (std::optional<NamedValue> maybe_arg_field =
                   FindField(arg_struct.fields(), param_field.name)) {
-            arg_field = *maybe_arg_field;
+            arg_field = maybe_arg_field;
           } else {
             return diagnose_missing_field(arg_struct, param_field, true);
           }
@@ -946,13 +946,13 @@ auto TypeChecker::ArgumentDeduction::Deduce(Nonnull<const Value*> param,
             return diagnose_missing_field(arg_struct, param_field, true);
           }
           arg_field = arg_struct.fields()[i];
-          if (param_field.name != arg_field.name) {
+          if (param_field.name != arg_field->name) {
             return ProgramError(source_loc_)
                    << "mismatch in field names, `" << param_field.name
-                   << "` != `" << arg_field.name << "`";
+                   << "` != `" << arg_field->name << "`";
           }
         }
-        CARBON_RETURN_IF_ERROR(Deduce(param_field.value, arg_field.value,
+        CARBON_RETURN_IF_ERROR(Deduce(param_field.value, arg_field->value,
                                       allow_implicit_conversion));
       }
       if (param_struct.fields().size() != arg_struct.fields().size()) {
@@ -1786,259 +1786,149 @@ auto TypeChecker::RebuildValue(Nonnull<const Value*> value) const
   return SubstituteImpl(Bindings(), value);
 }
 
-auto TypeChecker::SubstituteImpl(const Bindings& bindings,
-                                 Nonnull<const Value*> type) const
-    -> Nonnull<const Value*> {
-  auto substitute_into_bindings =
-      [&](Nonnull<const Bindings*> inner_bindings) -> Nonnull<const Bindings*> {
-    BindingMap values;
-    for (const auto& [name, value] : inner_bindings->args()) {
-      values[name] = SubstituteImpl(bindings, value);
-    }
-    ImplWitnessMap witnesses;
-    for (const auto& [name, value] : inner_bindings->witnesses()) {
-      witnesses[name] = SubstituteImpl(bindings, value);
-    }
-    if (values == inner_bindings->args() &&
-        witnesses == inner_bindings->witnesses()) {
-      return inner_bindings;
-    }
-    return arena_->New<Bindings>(std::move(values), std::move(witnesses));
-  };
-
-  switch (type->kind()) {
-    case Value::Kind::VariableType: {
-      auto it = bindings.args().find(&cast<VariableType>(*type).binding());
-      if (it == bindings.args().end()) {
-        if (trace_stream_) {
-          **trace_stream_ << "substitution: no value for binding " << *type
-                          << ", leaving alone\n";
-        }
-        return type;
-      } else {
-        return it->second;
-      }
-    }
-    case Value::Kind::AssociatedConstant: {
-      const auto& assoc = cast<AssociatedConstant>(*type);
-      Nonnull<const Value*> base = SubstituteImpl(bindings, &assoc.base());
-      const auto* interface =
-          cast<InterfaceType>(SubstituteImpl(bindings, &assoc.interface()));
-      // If we're substituting into an associated constant, we may now be able
-      // to rewrite it to a concrete value.
-      if (auto rewritten_value =
-              LookupRewriteInTypeOf(base, interface, &assoc.constant())) {
-        return (*rewritten_value)->converted_replacement;
-      }
-      const auto* witness =
-          cast<Witness>(SubstituteImpl(bindings, &assoc.witness()));
-      witness = RefineWitness(witness, base, interface);
-      if (auto rewritten_value =
-              LookupRewriteInWitness(witness, interface, &assoc.constant())) {
-        return (*rewritten_value)->converted_replacement;
-      }
-      return arena_->New<AssociatedConstant>(base, interface, &assoc.constant(),
-                                             witness);
-    }
-    case Value::Kind::TupleType:
-    case Value::Kind::TupleValue: {
-      std::vector<Nonnull<const Value*>> elts;
-      for (const auto& elt : cast<TupleValueBase>(*type).elements()) {
-        elts.push_back(SubstituteImpl(bindings, elt));
-      }
-      if (isa<TupleType>(type)) {
-        return arena_->New<TupleType>(std::move(elts));
-      } else {
-        return arena_->New<TupleValue>(std::move(elts));
-      }
-    }
-    case Value::Kind::StructType: {
-      std::vector<NamedValue> fields;
-      for (const auto& [name, value] : cast<StructType>(*type).fields()) {
-        const auto* new_type = SubstituteImpl(bindings, value);
-        fields.push_back({name, new_type});
+class TypeChecker::SubstituteTransform
+    : public ValueTransform<SubstituteTransform> {
+ public:
+  SubstituteTransform(Nonnull<const TypeChecker*> type_checker,
+                      const Bindings& bindings)
+      : ValueTransform(type_checker->arena_),
+        type_checker_(type_checker),
+        bindings_(bindings) {}
+
+  using ValueTransform::operator();
+
+  // Replace a `VariableType` with its binding value if available.
+  auto operator()(Nonnull<const VariableType*> var_type)
+      -> Nonnull<const Value*> {
+    auto it = bindings_.args().find(&var_type->binding());
+    if (it == bindings_.args().end()) {
+      if (auto& trace_stream = type_checker_->trace_stream_) {
+        **trace_stream << "substitution: no value for binding " << *var_type
+                       << ", leaving alone\n";
       }
-      return arena_->New<StructType>(std::move(fields));
+      return var_type;
+    } else {
+      return it->second;
     }
-    case Value::Kind::FunctionType: {
-      const auto& fn_type = cast<FunctionType>(*type);
-      SubstitutedGenericBindings subst_bindings(this, bindings);
-
-      // Apply substitution to into generic parameters and deduced bindings.
-      std::vector<FunctionType::GenericParameter> generic_parameters;
-      for (const FunctionType::GenericParameter& gp :
-           fn_type.generic_parameters()) {
-        generic_parameters.push_back(
-            {.index = gp.index,
-             .binding =
-                 subst_bindings.SubstituteIntoGenericBinding(gp.binding)});
-      }
-      std::vector<Nonnull<const GenericBinding*>> deduced_bindings;
-      for (Nonnull<const GenericBinding*> gb : fn_type.deduced_bindings()) {
-        deduced_bindings.push_back(
-            subst_bindings.SubstituteIntoGenericBinding(gb));
-      }
+  }
 
-      // Apply substitution to parameter and return types and create the new
-      // function type.
-      const auto* param =
-          SubstituteImpl(subst_bindings.bindings(), &fn_type.parameters());
-      const auto* ret =
-          SubstituteImpl(subst_bindings.bindings(), &fn_type.return_type());
-      return arena_->New<FunctionType>(
-          param, std::move(generic_parameters), ret,
-          std::move(deduced_bindings),
-          std::move(subst_bindings).TakeImplBindings());
-    }
-    case Value::Kind::PointerType: {
-      return arena_->New<PointerType>(
-          SubstituteImpl(bindings, &cast<PointerType>(*type).type()));
-    }
-    case Value::Kind::NominalClassType: {
-      const auto& class_type = cast<NominalClassType>(*type);
-      auto base_type = class_type.base();
-      if (base_type.has_value()) {
-        base_type = cast<NominalClassType>(
-            SubstituteImpl(base_type.value()->bindings(), base_type.value()));
+  // Replace a `BindingWitness` with its binding value if available.
+  auto operator()(Nonnull<const BindingWitness*> witness)
+      -> Nonnull<const Value*> {
+    auto it = bindings_.witnesses().find(witness->binding());
+    if (it == bindings_.witnesses().end()) {
+      if (auto& trace_stream = type_checker_->trace_stream_) {
+        **trace_stream << "substitution: no value for binding " << *witness
+                       << ", leaving alone\n";
       }
-      Nonnull<const NominalClassType*> new_class_type =
-          arena_->New<NominalClassType>(
-              &class_type.declaration(),
-              substitute_into_bindings(&class_type.bindings()), base_type);
-      return new_class_type;
-    }
-    case Value::Kind::InterfaceType: {
-      const auto& iface_type = cast<InterfaceType>(*type);
-      Nonnull<const InterfaceType*> new_iface_type = arena_->New<InterfaceType>(
-          &iface_type.declaration(),
-          substitute_into_bindings(&iface_type.bindings()));
-      return new_iface_type;
+      return witness;
+    } else {
+      return it->second;
     }
-    case Value::Kind::NamedConstraintType: {
-      const auto& constraint_type = cast<NamedConstraintType>(*type);
-      Nonnull<const NamedConstraintType*> new_constraint_type =
-          arena_->New<NamedConstraintType>(
-              &constraint_type.declaration(),
-              substitute_into_bindings(&constraint_type.bindings()));
-      return new_constraint_type;
-    }
-    case Value::Kind::ConstraintType: {
-      const auto& constraint = cast<ConstraintType>(*type);
-      if (auto it = bindings.args().find(constraint.self_binding());
-          it != bindings.args().end()) {
-        // This happens when we substitute into the parameter type of a
-        // function that takes a `T:! Constraint` parameter. In this case we
-        // produce the new type-of-type of the replacement type.
-        Nonnull<const Value*> type_of_type;
-        if (const auto* var_type = dyn_cast<VariableType>(it->second)) {
-          type_of_type = &var_type->binding().static_type();
-        } else if (const auto* assoc_type =
-                       dyn_cast<AssociatedConstant>(it->second)) {
-          type_of_type = GetTypeForAssociatedConstant(assoc_type);
-        } else {
-          type_of_type = arena_->New<TypeType>();
-        }
-        if (trace_stream_) {
-          **trace_stream_ << "substitution: self of constraint " << constraint
-                          << " is substituted, new type of type is "
-                          << *type_of_type << "\n";
-        }
-        // TODO: Should we keep any part of the old constraint -- rewrites,
-        // equality constraints, etc?
-        return type_of_type;
-      }
-      ConstraintTypeBuilder builder(arena_,
-                                    constraint.self_binding()->source_loc());
-      builder.AddAndSubstitute(*this, &constraint, builder.GetSelfType(),
-                               builder.GetSelfWitness(), bindings,
-                               /*add_lookup_contexts=*/true);
-      Nonnull<const ConstraintType*> new_constraint =
-          std::move(builder).Build();
-      if (trace_stream_) {
-        **trace_stream_ << "substitution: " << constraint << " => "
-                        << *new_constraint << "\n";
-      }
-      return new_constraint;
-    }
-    case Value::Kind::ImplWitness: {
-      const auto& witness = cast<ImplWitness>(*type);
-      return arena_->New<ImplWitness>(
-          &witness.declaration(),
-          substitute_into_bindings(&witness.bindings()));
-    }
-    case Value::Kind::BindingWitness: {
-      auto it =
-          bindings.witnesses().find(cast<BindingWitness>(*type).binding());
-      if (it == bindings.witnesses().end()) {
-        if (trace_stream_) {
-          **trace_stream_ << "substitution: no value for binding " << *type
-                          << ", leaving alone\n";
-        }
-        return type;
+  }
+
+  // For an associated constant, look for a rewrite.
+  auto operator()(Nonnull<const AssociatedConstant*> assoc)
+      -> Nonnull<const Value*> {
+    Nonnull<const Value*> base = Transform(&assoc->base());
+    Nonnull<const InterfaceType*> interface = Transform(&assoc->interface());
+    // If we're substituting into an associated constant, we may now be able
+    // to rewrite it to a concrete value.
+    if (auto rewritten_value = type_checker_->LookupRewriteInTypeOf(
+            base, interface, &assoc->constant())) {
+      return (*rewritten_value)->converted_replacement;
+    }
+    const auto* witness = cast<Witness>(Transform(&assoc->witness()));
+    witness = type_checker_->RefineWitness(witness, base, interface);
+    if (auto rewritten_value = type_checker_->LookupRewriteInWitness(
+            witness, interface, &assoc->constant())) {
+      return (*rewritten_value)->converted_replacement;
+    }
+    return type_checker_->arena_->New<AssociatedConstant>(
+        base, interface, &assoc->constant(), witness);
+  }
+
+  // Rebuilding a function type needs special handling to build new bindings.
+  // TODO: This is probably not specific to substitution, and would apply to
+  // other transforms too.
+  auto operator()(Nonnull<const FunctionType*> fn_type)
+      -> Nonnull<const FunctionType*> {
+    SubstitutedGenericBindings subst_bindings(type_checker_, bindings_);
+
+    // Apply substitution to into generic parameters and deduced bindings.
+    std::vector<FunctionType::GenericParameter> generic_parameters;
+    for (const FunctionType::GenericParameter& gp :
+         fn_type->generic_parameters()) {
+      generic_parameters.push_back(
+          {.index = gp.index,
+           .binding = subst_bindings.SubstituteIntoGenericBinding(gp.binding)});
+    }
+    std::vector<Nonnull<const GenericBinding*>> deduced_bindings;
+    for (Nonnull<const GenericBinding*> gb : fn_type->deduced_bindings()) {
+      deduced_bindings.push_back(
+          subst_bindings.SubstituteIntoGenericBinding(gb));
+    }
+
+    // Apply substitution to parameter and return types and create the new
+    // function type.
+    const auto* param = type_checker_->SubstituteImpl(subst_bindings.bindings(),
+                                                      &fn_type->parameters());
+    const auto* ret = type_checker_->SubstituteImpl(subst_bindings.bindings(),
+                                                    &fn_type->return_type());
+    return type_checker_->arena_->New<FunctionType>(
+        param, std::move(generic_parameters), ret, std::move(deduced_bindings),
+        std::move(subst_bindings).TakeImplBindings());
+  }
+
+  // Substituting into a `ConstraintType` needs special handling if we replace
+  // its self type.
+  auto operator()(Nonnull<const ConstraintType*> constraint)
+      -> Nonnull<const Value*> {
+    if (auto it = bindings_.args().find(constraint->self_binding());
+        it != bindings_.args().end()) {
+      // This happens when we substitute into the parameter type of a
+      // function that takes a `T:! Constraint` parameter. In this case we
+      // produce the new type-of-type of the replacement type.
+      Nonnull<const Value*> type_of_type;
+      if (const auto* var_type = dyn_cast<VariableType>(it->second)) {
+        type_of_type = &var_type->binding().static_type();
+      } else if (const auto* assoc_type =
+                     dyn_cast<AssociatedConstant>(it->second)) {
+        type_of_type = type_checker_->GetTypeForAssociatedConstant(assoc_type);
       } else {
-        return it->second;
+        type_of_type = type_checker_->arena_->New<TypeType>();
       }
-    }
-    case Value::Kind::ConstraintWitness: {
-      const auto& witness = cast<ConstraintWitness>(*type);
-      std::vector<Nonnull<const Witness*>> witnesses;
-      witnesses.reserve(witness.witnesses().size());
-      for (const auto* witness : witness.witnesses()) {
-        witnesses.push_back(cast<Witness>(SubstituteImpl(bindings, witness)));
+      if (auto& trace_stream = type_checker_->trace_stream_) {
+        **trace_stream << "substitution: self of constraint " << *constraint
+                       << " is substituted, new type of type is "
+                       << *type_of_type << "\n";
       }
-      return arena_->New<ConstraintWitness>(std::move(witnesses));
-    }
-    case Value::Kind::ConstraintImplWitness: {
-      const auto& witness = cast<ConstraintImplWitness>(*type);
-      return ConstraintImplWitness::Make(
-          arena_,
-          cast<Witness>(SubstituteImpl(bindings, witness.constraint_witness())),
-          witness.index());
+      // TODO: Should we keep any part of the old constraint -- rewrites,
+      // equality constraints, etc?
+      return type_of_type;
+    }
+    ConstraintTypeBuilder builder(type_checker_->arena_,
+                                  constraint->self_binding()->source_loc());
+    builder.AddAndSubstitute(*type_checker_, constraint, builder.GetSelfType(),
+                             builder.GetSelfWitness(), bindings_,
+                             /*add_lookup_contexts=*/true);
+    Nonnull<const ConstraintType*> new_constraint = std::move(builder).Build();
+    if (auto& trace_stream = type_checker_->trace_stream_) {
+      **trace_stream << "substitution: " << *constraint << " => "
+                     << *new_constraint << "\n";
     }
-    case Value::Kind::StaticArrayType:
-    case Value::Kind::ChoiceType:
-    case Value::Kind::MixinPseudoType:
-      // TODO: These can contain bindings. We should substitute into them.
-      return type;
-    case Value::Kind::AutoType:
-    case Value::Kind::IntType:
-    case Value::Kind::BoolType:
-    case Value::Kind::TypeType:
-    case Value::Kind::ContinuationType:
-    case Value::Kind::StringType:
-      // These types cannot contain bindings or witnesses.
-      return type;
-    case Value::Kind::TypeOfMixinPseudoType:
-    case Value::Kind::TypeOfParameterizedEntityName:
-    case Value::Kind::TypeOfMemberName:
-      // TODO: We should substitute into the value and produce a new type of
-      // type for it.
-      return type;
-    case Value::Kind::ParameterizedEntityName:
-    case Value::Kind::MemberName:
-    case Value::Kind::FunctionValue:
-    case Value::Kind::DestructorValue:
-    case Value::Kind::BoundMethodValue:
-    case Value::Kind::StructValue:
-    case Value::Kind::NominalClassValue:
-    case Value::Kind::AlternativeValue:
-    case Value::Kind::BindingPlaceholderValue:
-    case Value::Kind::AddrValue:
-    case Value::Kind::AlternativeConstructorValue:
-    case Value::Kind::ContinuationValue:
-      // This can happen when substituting into the arguments of a class or
-      // interface.
-      // TODO: Implement substitution for these cases.
-      return type;
-    case Value::Kind::IntValue:
-    case Value::Kind::BoolValue:
-    case Value::Kind::PointerValue:
-    case Value::Kind::LValue:
-    case Value::Kind::StringValue:
-    case Value::Kind::UninitializedValue:
-      // These values cannot contain bindings or witnesses.
-      return type;
+    return new_constraint;
   }
+
+ private:
+  Nonnull<const TypeChecker*> type_checker_;
+  const Bindings& bindings_;
+};
+
+auto TypeChecker::SubstituteImpl(const Bindings& bindings,
+                                 Nonnull<const Value*> type) const
+    -> Nonnull<const Value*> {
+  return SubstituteTransform(this, bindings).Transform(type);
 }
 
 auto TypeChecker::RefineWitness(Nonnull<const Witness*> witness,
@@ -2543,7 +2433,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
         CARBON_ASSIGN_OR_RETURN(
             Nonnull<const Value*> type,
             TypeCheckTypeExp(&arg.expression(), impl_scope));
-        fields.push_back({.name = arg.name(), .value = type});
+        fields.push_back({arg.name(), type});
       }
       struct_type.set_static_type(arena_->New<TypeType>());
       struct_type.set_value_category(ValueCategory::Let);
@@ -5303,7 +5193,7 @@ auto TypeChecker::DeclareChoiceDeclaration(Nonnull<ChoiceDeclaration*> choice,
     CARBON_ASSIGN_OR_RETURN(auto signature,
                             TypeCheckTypeExp(&alternative->signature(),
                                              *scope_info.innermost_scope));
-    alternatives.push_back({.name = alternative->name(), .value = signature});
+    alternatives.push_back({alternative->name(), signature});
   }
   choice->set_members(alternatives);
   if (choice->type_params().has_value()) {

+ 1 - 0
explorer/interpreter/type_checker.h

@@ -95,6 +95,7 @@ class TypeChecker {
  private:
   class ConstraintTypeBuilder;
   class SubstitutedGenericBindings;
+  class SubstituteTransform;
   class ArgumentDeduction;
 
   // Information about the currently enclosing scopes.

+ 260 - 46
explorer/interpreter/value.h

@@ -26,6 +26,16 @@ namespace Carbon {
 class Action;
 class AssociatedConstant;
 
+// A trait type that describes how to allocate an instance of `T` in an arena.
+// Returns the created object, which is not required to be of type `T`.
+template <typename T>
+struct AllocateTrait {
+  template <typename... Args>
+  static auto New(Nonnull<Arena*> arena, Args&&... args) -> Nonnull<const T*> {
+    return arena->New<T>(std::forward<Args>(args)...);
+  }
+};
+
 // Abstract base class of all AST nodes representing values.
 //
 // Value and its derived classes support LLVM-style RTTI, including
@@ -37,56 +47,18 @@ class AssociatedConstant;
 class Value {
  public:
   enum class Kind {
-    IntValue,
-    FunctionValue,
-    DestructorValue,
-    BoundMethodValue,
-    PointerValue,
-    LValue,
-    BoolValue,
-    StructValue,
-    NominalClassValue,
-    AlternativeValue,
-    TupleValue,
-    UninitializedValue,
-    ImplWitness,
-    BindingWitness,
-    ConstraintWitness,
-    ConstraintImplWitness,
-    IntType,
-    BoolType,
-    TypeType,
-    FunctionType,
-    PointerType,
-    AutoType,
-    StructType,
-    NominalClassType,
-    TupleType,
-    MixinPseudoType,
-    InterfaceType,
-    NamedConstraintType,
-    ConstraintType,
-    ChoiceType,
-    ContinuationType,  // The type of a continuation.
-    VariableType,      // e.g., generic type parameters.
-    AssociatedConstant,
-    ParameterizedEntityName,
-    MemberName,
-    BindingPlaceholderValue,
-    AddrValue,
-    AlternativeConstructorValue,
-    ContinuationValue,  // A first-class continuation value.
-    StringType,
-    StringValue,
-    TypeOfMixinPseudoType,
-    TypeOfParameterizedEntityName,
-    TypeOfMemberName,
-    StaticArrayType,
+#define CARBON_VALUE_KIND(kind) kind,
+#include "explorer/interpreter/value_kinds.def"
   };
 
   Value(const Value&) = delete;
   auto operator=(const Value&) -> Value& = delete;
 
+  // Call `f` on this value, cast to its most-derived type. `R` specifies the
+  // expected return type of `f`.
+  template <typename R, typename F>
+  auto Visit(F f) const -> R;
+
   void Print(llvm::raw_ostream& out) const;
   LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
 
@@ -155,6 +127,11 @@ class IntValue : public Value {
     return value->kind() == Kind::IntValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(value_);
+  }
+
   auto value() const -> int { return value_; }
 
  private:
@@ -177,6 +154,11 @@ class FunctionValue : public Value {
     return value->kind() == Kind::FunctionValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, bindings_);
+  }
+
   auto declaration() const -> const FunctionDeclaration& {
     return *declaration_;
   }
@@ -204,6 +186,11 @@ class DestructorValue : public Value {
     return value->kind() == Kind::DestructorValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_);
+  }
+
   auto declaration() const -> const DestructorDeclaration& {
     return *declaration_;
   }
@@ -233,6 +220,11 @@ class BoundMethodValue : public Value {
     return value->kind() == Kind::BoundMethodValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, receiver_, bindings_);
+  }
+
   auto declaration() const -> const FunctionDeclaration& {
     return *declaration_;
   }
@@ -263,6 +255,11 @@ class LValue : public Value {
     return value->kind() == Kind::LValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(value_);
+  }
+
   auto address() const -> const Address& { return value_; }
 
  private:
@@ -279,6 +276,11 @@ class PointerValue : public Value {
     return value->kind() == Kind::PointerValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(value_);
+  }
+
   auto address() const -> const Address& { return value_; }
 
  private:
@@ -294,6 +296,11 @@ class BoolValue : public Value {
     return value->kind() == Kind::BoolValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(value_);
+  }
+
   auto value() const -> bool { return value_; }
 
  private:
@@ -312,6 +319,11 @@ class StructValue : public Value {
     return value->kind() == Kind::StructValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(elements_);
+  }
+
   auto elements() const -> llvm::ArrayRef<NamedValue> { return elements_; }
 
   // Returns the value of the field named `name` in this struct, or
@@ -339,6 +351,11 @@ class NominalClassValue : public Value {
     return value->kind() == Kind::NominalClassValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(type_, inits_, base_);
+  }
+
   auto type() const -> const Value& { return *type_; }
   auto inits() const -> const Value& { return *inits_; }
   auto base() const -> std::optional<Nonnull<const NominalClassValue*>> {
@@ -364,6 +381,11 @@ class AlternativeConstructorValue : public Value {
     return value->kind() == Kind::AlternativeConstructorValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(alt_name_, choice_name_);
+  }
+
   auto alt_name() const -> const std::string& { return alt_name_; }
   auto choice_name() const -> const std::string& { return choice_name_; }
 
@@ -386,6 +408,11 @@ class AlternativeValue : public Value {
     return value->kind() == Kind::AlternativeValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(alt_name_, choice_name_, argument_);
+  }
+
   auto alt_name() const -> const std::string& { return alt_name_; }
   auto choice_name() const -> const std::string& { return choice_name_; }
   auto argument() const -> const Value& { return *argument_; }
@@ -414,6 +441,11 @@ class TupleValueBase : public Value {
            value->kind() == Kind::TupleType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(elements_);
+  }
+
  private:
   std::vector<Nonnull<const Value*>> elements_;
 };
@@ -470,6 +502,11 @@ class BindingPlaceholderValue : public Value {
     return value->kind() == Kind::BindingPlaceholderValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return value_node_ ? f(*value_node_) : f();
+  }
+
   auto value_node() const -> const std::optional<ValueNodeView>& {
     return value_node_;
   }
@@ -488,6 +525,11 @@ class AddrValue : public Value {
     return value->kind() == Kind::AddrValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(pattern_);
+  }
+
   auto pattern() const -> const Value& { return *pattern_; }
 
  private:
@@ -504,6 +546,11 @@ class UninitializedValue : public Value {
     return value->kind() == Kind::UninitializedValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(pattern_);
+  }
+
   auto pattern() const -> const Value& { return *pattern_; }
 
  private:
@@ -518,6 +565,11 @@ class IntType : public Value {
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::IntType;
   }
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f();
+  }
 };
 
 // The bool type.
@@ -528,6 +580,11 @@ class BoolType : public Value {
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::BoolType;
   }
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f();
+  }
 };
 
 // A type type.
@@ -538,6 +595,11 @@ class TypeType : public Value {
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::TypeType;
   }
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f();
+  }
 };
 
 // A function type.
@@ -571,6 +633,12 @@ class FunctionType : public Value {
     return value->kind() == Kind::FunctionType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(parameters_, generic_parameters_, return_type_, deduced_bindings_,
+             impl_bindings_);
+  }
+
   // The type of the function parameter tuple.
   auto parameters() const -> const Value& { return *parameters_; }
   // Parameters that use a generic `:!` binding at the top level.
@@ -609,6 +677,11 @@ class PointerType : public Value {
     return value->kind() == Kind::PointerType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(type_);
+  }
+
   auto type() const -> const Value& { return *type_; }
 
  private:
@@ -623,6 +696,11 @@ class AutoType : public Value {
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::AutoType;
   }
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f();
+  }
 };
 
 // A struct type.
@@ -637,6 +715,11 @@ class StructType : public Value {
     return value->kind() == Kind::StructType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(fields_);
+  }
+
   auto fields() const -> llvm::ArrayRef<NamedValue> { return fields_; }
 
  private:
@@ -644,7 +727,6 @@ class StructType : public Value {
 };
 
 // A class type.
-// TODO: Consider splitting this class into several classes.
 class NominalClassType : public Value {
  public:
   // Construct a non-generic class type.
@@ -669,6 +751,11 @@ class NominalClassType : public Value {
     return value->kind() == Kind::NominalClassType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, bindings_, base_);
+  }
+
   auto declaration() const -> const ClassDeclaration& { return *declaration_; }
 
   auto bindings() const -> const Bindings& { return *bindings_; }
@@ -713,6 +800,11 @@ class MixinPseudoType : public Value {
     return value->kind() == Kind::MixinPseudoType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, bindings_);
+  }
+
   auto declaration() const -> const MixinDeclaration& { return *declaration_; }
 
   auto bindings() const -> const Bindings& { return *bindings_; }
@@ -766,6 +858,11 @@ class InterfaceType : public Value {
     return value->kind() == Kind::InterfaceType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, bindings_);
+  }
+
   auto declaration() const -> const InterfaceDeclaration& {
     return *declaration_;
   }
@@ -797,6 +894,11 @@ class NamedConstraintType : public Value {
     return value->kind() == Kind::NamedConstraintType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, bindings_);
+  }
+
   auto declaration() const -> const ConstraintDeclaration& {
     return *declaration_;
   }
@@ -885,6 +987,12 @@ class ConstraintType : public Value {
     return value->kind() == Kind::ConstraintType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(self_binding_, impl_constraints_, equality_constraints_,
+             rewrite_constraints_, lookup_contexts_);
+  }
+
   auto self_binding() const -> Nonnull<const GenericBinding*> {
     return self_binding_;
   }
@@ -951,6 +1059,12 @@ class ImplWitness : public Witness {
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::ImplWitness;
   }
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, bindings_);
+  }
+
   auto declaration() const -> const ImplDeclaration& { return *declaration_; }
 
   auto bindings() const -> const Bindings& { return *bindings_; }
@@ -977,6 +1091,11 @@ class BindingWitness : public Witness {
     return value->kind() == Kind::BindingWitness;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(binding_);
+  }
+
   auto binding() const -> Nonnull<const ImplBinding*> { return binding_; }
 
  private:
@@ -994,6 +1113,11 @@ class ConstraintWitness : public Witness {
     return value->kind() == Kind::ConstraintWitness;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(witnesses_);
+  }
+
   auto witnesses() const -> llvm::ArrayRef<Nonnull<const Witness*>> {
     return witnesses_;
   }
@@ -1033,6 +1157,11 @@ class ConstraintImplWitness : public Witness {
     return value->kind() == Kind::ConstraintImplWitness;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(constraint_witness_, index_);
+  }
+
   // Get the witness for the complete `ConstraintType`.
   auto constraint_witness() const -> Nonnull<const Witness*> {
     return constraint_witness_;
@@ -1046,6 +1175,16 @@ class ConstraintImplWitness : public Witness {
   int index_;
 };
 
+// Allocate a `ConstraintImplWitness` using the custom `Make` function.
+template <>
+struct AllocateTrait<ConstraintImplWitness> {
+  template <typename... Args>
+  static auto New(Nonnull<Arena*> arena, Args&&... args)
+      -> Nonnull<const Witness*> {
+    return ConstraintImplWitness::Make(arena, std::forward<Args>(args)...);
+  }
+};
+
 // A choice type.
 class ChoiceType : public Value {
  public:
@@ -1059,6 +1198,11 @@ class ChoiceType : public Value {
     return value->kind() == Kind::ChoiceType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, bindings_);
+  }
+
   auto name() const -> const std::string& { return declaration_->name(); }
 
   // Returns the parameter types of the alternative with the given name,
@@ -1089,6 +1233,11 @@ class ContinuationType : public Value {
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::ContinuationType;
   }
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f();
+  }
 };
 
 // A variable type.
@@ -1101,6 +1250,11 @@ class VariableType : public Value {
     return value->kind() == Kind::VariableType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(binding_);
+  }
+
   auto binding() const -> const GenericBinding& { return *binding_; }
 
  private:
@@ -1122,6 +1276,11 @@ class ParameterizedEntityName : public Value {
     return value->kind() == Kind::ParameterizedEntityName;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(declaration_, params_);
+  }
+
   auto declaration() const -> const Declaration& { return *declaration_; }
   auto params() const -> const TuplePattern& { return *params_; }
 
@@ -1152,6 +1311,11 @@ class MemberName : public Value {
     return value->kind() == Kind::MemberName;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(base_type_, interface_, member_);
+  }
+
   // Prints the member name or identifier.
   void Print(llvm::raw_ostream& out) const { member_.Print(out); }
 
@@ -1194,6 +1358,11 @@ class AssociatedConstant : public Value {
     return value->kind() == Kind::AssociatedConstant;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(base_, interface_, constant_, witness_);
+  }
+
   // The type for which we denote an associated constant.
   auto base() const -> const Value& { return *base_; }
 
@@ -1260,6 +1429,11 @@ class ContinuationValue : public Value {
     return value->kind() == Kind::ContinuationValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(stack_);
+  }
+
   // The todo stack of the suspended continuation. Note that this provides
   // mutable access, even when *this is const, because of the reference-like
   // semantics of ContinuationValue.
@@ -1277,6 +1451,11 @@ class StringType : public Value {
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::StringType;
   }
+
+  template <typename F>
+  auto Decompose(F f) const {
+    return f();
+  }
 };
 
 // A string value.
@@ -1289,6 +1468,11 @@ class StringValue : public Value {
     return value->kind() == Kind::StringValue;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(value_);
+  }
+
   auto value() const -> const std::string& { return value_; }
 
  private:
@@ -1304,6 +1488,11 @@ class TypeOfMixinPseudoType : public Value {
     return value->kind() == Kind::TypeOfMixinPseudoType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(mixin_type_);
+  }
+
   auto mixin_type() const -> const MixinPseudoType& { return *mixin_type_; }
 
  private:
@@ -1323,6 +1512,11 @@ class TypeOfParameterizedEntityName : public Value {
     return value->kind() == Kind::TypeOfParameterizedEntityName;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(name_);
+  }
+
   auto name() const -> const ParameterizedEntityName& { return *name_; }
 
  private:
@@ -1346,6 +1540,11 @@ class TypeOfMemberName : public Value {
     return value->kind() == Kind::TypeOfMemberName;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(member_);
+  }
+
   // TODO: consider removing this or moving it elsewhere in the AST,
   // since it's arguably part of the expression value rather than its type.
   auto member() const -> NamedElement { return member_; }
@@ -1370,6 +1569,11 @@ class StaticArrayType : public Value {
     return value->kind() == Kind::StaticArrayType;
   }
 
+  template <typename F>
+  auto Decompose(F f) const {
+    return f(element_type_, size_);
+  }
+
   auto element_type() const -> const Value& { return *element_type_; }
   auto size() const -> size_t { return size_; }
 
@@ -1378,6 +1582,16 @@ class StaticArrayType : public Value {
   size_t size_;
 };
 
+template <typename R, typename F>
+auto Value::Visit(F f) const -> R {
+  switch (kind()) {
+#define CARBON_VALUE_KIND(kind) \
+  case Kind::kind:              \
+    return f(static_cast<const kind*>(this));
+#include "explorer/interpreter/value_kinds.def"
+  }
+}
+
 }  // namespace Carbon
 
 #endif  // CARBON_EXPLORER_INTERPRETER_VALUE_H_

+ 59 - 0
explorer/interpreter/value_kinds.def

@@ -0,0 +1,59 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// This .def file expands the CARBON_VALUE_KIND macro once for each kind of
+// Value. The macro should be defined as taking a single argument, which is the
+// name of the Value type.
+
+#ifndef CARBON_VALUE_KIND
+#error #define CARBON_VALUE_KIND(kind) before including this header
+#endif
+
+CARBON_VALUE_KIND(IntValue)
+CARBON_VALUE_KIND(FunctionValue)
+CARBON_VALUE_KIND(DestructorValue)
+CARBON_VALUE_KIND(BoundMethodValue)
+CARBON_VALUE_KIND(PointerValue)
+CARBON_VALUE_KIND(LValue)
+CARBON_VALUE_KIND(BoolValue)
+CARBON_VALUE_KIND(StructValue)
+CARBON_VALUE_KIND(NominalClassValue)
+CARBON_VALUE_KIND(AlternativeValue)
+CARBON_VALUE_KIND(TupleValue)
+CARBON_VALUE_KIND(UninitializedValue)
+CARBON_VALUE_KIND(ImplWitness)
+CARBON_VALUE_KIND(BindingWitness)
+CARBON_VALUE_KIND(ConstraintWitness)
+CARBON_VALUE_KIND(ConstraintImplWitness)
+CARBON_VALUE_KIND(IntType)
+CARBON_VALUE_KIND(BoolType)
+CARBON_VALUE_KIND(TypeType)
+CARBON_VALUE_KIND(FunctionType)
+CARBON_VALUE_KIND(PointerType)
+CARBON_VALUE_KIND(AutoType)
+CARBON_VALUE_KIND(StructType)
+CARBON_VALUE_KIND(NominalClassType)
+CARBON_VALUE_KIND(TupleType)
+CARBON_VALUE_KIND(MixinPseudoType)
+CARBON_VALUE_KIND(InterfaceType)
+CARBON_VALUE_KIND(NamedConstraintType)
+CARBON_VALUE_KIND(ConstraintType)
+CARBON_VALUE_KIND(ChoiceType)
+CARBON_VALUE_KIND(ContinuationType)
+CARBON_VALUE_KIND(VariableType)
+CARBON_VALUE_KIND(AssociatedConstant)
+CARBON_VALUE_KIND(ParameterizedEntityName)
+CARBON_VALUE_KIND(MemberName)
+CARBON_VALUE_KIND(BindingPlaceholderValue)
+CARBON_VALUE_KIND(AddrValue)
+CARBON_VALUE_KIND(AlternativeConstructorValue)
+CARBON_VALUE_KIND(ContinuationValue)
+CARBON_VALUE_KIND(StringType)
+CARBON_VALUE_KIND(StringValue)
+CARBON_VALUE_KIND(TypeOfMixinPseudoType)
+CARBON_VALUE_KIND(TypeOfParameterizedEntityName)
+CARBON_VALUE_KIND(TypeOfMemberName)
+CARBON_VALUE_KIND(StaticArrayType)
+
+#undef CARBON_VALUE_KIND

+ 166 - 0
explorer/interpreter/value_transform.h

@@ -0,0 +1,166 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CARBON_EXPLORER_INTERPRETER_VALUE_TRANSFORM_H_
+#define CARBON_EXPLORER_INTERPRETER_VALUE_TRANSFORM_H_
+
+#include "explorer/interpreter/value.h"
+
+namespace Carbon {
+
+// A no-op visitor used to implement `IsRecursivelyTransformable`. The
+// `operator()` function returns `true_type` if it's called with arguments that
+// can be used to construct `T`, and `false_type` otherwise.
+template <typename T>
+struct IsRecursivelyTransformableVisitor {
+  template <typename... Args>
+  auto operator()(Args&&... args)
+      -> std::integral_constant<bool, std::is_constructible_v<T, Args...>>;
+};
+
+// A type trait that indicates whether `T` is transformable. A transformable
+// type provides a function
+//
+// template<typename F> void Decompose(F f) const;
+//
+// that takes a callable `f` and passes it an argument list that can be passed
+// to the constructor of `T` to create an equivalent value.
+template <typename T, typename = std::true_type>
+constexpr bool IsRecursivelyTransformable = false;
+template <typename T>
+constexpr bool IsRecursivelyTransformable<
+    T, decltype(std::declval<const T>().Decompose(
+           IsRecursivelyTransformableVisitor<T>{}))> = true;
+
+// Base class for transforms of visitable data types.
+template <typename Derived>
+class TransformBase {
+ public:
+  TransformBase(Nonnull<Arena*> arena) : arena_(arena) {}
+
+  template <typename T>
+  auto Transform(T&& v) -> decltype(auto) {
+    return static_cast<Derived&>(*this)(std::forward<T>(v));
+  }
+
+  // Transformable values are recursively transformed by default.
+  template <typename T,
+            std::enable_if_t<IsRecursivelyTransformable<T>, void*> = nullptr>
+  auto operator()(const T& value) -> T {
+    return value.Decompose([&](auto&&... elements) {
+      return T{Transform(decltype(elements)(elements))...};
+    });
+  }
+
+  // Transformable pointers are recursively transformed and reallocated by
+  // default.
+  template <typename T,
+            std::enable_if_t<IsRecursivelyTransformable<T>, void*> = nullptr>
+  auto operator()(Nonnull<const T*> value) -> auto{
+    return value->Decompose([&](auto&&... elements) {
+      return AllocateTrait<T>::New(arena_,
+                                   Transform(decltype(elements)(elements))...);
+    });
+  }
+
+  // Fundamental types like `int` are assumed to not need transformation.
+  template <typename T>
+  auto operator()(const T& v) -> std::enable_if_t<std::is_fundamental_v<T>, T> {
+    return v;
+  }
+  auto operator()(const std::string& str) -> const std::string& { return str; }
+
+  // Transform `optional<T>` by transforming the `T` if it's present.
+  template <typename T>
+  auto operator()(const std::optional<T>& v) -> std::optional<T> {
+    if (!v) {
+      return std::nullopt;
+    }
+    return Transform(*v);
+  }
+
+  // Transform `vector<T>` by transforming its elements.
+  template <typename T>
+  auto operator()(const std::vector<T>& vec) -> std::vector<T> {
+    std::vector<T> result;
+    result.reserve(vec.size());
+    for (auto& value : vec) {
+      result.push_back(Transform(value));
+    }
+    return result;
+  }
+
+  // Transform `map<T, U>` by transforming its keys and values.
+  template <typename T, typename U>
+  auto operator()(const std::map<T, U>& map) -> std::map<T, U> {
+    std::map<T, U> result;
+    for (auto& [key, value] : map) {
+      result.insert({Transform(key), Transform(value)});
+    }
+    return result;
+  }
+
+ private:
+  Nonnull<Arena*> arena_;
+};
+
+// Base class for transforms of `Value`s.
+template <typename Derived>
+class ValueTransform : public TransformBase<Derived> {
+ public:
+  using TransformBase<Derived>::TransformBase;
+  using TransformBase<Derived>::operator();
+
+  // Leave references to AST nodes alone by default.
+  // The 'int = 0' parameter avoids this function hiding the `operator()(const
+  // T*)` in the base class. We can remove this once we start using a compiler
+  // that implements P1787R6.
+  template <typename NodeT>
+  auto operator()(Nonnull<const NodeT*> node, int = 0)
+      -> std::enable_if_t<std::is_base_of_v<AstNode, NodeT>,
+                          Nonnull<const NodeT*>> {
+    return node;
+  }
+
+  auto operator()(Nonnull<ContinuationValue::StackFragment*> stack_fragment)
+      -> Nonnull<ContinuationValue::StackFragment*> {
+    return stack_fragment;
+  }
+
+  auto operator()(Address addr) -> Address { return addr; }
+
+  auto operator()(ValueNodeView value_node) -> ValueNodeView {
+    return value_node;
+  }
+
+  // For a type that provides a `Visit` function to visit the most-derived
+  // object, visit and transform that most-derived object.
+  template <typename R, typename T>
+  auto TransformDerived(Nonnull<const T*> value) -> R {
+    return value->template Visit<R>([&](const auto* derived_value) {
+      using DerivedType = std::remove_pointer_t<decltype(derived_value)>;
+      static_assert(IsRecursivelyTransformable<DerivedType>);
+      return this->Transform(derived_value);
+    });
+  }
+
+  // For values, dispatch on the value kind and recursively transform.
+  auto operator()(Nonnull<const Value*> value) -> Nonnull<const Value*> {
+    return TransformDerived<Nonnull<const Value*>>(value);
+  }
+
+  // Provide a more precise type from transforming a `Witness`.
+  auto operator()(Nonnull<const Witness*> value) -> Nonnull<const Witness*> {
+    return llvm::cast<Witness>(this->Transform(llvm::cast<Value>(value)));
+  }
+
+  // For elements, dispatch on the element kind and recursively transform.
+  auto operator()(Nonnull<const Element*> elem) -> Nonnull<const Element*> {
+    return TransformDerived<Nonnull<const Element*>>(elem);
+  }
+};
+
+}  // namespace Carbon
+
+#endif  // CARBON_EXPLORER_INTERPRETER_VALUE_TRANSFORM_H_