瀏覽代碼

Initial support for generic parameters that are introduced in a function parameter list (#1247)

In order for this to work, we need to always use the argument deduction code path for function calls, instead of only using it when there is a `[...]` list, which means that argument deduction now needs to support implicit conversion.

Co-authored-by: Jon Meow <jperkins@google.com>
Richard Smith 4 年之前
父節點
當前提交
c29f56e667

+ 3 - 3
explorer/ast/declaration.cpp

@@ -198,9 +198,9 @@ auto FunctionDeclaration::Create(
                << "illegal AST node in implicit parameter list";
     }
   }
-  return arena->New<FunctionDeclaration>(source_loc, name, resolved_params,
-                                         me_pattern, param_pattern, return_term,
-                                         body);
+  return arena->New<FunctionDeclaration>(source_loc, name,
+                                         std::move(resolved_params), me_pattern,
+                                         param_pattern, return_term, body);
 }
 
 void FunctionDeclaration::PrintDepth(int depth, llvm::raw_ostream& out) const {

+ 16 - 23
explorer/ast/pattern.cpp

@@ -93,40 +93,33 @@ void Pattern::PrintID(llvm::raw_ostream& out) const {
   }
 }
 
-// Equivalent to `GetBindings`, but stores its output in `bindings` instead of
-// returning it.
-static void GetBindingsImpl(
-    const Pattern& pattern,
-    std::vector<Nonnull<const BindingPattern*>>& bindings) {
+auto VisitNestedPatterns(const Pattern& pattern,
+                         llvm::function_ref<bool(const Pattern&)> visitor)
+    -> bool {
+  if (!visitor(pattern)) {
+    return false;
+  }
   switch (pattern.kind()) {
-    case PatternKind::BindingPattern:
-      bindings.push_back(&cast<BindingPattern>(pattern));
-      return;
     case PatternKind::TuplePattern:
       for (const Pattern* field : cast<TuplePattern>(pattern).fields()) {
-        GetBindingsImpl(*field, bindings);
+        if (!VisitNestedPatterns(*field, visitor)) {
+          return false;
+        }
       }
-      return;
+      return true;
     case PatternKind::AlternativePattern:
-      GetBindingsImpl(cast<AlternativePattern>(pattern).arguments(), bindings);
-      return;
+      return VisitNestedPatterns(cast<AlternativePattern>(pattern).arguments(),
+                                 visitor);
+    case PatternKind::VarPattern:
+      return VisitNestedPatterns(cast<VarPattern>(pattern).pattern(), visitor);
+    case PatternKind::BindingPattern:
     case PatternKind::AutoPattern:
     case PatternKind::ExpressionPattern:
     case PatternKind::GenericBinding:
-      return;
-    case PatternKind::VarPattern:
-      GetBindingsImpl(cast<VarPattern>(pattern).pattern(), bindings);
-      return;
+      return true;
   }
 }
 
-auto GetBindings(const Pattern& pattern)
-    -> std::vector<Nonnull<const BindingPattern*>> {
-  std::vector<Nonnull<const BindingPattern*>> result;
-  GetBindingsImpl(pattern, result);
-  return result;
-}
-
 auto PatternFromParenContents(Nonnull<Arena*> arena, SourceLocation source_loc,
                               const ParenContents<Pattern>& paren_contents)
     -> Nonnull<Pattern*> {

+ 7 - 5
explorer/ast/pattern.h

@@ -17,6 +17,7 @@
 #include "explorer/ast/value_category.h"
 #include "explorer/common/source_location.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 
 namespace Carbon {
 
@@ -88,11 +89,12 @@ class Pattern : public AstNode {
   std::optional<Nonnull<const Value*>> value_;
 };
 
-class BindingPattern;
-
-// Returns all `BindingPattern`s in the AST subtree rooted at `pattern`.
-auto GetBindings(const Pattern& pattern)
-    -> std::vector<Nonnull<const BindingPattern*>>;
+// Call the given `visitor` on all patterns nested within the given pattern,
+// including `pattern` itself. Aborts and returns `false` if `visitor` returns
+// `false`, otherwise returns `true`.
+auto VisitNestedPatterns(const Pattern& pattern,
+                         llvm::function_ref<bool(const Pattern&)> visitor)
+    -> bool;
 
 // A pattern consisting of the `auto` keyword.
 class AutoPattern : public Pattern {

+ 2 - 2
explorer/interpreter/interpreter.cpp

@@ -921,8 +921,8 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
         //    { { rt :: fn pt -> [] :: C, E, F} :: S, H}
         // -> { fn pt -> rt :: {C, E, F} :: S, H}
         return todo_.FinishAction(arena_->New<FunctionType>(
-            std::vector<Nonnull<const GenericBinding*>>(), act.results()[0],
-            act.results()[1], std::vector<Nonnull<const ImplBinding*>>()));
+            llvm::None, act.results()[0], act.results()[1], llvm::None,
+            llvm::None));
       }
     }
     case ExpressionKind::ContinuationTypeLiteral: {

+ 185 - 138
explorer/interpreter/type_checker.cpp

@@ -185,19 +185,31 @@ auto TypeChecker::ExpectIsConcreteType(SourceLocation source_loc,
   }
 }
 
+// Returns the named field, or None if not found.
+static auto FindField(llvm::ArrayRef<NamedValue> fields,
+                      const std::string& field_name)
+    -> std::optional<NamedValue> {
+  auto it = std::find_if(
+      fields.begin(), fields.end(),
+      [&](const NamedValue& field) { return field.name == field_name; });
+  if (it == fields.end()) {
+    return std::nullopt;
+  }
+  return *it;
+}
+
 auto TypeChecker::FieldTypesImplicitlyConvertible(
     llvm::ArrayRef<NamedValue> source_fields,
-    llvm::ArrayRef<NamedValue> destination_fields) const {
+    llvm::ArrayRef<NamedValue> destination_fields) const -> bool {
   if (source_fields.size() != destination_fields.size()) {
     return false;
   }
   for (const auto& source_field : source_fields) {
-    auto it = std::find_if(destination_fields.begin(), destination_fields.end(),
-                           [&](const NamedValue& field) {
-                             return field.name == source_field.name;
-                           });
-    if (it == destination_fields.end() ||
-        !IsImplicitlyConvertible(source_field.value, it->value)) {
+    std::optional<NamedValue> destination_field =
+        FindField(destination_fields, source_field.name);
+    if (!destination_field.has_value() ||
+        !IsImplicitlyConvertible(source_field.value,
+                                 destination_field.value().value)) {
       return false;
     }
   }
@@ -315,10 +327,35 @@ auto TypeChecker::ExpectType(SourceLocation source_loc,
 }
 
 auto TypeChecker::ArgumentDeduction(
-    SourceLocation source_loc,
+    SourceLocation source_loc, const std::string& context,
     llvm::ArrayRef<Nonnull<const GenericBinding*>> type_params,
     BindingMap& deduced, Nonnull<const Value*> param_type,
-    Nonnull<const Value*> arg_type) const -> ErrorOr<Success> {
+    Nonnull<const Value*> arg_type, bool allow_implicit_conversion) const
+    -> ErrorOr<Success> {
+  if (trace_stream_) {
+    **trace_stream_ << "deducing " << *param_type << " from " << *arg_type
+                    << "\n";
+  }
+  // Handle the case where we can't perform deduction, either because the
+  // parameter is a primitive type or because the parameter and argument have
+  // different forms. In this case, we require an implicit conversion to exist,
+  // or for an exact type match if implicit conversions are not permitted.
+  auto handle_non_deduced_type = [&]() -> ErrorOr<Success> {
+    if (!IsConcreteType(param_type)) {
+      // Parameter type contains a nested `auto` and argument type isn't the
+      // same kind of type.
+      // FIXME: This seems like something we should be able to accept.
+      return CompilationError(source_loc) << "type error in " << context << "\n"
+                                          << "expected: " << *param_type << "\n"
+                                          << "actual: " << *arg_type;
+    }
+    const Value* subst_param_type = Substitute(deduced, param_type);
+    return allow_implicit_conversion
+               ? ExpectType(source_loc, context, subst_param_type, arg_type)
+               : ExpectExactType(source_loc, context, subst_param_type,
+                                 arg_type);
+  };
+
   switch (param_type->kind()) {
     case Value::Kind::VariableType: {
       const auto& var_type = cast<VariableType>(*param_type);
@@ -326,23 +363,18 @@ auto TypeChecker::ArgumentDeduction(
                     &var_type.binding()) != type_params.end()) {
         auto [it, success] = deduced.insert({&var_type.binding(), arg_type});
         if (!success) {
-          // Variable already has a match.
-          // TODO: can we allow implicit conversions here?
+          // All deductions are required to produce the same value.
           CARBON_RETURN_IF_ERROR(ExpectExactType(
-              source_loc, "argument deduction", it->second, arg_type));
+              source_loc, "repeated argument deduction", it->second, arg_type));
         }
       } else {
-        CARBON_RETURN_IF_ERROR(ExpectExactType(source_loc, "argument deduction",
-                                               param_type, arg_type));
+        return handle_non_deduced_type();
       }
       return Success();
     }
     case Value::Kind::TupleValue: {
       if (arg_type->kind() != Value::Kind::TupleValue) {
-        return CompilationError(source_loc)
-               << "type error in argument deduction\n"
-               << "expected: " << *param_type << "\n"
-               << "actual: " << *arg_type;
+        return handle_non_deduced_type();
       }
       const auto& param_tup = cast<TupleValue>(*param_type);
       const auto& arg_tup = cast<TupleValue>(*arg_type);
@@ -353,67 +385,91 @@ auto TypeChecker::ArgumentDeduction(
                << arg_tup.elements().size();
       }
       for (size_t i = 0; i < param_tup.elements().size(); ++i) {
-        CARBON_RETURN_IF_ERROR(
-            ArgumentDeduction(source_loc, type_params, deduced,
-                              param_tup.elements()[i], arg_tup.elements()[i]));
+        CARBON_RETURN_IF_ERROR(ArgumentDeduction(
+            source_loc, context, type_params, deduced, param_tup.elements()[i],
+            arg_tup.elements()[i], allow_implicit_conversion));
       }
       return Success();
     }
     case Value::Kind::StructType: {
       if (arg_type->kind() != Value::Kind::StructType) {
-        return CompilationError(source_loc)
-               << "type error in argument deduction\n"
-               << "expected: " << *param_type << "\n"
-               << "actual: " << *arg_type;
+        return handle_non_deduced_type();
       }
       const auto& param_struct = cast<StructType>(*param_type);
       const auto& arg_struct = cast<StructType>(*arg_type);
-      if (param_struct.fields().size() != arg_struct.fields().size()) {
+      auto diagnose_missing_field = [&](const StructType& struct_type,
+                                        const NamedValue& field,
+                                        bool missing_from_source) -> Error {
+        static constexpr const char* SourceOrDestination[2] = {"source",
+                                                               "destination"};
         return CompilationError(source_loc)
-               << "mismatch in struct field counts, expected "
-               << param_struct.fields().size() << " but got "
-               << arg_struct.fields().size();
-      }
+               << "mismatch in field names, "
+               << SourceOrDestination[missing_from_source ? 1 : 0] << " field `"
+               << field.name << "` not in "
+               << SourceOrDestination[missing_from_source ? 0 : 1] << " type `"
+               << struct_type << "`";
+      };
       for (size_t i = 0; i < param_struct.fields().size(); ++i) {
-        if (param_struct.fields()[i].name != arg_struct.fields()[i].name) {
-          return CompilationError(source_loc)
-                 << "mismatch in field names, " << param_struct.fields()[i].name
-                 << " != " << arg_struct.fields()[i].name;
+        NamedValue param_field = param_struct.fields()[i];
+        NamedValue arg_field;
+        if (allow_implicit_conversion) {
+          if (std::optional<NamedValue> maybe_arg_field =
+                  FindField(arg_struct.fields(), param_field.name)) {
+            arg_field = *maybe_arg_field;
+          } else {
+            return diagnose_missing_field(arg_struct, param_field, true);
+          }
+        } else {
+          if (i >= arg_struct.fields().size()) {
+            return diagnose_missing_field(arg_struct, param_field, true);
+          }
+          arg_field = arg_struct.fields()[i];
+          if (param_field.name != arg_field.name) {
+            return CompilationError(source_loc)
+                   << "mismatch in field names, `" << param_field.name
+                   << "` != `" << arg_field.name << "`";
+          }
         }
         CARBON_RETURN_IF_ERROR(ArgumentDeduction(
-            source_loc, type_params, deduced, param_struct.fields()[i].value,
-            arg_struct.fields()[i].value));
+            source_loc, context, type_params, deduced, param_field.value,
+            arg_field.value, allow_implicit_conversion));
+      }
+      if (param_struct.fields().size() != arg_struct.fields().size()) {
+        CARBON_CHECK(allow_implicit_conversion)
+            << "should have caught this earlier";
+        for (const NamedValue& arg_field : arg_struct.fields()) {
+          if (!FindField(param_struct.fields(), arg_field.name).has_value()) {
+            return diagnose_missing_field(param_struct, arg_field, false);
+          }
+        }
+        CARBON_FATAL() << "field count mismatch but no missing field; "
+                       << "duplicate field name?";
       }
       return Success();
     }
     case Value::Kind::FunctionType: {
       if (arg_type->kind() != Value::Kind::FunctionType) {
-        return CompilationError(source_loc)
-               << "type error in argument deduction\n"
-               << "expected: " << *param_type << "\n"
-               << "actual: " << *arg_type;
+        return handle_non_deduced_type();
       }
       const auto& param_fn = cast<FunctionType>(*param_type);
       const auto& arg_fn = cast<FunctionType>(*arg_type);
       // TODO: handle situation when arg has deduced parameters.
-      CARBON_RETURN_IF_ERROR(ArgumentDeduction(source_loc, type_params, deduced,
-                                               &param_fn.parameters(),
-                                               &arg_fn.parameters()));
-      CARBON_RETURN_IF_ERROR(ArgumentDeduction(source_loc, type_params, deduced,
-                                               &param_fn.return_type(),
-                                               &arg_fn.return_type()));
+      CARBON_RETURN_IF_ERROR(ArgumentDeduction(
+          source_loc, context, type_params, deduced, &param_fn.parameters(),
+          &arg_fn.parameters(), /*allow_implicit_conversion=*/false));
+      CARBON_RETURN_IF_ERROR(ArgumentDeduction(
+          source_loc, context, type_params, deduced, &param_fn.return_type(),
+          &arg_fn.return_type(), /*allow_implicit_conversion=*/false));
       return Success();
     }
     case Value::Kind::PointerType: {
       if (arg_type->kind() != Value::Kind::PointerType) {
-        return CompilationError(source_loc)
-               << "type error in argument deduction\n"
-               << "expected: " << *param_type << "\n"
-               << "actual: " << *arg_type;
+        return handle_non_deduced_type();
       }
-      return ArgumentDeduction(source_loc, type_params, deduced,
+      return ArgumentDeduction(source_loc, context, type_params, deduced,
                                &cast<PointerType>(*param_type).type(),
-                               &cast<PointerType>(*arg_type).type());
+                               &cast<PointerType>(*arg_type).type(),
+                               /*allow_implicit_conversion=*/false);
     }
     // Nothing to do in the case for `auto`.
     case Value::Kind::AutoType: {
@@ -421,25 +477,27 @@ auto TypeChecker::ArgumentDeduction(
     }
     case Value::Kind::NominalClassType: {
       const auto& param_class_type = cast<NominalClassType>(*param_type);
-      if (arg_type->kind() == Value::Kind::NominalClassType) {
-        const auto& arg_class_type = cast<NominalClassType>(*arg_type);
-        if (param_class_type.declaration().name() ==
-            arg_class_type.declaration().name()) {
-          for (const auto& [ty, param_ty] : param_class_type.type_args()) {
-            CARBON_RETURN_IF_ERROR(
-                ArgumentDeduction(source_loc, type_params, deduced, param_ty,
-                                  arg_class_type.type_args().at(ty)));
-          }
-          return Success();
-        }
+      if (arg_type->kind() != Value::Kind::NominalClassType) {
+        // FIXME: We could determine the parameters of the class from field
+        // types in a struct argument.
+        return handle_non_deduced_type();
+      }
+      const auto& arg_class_type = cast<NominalClassType>(*arg_type);
+      if (param_class_type.declaration().name() !=
+          arg_class_type.declaration().name()) {
+        return handle_non_deduced_type();
       }
-      return CompilationError(source_loc)
-             << "type error in argument deduction\n"
-             << "expected: " << *param_type << "\n"
-             << "actual: " << *arg_type;
+      for (const auto& [ty, param_ty] : param_class_type.type_args()) {
+        CARBON_RETURN_IF_ERROR(
+            ArgumentDeduction(source_loc, context, type_params, deduced,
+                              param_ty, arg_class_type.type_args().at(ty),
+                              /*allow_implicit_conversion=*/false));
+      }
+      return Success();
     }
-    // For the following cases, we check for type convertability.
+    // For the following cases, we check the type matches.
     case Value::Kind::StaticArrayType:
+      // FIXME: We could deduce the array type from an array or tuple argument.
     case Value::Kind::ContinuationType:
     case Value::Kind::InterfaceType:
     case Value::Kind::ChoiceType:
@@ -451,7 +509,7 @@ auto TypeChecker::ArgumentDeduction(
     case Value::Kind::TypeOfInterfaceType:
     case Value::Kind::TypeOfChoiceType:
     case Value::Kind::TypeOfParameterizedEntityName:
-      return ExpectType(source_loc, "argument deduction", param_type, arg_type);
+      return handle_non_deduced_type();
     // The rest of these cases should never happen.
     case Value::Kind::Witness:
     case Value::Kind::ParameterizedEntityName:
@@ -504,9 +562,8 @@ auto TypeChecker::Substitute(
       const auto& fn_type = cast<FunctionType>(*type);
       auto param = Substitute(dict, &fn_type.parameters());
       auto ret = Substitute(dict, &fn_type.return_type());
-      return arena_->New<FunctionType>(
-          std::vector<Nonnull<const GenericBinding*>>(), param, ret,
-          std::vector<Nonnull<const ImplBinding*>>());
+      return arena_->New<FunctionType>(llvm::None, param, ret, llvm::None,
+                                       llvm::None);
     }
     case Value::Kind::PointerType: {
       return arena_->New<PointerType>(
@@ -602,7 +659,8 @@ auto TypeChecker::MatchImpl(const InterfaceType& iface,
     // case: impl is a generic impl.
     BindingMap deduced_type_args;
     ErrorOr<Success> e = ArgumentDeduction(
-        source_loc, impl.deduced, deduced_type_args, impl.type, impl_type);
+        source_loc, "match", impl.deduced, deduced_type_args, impl.type,
+        impl_type, /*allow_implicit_conversion=*/true);
     if (trace_stream_) {
       **trace_stream_ << "match results: {";
       llvm::ListSeparator sep;
@@ -797,8 +855,8 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
                    << " does not have a field named " << access.field();
           }
           access.set_static_type(arena_->New<FunctionType>(
-              std::vector<Nonnull<const GenericBinding*>>(), *parameter_types,
-              &aggregate_type, std::vector<Nonnull<const ImplBinding*>>()));
+              llvm::None, *parameter_types, &aggregate_type, llvm::None,
+              llvm::None));
           access.set_value_category(ValueCategory::Let);
           return Success();
         }
@@ -1029,41 +1087,35 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
       switch (call.function().static_type().kind()) {
         case Value::Kind::FunctionType: {
           const auto& fun_t = cast<FunctionType>(call.function().static_type());
-          Nonnull<const Value*> parameters = &fun_t.parameters();
-          Nonnull<const Value*> return_type = &fun_t.return_type();
-          if (!fun_t.deduced().empty()) {
-            BindingMap deduced_type_args;
-            CARBON_RETURN_IF_ERROR(ArgumentDeduction(
-                e->source_loc(), fun_t.deduced(), deduced_type_args, parameters,
-                &call.argument().static_type()));
-            call.set_deduced_args(deduced_type_args);
-            for (Nonnull<const GenericBinding*> deduced_param :
-                 fun_t.deduced()) {
-              // TODO: change the following to a CHECK once the real checking
-              // has been added to the type checking of function signatures.
-              if (auto it = deduced_type_args.find(deduced_param);
-                  it == deduced_type_args.end()) {
-                return CompilationError(e->source_loc())
-                       << "could not deduce type argument for type parameter "
-                       << deduced_param->name() << "\n"
-                       << "in " << call;
-              }
+
+          BindingMap deduced_type_args;
+          CARBON_RETURN_IF_ERROR(ArgumentDeduction(
+              e->source_loc(), "call", fun_t.generic_bindings(),
+              deduced_type_args, &fun_t.parameters(),
+              &call.argument().static_type(),
+              /*allow_implicit_conversion=*/true));
+          call.set_deduced_args(deduced_type_args);
+          for (Nonnull<const GenericBinding*> deduced_param : fun_t.deduced()) {
+            // TODO: change the following to a CHECK once the real checking
+            // has been added to the type checking of function signatures.
+            if (auto it = deduced_type_args.find(deduced_param);
+                it == deduced_type_args.end()) {
+              return CompilationError(e->source_loc())
+                     << "could not deduce type argument for type parameter "
+                     << deduced_param->name() << "\n"
+                     << "in " << call;
             }
-            parameters = Substitute(deduced_type_args, parameters);
-            return_type = Substitute(deduced_type_args, return_type);
-            // Find impls for all the impl bindings of the function.
-            ImplExpMap impls;
-            CARBON_RETURN_IF_ERROR(SatisfyImpls(fun_t.impl_bindings(),
-                                                impl_scope, e->source_loc(),
-                                                deduced_type_args, impls));
-            call.set_impls(impls);
-          } else {
-            // No deduced parameters. Check that the argument types
-            // are convertible to the parameter types.
-            CARBON_RETURN_IF_ERROR(ExpectType(e->source_loc(), "call",
-                                              parameters,
-                                              &call.argument().static_type()));
           }
+
+          Nonnull<const Value*> return_type =
+              Substitute(deduced_type_args, &fun_t.return_type());
+
+          // Find impls for all the impl bindings of the function.
+          ImplExpMap impls;
+          CARBON_RETURN_IF_ERROR(SatisfyImpls(fun_t.impl_bindings(), impl_scope,
+                                              e->source_loc(),
+                                              deduced_type_args, impls));
+          call.set_impls(impls);
           call.set_static_type(return_type);
           call.set_value_category(ValueCategory::Let);
           return Success();
@@ -1237,39 +1289,28 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
   }
 }
 
+void TypeChecker::CollectGenericBindingsInPattern(
+    Nonnull<const Pattern*> p,
+    std::vector<Nonnull<const GenericBinding*>>& generic_bindings) {
+  VisitNestedPatterns(*p, [&](const Pattern& pattern) {
+    if (auto* binding = dyn_cast<GenericBinding>(&pattern)) {
+      generic_bindings.push_back(binding);
+    }
+    return true;
+  });
+}
+
 void TypeChecker::CollectImplBindingsInPattern(
     Nonnull<const Pattern*> p,
     std::vector<Nonnull<const ImplBinding*>>& impl_bindings) {
-  switch (p->kind()) {
-    case PatternKind::GenericBinding: {
-      auto& binding = cast<GenericBinding>(*p);
-      if (binding.impl_binding().has_value()) {
-        impl_bindings.push_back(*binding.impl_binding());
+  VisitNestedPatterns(*p, [&](const Pattern& pattern) {
+    if (auto* binding = dyn_cast<GenericBinding>(&pattern)) {
+      if (binding->impl_binding().has_value()) {
+        impl_bindings.push_back(binding->impl_binding().value());
       }
-      return;
     }
-    case PatternKind::TuplePattern: {
-      auto& tuple = cast<TuplePattern>(*p);
-      for (Nonnull<const Pattern*> field : tuple.fields()) {
-        CollectImplBindingsInPattern(field, impl_bindings);
-      }
-      return;
-    }
-    case PatternKind::AlternativePattern: {
-      auto& alternative = cast<AlternativePattern>(*p);
-      CollectImplBindingsInPattern(&alternative.arguments(), impl_bindings);
-      return;
-    }
-    case PatternKind::VarPattern: {
-      auto& var_pattern = cast<VarPattern>(*p);
-      CollectImplBindingsInPattern(&var_pattern.pattern(), impl_bindings);
-      return;
-    }
-    case PatternKind::ExpressionPattern:
-    case PatternKind::AutoPattern:
-    case PatternKind::BindingPattern:
-      return;
-  }
+    return true;
+  });
 }
 
 void TypeChecker::BringPatternImplsIntoScope(Nonnull<const Pattern*> p,
@@ -1317,7 +1358,9 @@ auto TypeChecker::TypeCheckPattern(
     }
     case PatternKind::BindingPattern: {
       auto& binding = cast<BindingPattern>(*p);
-      if (!GetBindings(binding.type()).empty()) {
+      if (!VisitNestedPatterns(binding.type(), [](const Pattern& pattern) {
+            return !isa<BindingPattern>(pattern);
+          })) {
         return CompilationError(binding.type().source_loc())
                << "The type of a binding pattern cannot contain bindings.";
       }
@@ -1679,22 +1722,26 @@ auto TypeChecker::DeclareFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
   }
   ImplScope function_scope;
   function_scope.AddParent(&enclosing_scope);
+  std::vector<Nonnull<const GenericBinding*>> generic_bindings;
   std::vector<Nonnull<const ImplBinding*>> impl_bindings;
   // Bring the deduced parameters into scope.
   for (Nonnull<GenericBinding*> deduced : f->deduced_parameters()) {
     CARBON_RETURN_IF_ERROR(TypeCheckPattern(
         deduced, std::nullopt, function_scope, ValueCategory::Let));
+    CollectGenericBindingsInPattern(deduced, generic_bindings);
     CollectImplBindingsInPattern(deduced, impl_bindings);
   }
   // Type check the receiver pattern.
   if (f->is_method()) {
     CARBON_RETURN_IF_ERROR(TypeCheckPattern(
         &f->me_pattern(), std::nullopt, function_scope, ValueCategory::Let));
+    CollectGenericBindingsInPattern(&f->me_pattern(), generic_bindings);
     CollectImplBindingsInPattern(&f->me_pattern(), impl_bindings);
   }
   // Type check the parameter pattern.
   CARBON_RETURN_IF_ERROR(TypeCheckPattern(&f->param_pattern(), std::nullopt,
                                           function_scope, ValueCategory::Let));
+  CollectGenericBindingsInPattern(&f->param_pattern(), generic_bindings);
   CollectImplBindingsInPattern(&f->param_pattern(), impl_bindings);
 
   // Evaluate the return type, if we can do so without examining the body.
@@ -1730,7 +1777,7 @@ auto TypeChecker::DeclareFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
       ExpectIsConcreteType(f->source_loc(), &f->return_term().static_type()));
   f->set_static_type(arena_->New<FunctionType>(
       f->deduced_parameters(), &f->param_pattern().static_type(),
-      &f->return_term().static_type(), impl_bindings));
+      &f->return_term().static_type(), generic_bindings, impl_bindings));
   SetConstantValue(f, arena_->New<FunctionValue>(f));
 
   if (f->name() == "Main") {

+ 9 - 3
explorer/interpreter/type_checker.h

@@ -37,10 +37,11 @@ class TypeChecker {
   // The `deduced` parameter is an accumulator, that is, it holds the
   // results so-far.
   auto ArgumentDeduction(
-      SourceLocation source_loc,
+      SourceLocation source_loc, const std::string& context,
       llvm::ArrayRef<Nonnull<const GenericBinding*>> type_params,
       BindingMap& deduced, Nonnull<const Value*> param_type,
-      Nonnull<const Value*> arg_type) const -> ErrorOr<Success>;
+      Nonnull<const Value*> arg_type, bool allow_implicit_conversion) const
+      -> ErrorOr<Success>;
 
   // If `impl` can be an implementation of interface `iface` for the
   // given `type`, then return an expression that will produce the witness
@@ -106,6 +107,11 @@ class TypeChecker {
                                 const ImplScope& enclosing_scope)
       -> ErrorOr<Success>;
 
+  // Find all of the GenericBindings in the given pattern.
+  void CollectGenericBindingsInPattern(
+      Nonnull<const Pattern*> p,
+      std::vector<Nonnull<const GenericBinding*>>& generic_bindings);
+
   // Find all of the ImplBindings in the given pattern. The pattern is required
   // to have already been type-checked.
   void CollectImplBindingsInPattern(
@@ -182,7 +188,7 @@ class TypeChecker {
   // must be types.
   auto FieldTypesImplicitlyConvertible(
       llvm::ArrayRef<NamedValue> source_fields,
-      llvm::ArrayRef<NamedValue> destination_fields) const;
+      llvm::ArrayRef<NamedValue> destination_fields) const -> bool;
 
   // Returns true if *source is implicitly convertible to *destination. *source
   // and *destination must be concrete types.

+ 8 - 0
explorer/interpreter/value.h

@@ -428,11 +428,13 @@ class FunctionType : public Value {
   FunctionType(llvm::ArrayRef<Nonnull<const GenericBinding*>> deduced,
                Nonnull<const Value*> parameters,
                Nonnull<const Value*> return_type,
+               llvm::ArrayRef<Nonnull<const GenericBinding*>> generic_bindings,
                llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings)
       : Value(Kind::FunctionType),
         deduced_(deduced),
         parameters_(parameters),
         return_type_(return_type),
+        generic_bindings_(generic_bindings),
         impl_bindings_(impl_bindings) {}
 
   static auto classof(const Value* value) -> bool {
@@ -444,6 +446,11 @@ class FunctionType : public Value {
   }
   auto parameters() const -> const Value& { return *parameters_; }
   auto return_type() const -> const Value& { return *return_type_; }
+  // All generic bindings in this function's signature.
+  auto generic_bindings() const
+      -> llvm::ArrayRef<Nonnull<const GenericBinding*>> {
+    return generic_bindings_;
+  }
   // The bindings for the witness tables (impls) required by the
   // bounds on the type parameters of the generic function.
   auto impl_bindings() const -> llvm::ArrayRef<Nonnull<const ImplBinding*>> {
@@ -454,6 +461,7 @@ class FunctionType : public Value {
   std::vector<Nonnull<const GenericBinding*>> deduced_;
   Nonnull<const Value*> parameters_;
   Nonnull<const Value*> return_type_;
+  std::vector<Nonnull<const GenericBinding*>> generic_bindings_;
   std::vector<Nonnull<const ImplBinding*>> impl_bindings_;
 };
 

+ 1 - 1
explorer/testdata/function/fail_call_with_tuple.carbon

@@ -15,6 +15,6 @@ fn f(x: i32, y: i32) -> i32 { return x + y; }
 fn Main() -> i32 {
   var xy: (i32, i32) = (1, 2);
   // should fail to type-check
-  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/function/fail_call_with_tuple.carbon:[[@LINE+1]]: type error in call: '((i32, i32))' is not implicitly convertible to '(i32, i32)'
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/function/fail_call_with_tuple.carbon:[[@LINE+1]]: mismatch in tuple sizes, expected 2 but got 1
   return f(xy);
 }

+ 1 - 1
explorer/testdata/generic_class/fail_argument_deduction.carbon

@@ -22,7 +22,7 @@ fn FirstOfTwoPoints[T:! Type](a: Point(T), b: Point(T)) -> Point(T) {
 fn Main() -> i32 {
   var p: Point(i32) = {.x = 0, .y = 1};
   var q: Point(Bool) = {.x = true, .y = false};
-  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/generic_class/fail_argument_deduction.carbon:[[@LINE+3]]: type error in argument deduction
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/generic_class/fail_argument_deduction.carbon:[[@LINE+3]]: type error in repeated argument deduction
   // CHECK: expected: i32
   // CHECK: actual: Bool
   return FirstOfTwoPoints(p, q).x;

+ 19 - 0
explorer/testdata/generic_function/fail_implicit_conversion_extra_field.carbon

@@ -0,0 +1,19 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{not} %{explorer} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{not} %{explorer} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{explorer} %s
+
+package ExplorerTest api;
+
+fn Bad[T:! Type](x: {.a: i32, .b: T}) {}
+
+fn Main() -> i32 {
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/generic_function/fail_implicit_conversion_extra_field.carbon:[[@LINE+1]]: mismatch in field names, source field `c` not in destination type `{.a: i32, .b: T:! Type}`
+  Bad({.b = 5, .a = 7, .c = 2});
+  return 0;
+}

+ 19 - 0
explorer/testdata/generic_function/fail_implicit_conversion_missing_field.carbon

@@ -0,0 +1,19 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{not} %{explorer} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{not} %{explorer} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{explorer} %s
+
+package ExplorerTest api;
+
+fn Bad[T:! Type](x: {.a: i32, .b: T}) {}
+
+fn Main() -> i32 {
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/generic_function/fail_implicit_conversion_missing_field.carbon:[[@LINE+1]]: mismatch in field names, destination field `b` not in source type `{.a: i32}`
+  Bad({.a = 5});
+  return 0;
+}

+ 1 - 1
explorer/testdata/generic_function/fail_type_deduction_mismatch.carbon

@@ -15,7 +15,7 @@ fn fst[T:! Type](x: T, y: T) -> T {
 }
 
 fn Main() -> i32 {
-  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/generic_function/fail_type_deduction_mismatch.carbon:[[@LINE+3]]: type error in argument deduction
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/generic_function/fail_type_deduction_mismatch.carbon:[[@LINE+3]]: type error in repeated argument deduction
   // CHECK: expected: i32
   // CHECK: actual: Bool
   return fst(0, true);

+ 27 - 0
explorer/testdata/generic_function/implicit_conversion.carbon

@@ -0,0 +1,27 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{explorer} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{explorer} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{explorer} %s
+// CHECK: result: 4213
+
+package ExplorerTest api;
+
+interface IntLike {
+  fn Convert[me: Self]() -> i32;
+}
+impl i32 as IntLike {
+  fn Convert[me: i32]() -> i32 { return me; }
+}
+
+fn add[T:! IntLike](x: {.a: T, .b: ({.m: i32, .n: T}, i32)}) -> i32 {
+  return 1000 * x.a.Convert() + 100 * x.b[0].m + 10 * x.b[0].n.Convert() + x.b[1];
+}
+
+fn Main() -> i32 {
+  return add({.b = ({.n = 1, .m = 2}, 3), .a = 4});
+}

+ 22 - 0
explorer/testdata/generic_function/nondeduced_generic_param.carbon

@@ -0,0 +1,22 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{explorer} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{explorer} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{explorer} %s
+// CHECK: result: 1
+
+package ExplorerTest api;
+class A(T:! Type) {
+  var v: T;
+}
+fn F(T:! Type, x: T) -> T {
+  var v: A(T) = {.v = x};
+  return v.v;
+}
+fn Main() -> i32 {
+  return F(i32, 1);
+}