Quellcode durchsuchen

Support explicit generic parameters in function parameter lists (#1259)

Richard Smith vor 4 Jahren
Ursprung
Commit
27c8d1fc12

+ 15 - 10
explorer/interpreter/interpreter.cpp

@@ -580,26 +580,31 @@ auto Interpreter::CallFunction(const CallExpression& call,
     case Value::Kind::FunctionValue: {
       const FunctionValue& fun_val = cast<FunctionValue>(*fun);
       const FunctionDeclaration& function = fun_val.declaration();
-      CARBON_ASSIGN_OR_RETURN(
-          Nonnull<const Value*> converted_args,
-          Convert(arg, &function.param_pattern().static_type(),
-                  call.source_loc()));
-      RuntimeScope function_scope(&heap_);
+      RuntimeScope binding_scope(&heap_);
       // Bring the class type arguments into scope.
       for (const auto& [bind, val] : fun_val.type_args()) {
-        function_scope.Initialize(bind, val);
+        binding_scope.Initialize(bind, val);
       }
       // Bring the deduced type arguments into scope.
       for (const auto& [bind, val] : call.deduced_args()) {
-        function_scope.Initialize(bind, val);
+        binding_scope.Initialize(bind, val);
       }
       // Bring the impl witness tables into scope.
       for (const auto& [impl_bind, witness] : witnesses) {
-        function_scope.Initialize(impl_bind, witness);
+        binding_scope.Initialize(impl_bind, witness);
       }
       for (const auto& [impl_bind, witness] : fun_val.witnesses()) {
-        function_scope.Initialize(impl_bind, witness);
+        binding_scope.Initialize(impl_bind, witness);
       }
+      // Enter the binding scope to make any deduced arguments visible before
+      // we resolve the parameter type.
+      todo_.CurrentAction().StartScope(std::move(binding_scope));
+      CARBON_ASSIGN_OR_RETURN(
+          Nonnull<const Value*> converted_args,
+          Convert(arg, &function.param_pattern().static_type(),
+                  call.source_loc()));
+
+      RuntimeScope function_scope(&heap_);
       BindingMap generic_args;
       CARBON_CHECK(PatternMatch(&function.param_pattern().value(),
                                 converted_args, call.source_loc(),
@@ -921,7 +926,7 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
         //    { { rt :: fn pt -> [] :: C, E, F} :: S, H}
         // -> { fn pt -> rt :: {C, E, F} :: S, H}
         return todo_.FinishAction(arena_->New<FunctionType>(
-            llvm::None, act.results()[0], act.results()[1], llvm::None,
+            act.results()[0], llvm::None, act.results()[1], llvm::None,
             llvm::None));
       }
     }

+ 71 - 21
explorer/interpreter/type_checker.cpp

@@ -562,7 +562,7 @@ auto TypeChecker::Substitute(
       const auto& fn_type = cast<FunctionType>(*type);
       auto param = Substitute(dict, &fn_type.parameters());
       auto ret = Substitute(dict, &fn_type.return_type());
-      return arena_->New<FunctionType>(llvm::None, param, ret, llvm::None,
+      return arena_->New<FunctionType>(param, llvm::None, ret, llvm::None,
                                        llvm::None);
     }
     case Value::Kind::PointerType: {
@@ -855,7 +855,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
                    << " does not have a field named " << access.field();
           }
           access.set_static_type(arena_->New<FunctionType>(
-              llvm::None, *parameter_types, &aggregate_type, llvm::None,
+              *parameter_types, llvm::None, &aggregate_type, llvm::None,
               llvm::None));
           access.set_value_category(ValueCategory::Let);
           return Success();
@@ -1088,18 +1088,57 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
         case Value::Kind::FunctionType: {
           const auto& fun_t = cast<FunctionType>(call.function().static_type());
 
-          BindingMap deduced_type_args;
-          CARBON_RETURN_IF_ERROR(ArgumentDeduction(
-              e->source_loc(), "call", fun_t.generic_bindings(),
-              deduced_type_args, &fun_t.parameters(),
-              &call.argument().static_type(),
-              /*allow_implicit_conversion=*/true));
-          call.set_deduced_args(deduced_type_args);
-          for (Nonnull<const GenericBinding*> deduced_param : fun_t.deduced()) {
+          const auto& param_tuple = cast<TupleValue>(fun_t.parameters());
+          const auto& arg_tuple = cast<TupleLiteral>(call.argument());
+          llvm::ArrayRef<FunctionType::GenericParameter> generic_params =
+              fun_t.generic_parameters();
+          if (param_tuple.elements().size() != arg_tuple.fields().size()) {
+            return CompilationError(call.source_loc())
+                   << "wrong number of arguments in function call, expected "
+                   << param_tuple.elements().size() << " but got "
+                   << arg_tuple.fields().size();
+          }
+
+          // Bindings for deduced parameters and generic parameters.
+          BindingMap generic_bindings;
+
+          // Deduce and/or convert each argument to the corresponding
+          // parameter.
+          for (size_t i = 0; i < param_tuple.elements().size(); ++i) {
+            const Value* param = param_tuple.elements()[i];
+            const Expression* arg = arg_tuple.fields()[i];
+            CARBON_RETURN_IF_ERROR(ArgumentDeduction(
+                arg->source_loc(), "call", fun_t.deduced_bindings(),
+                generic_bindings, param, &arg->static_type(),
+                /*allow_implicit_conversion=*/true));
+            // If the parameter is a `:!` binding, evaluate and collect its
+            // value for use in later parameters and in the function body.
+            if (!generic_params.empty() && generic_params.front().index == i) {
+              CARBON_ASSIGN_OR_RETURN(Nonnull<const Value*> arg_value,
+                                      InterpExp(arg, arena_, trace_stream_));
+              if (trace_stream_) {
+                **trace_stream_ << "evaluated generic parameter "
+                                << *generic_params.front().binding << " as "
+                                << *arg_value << "\n";
+              }
+              bool newly_added =
+                  generic_bindings
+                      .insert({generic_params.front().binding, arg_value})
+                      .second;
+              CARBON_CHECK(newly_added)
+                  << "generic parameter should not be deduced";
+              generic_params = generic_params.drop_front();
+            }
+          }
+          CARBON_CHECK(generic_params.empty())
+              << "did not find all generic parameters in parameter list";
+          call.set_deduced_args(generic_bindings);
+          for (Nonnull<const GenericBinding*> deduced_param :
+               fun_t.deduced_bindings()) {
             // TODO: change the following to a CHECK once the real checking
             // has been added to the type checking of function signatures.
-            if (auto it = deduced_type_args.find(deduced_param);
-                it == deduced_type_args.end()) {
+            if (auto it = generic_bindings.find(deduced_param);
+                it == generic_bindings.end()) {
               return CompilationError(e->source_loc())
                      << "could not deduce type argument for type parameter "
                      << deduced_param->name() << "\n"
@@ -1108,13 +1147,13 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
           }
 
           Nonnull<const Value*> return_type =
-              Substitute(deduced_type_args, &fun_t.return_type());
+              Substitute(generic_bindings, &fun_t.return_type());
 
           // Find impls for all the impl bindings of the function.
           ImplExpMap impls;
           CARBON_RETURN_IF_ERROR(SatisfyImpls(fun_t.impl_bindings(), impl_scope,
-                                              e->source_loc(),
-                                              deduced_type_args, impls));
+                                              e->source_loc(), generic_bindings,
+                                              impls));
           call.set_impls(impls);
           call.set_static_type(return_type);
           call.set_value_category(ValueCategory::Let);
@@ -1722,28 +1761,39 @@ auto TypeChecker::DeclareFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
   }
   ImplScope function_scope;
   function_scope.AddParent(&enclosing_scope);
-  std::vector<Nonnull<const GenericBinding*>> generic_bindings;
+  std::vector<Nonnull<const GenericBinding*>> deduced_bindings;
   std::vector<Nonnull<const ImplBinding*>> impl_bindings;
   // Bring the deduced parameters into scope.
   for (Nonnull<GenericBinding*> deduced : f->deduced_parameters()) {
     CARBON_RETURN_IF_ERROR(TypeCheckPattern(
         deduced, std::nullopt, function_scope, ValueCategory::Let));
-    CollectGenericBindingsInPattern(deduced, generic_bindings);
+    CollectGenericBindingsInPattern(deduced, deduced_bindings);
     CollectImplBindingsInPattern(deduced, impl_bindings);
   }
   // Type check the receiver pattern.
   if (f->is_method()) {
     CARBON_RETURN_IF_ERROR(TypeCheckPattern(
         &f->me_pattern(), std::nullopt, function_scope, ValueCategory::Let));
-    CollectGenericBindingsInPattern(&f->me_pattern(), generic_bindings);
+    CollectGenericBindingsInPattern(&f->me_pattern(), deduced_bindings);
     CollectImplBindingsInPattern(&f->me_pattern(), impl_bindings);
   }
   // Type check the parameter pattern.
   CARBON_RETURN_IF_ERROR(TypeCheckPattern(&f->param_pattern(), std::nullopt,
                                           function_scope, ValueCategory::Let));
-  CollectGenericBindingsInPattern(&f->param_pattern(), generic_bindings);
   CollectImplBindingsInPattern(&f->param_pattern(), impl_bindings);
 
+  // Keep track of any generic parameters and nested generic bindings in the
+  // parameter pattern.
+  std::vector<FunctionType::GenericParameter> generic_parameters;
+  for (size_t i = 0; i != f->param_pattern().fields().size(); ++i) {
+    const Pattern* param_pattern = f->param_pattern().fields()[i];
+    if (auto* binding = dyn_cast<GenericBinding>(param_pattern)) {
+      generic_parameters.push_back({i, binding});
+    } else {
+      CollectGenericBindingsInPattern(param_pattern, deduced_bindings);
+    }
+  }
+
   // Evaluate the return type, if we can do so without examining the body.
   if (std::optional<Nonnull<Expression*>> return_expression =
           f->return_term().type_expression();
@@ -1776,8 +1826,8 @@ auto TypeChecker::DeclareFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
   CARBON_RETURN_IF_ERROR(
       ExpectIsConcreteType(f->source_loc(), &f->return_term().static_type()));
   f->set_static_type(arena_->New<FunctionType>(
-      f->deduced_parameters(), &f->param_pattern().static_type(),
-      &f->return_term().static_type(), generic_bindings, impl_bindings));
+      &f->param_pattern().static_type(), generic_parameters,
+      &f->return_term().static_type(), deduced_bindings, impl_bindings));
   SetConstantValue(f, arena_->New<FunctionValue>(f));
 
   if (f->name() == "Main") {

+ 3 - 2
explorer/interpreter/value.cpp

@@ -280,10 +280,11 @@ void Value::Print(llvm::raw_ostream& out) const {
     case Value::Kind::FunctionType: {
       const auto& fn_type = cast<FunctionType>(*this);
       out << "fn ";
-      if (!fn_type.deduced().empty()) {
+      if (!fn_type.deduced_bindings().empty()) {
         out << "[";
         unsigned int i = 0;
-        for (Nonnull<const GenericBinding*> deduced : fn_type.deduced()) {
+        for (Nonnull<const GenericBinding*> deduced :
+             fn_type.deduced_bindings()) {
           if (i != 0) {
             out << ", ";
           }

+ 25 - 13
explorer/interpreter/value.h

@@ -425,31 +425,43 @@ class TypeType : public Value {
 // A function type.
 class FunctionType : public Value {
  public:
-  FunctionType(llvm::ArrayRef<Nonnull<const GenericBinding*>> deduced,
-               Nonnull<const Value*> parameters,
+  // An explicit function parameter that is a `:!` binding:
+  //
+  //     fn MakeEmptyVector(T:! Type) -> Vector(T);
+  struct GenericParameter {
+    size_t index;
+    Nonnull<const GenericBinding*> binding;
+  };
+
+  FunctionType(Nonnull<const Value*> parameters,
+               llvm::ArrayRef<GenericParameter> generic_parameters,
                Nonnull<const Value*> return_type,
-               llvm::ArrayRef<Nonnull<const GenericBinding*>> generic_bindings,
+               llvm::ArrayRef<Nonnull<const GenericBinding*>> deduced_bindings,
                llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings)
       : Value(Kind::FunctionType),
-        deduced_(deduced),
         parameters_(parameters),
+        generic_parameters_(generic_parameters),
         return_type_(return_type),
-        generic_bindings_(generic_bindings),
+        deduced_bindings_(deduced_bindings),
         impl_bindings_(impl_bindings) {}
 
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::FunctionType;
   }
 
-  auto deduced() const -> llvm::ArrayRef<Nonnull<const GenericBinding*>> {
-    return deduced_;
-  }
+  // The type of the function parameter tuple.
   auto parameters() const -> const Value& { return *parameters_; }
+  // Parameters that use a generic `:!` binding at the top level.
+  auto generic_parameters() const -> llvm::ArrayRef<GenericParameter> {
+    return generic_parameters_;
+  }
+  // The function return type.
   auto return_type() const -> const Value& { return *return_type_; }
-  // All generic bindings in this function's signature.
-  auto generic_bindings() const
+  // All generic bindings in this function's signature that should be deduced
+  // in a call. This excludes any generic parameters.
+  auto deduced_bindings() const
       -> llvm::ArrayRef<Nonnull<const GenericBinding*>> {
-    return generic_bindings_;
+    return deduced_bindings_;
   }
   // The bindings for the witness tables (impls) required by the
   // bounds on the type parameters of the generic function.
@@ -458,10 +470,10 @@ class FunctionType : public Value {
   }
 
  private:
-  std::vector<Nonnull<const GenericBinding*>> deduced_;
   Nonnull<const Value*> parameters_;
+  std::vector<GenericParameter> generic_parameters_;
   Nonnull<const Value*> return_type_;
-  std::vector<Nonnull<const GenericBinding*>> generic_bindings_;
+  std::vector<Nonnull<const GenericBinding*>> deduced_bindings_;
   std::vector<Nonnull<const ImplBinding*>> impl_bindings_;
 };
 

+ 1 - 1
explorer/testdata/function/fail_call_with_tuple.carbon

@@ -15,6 +15,6 @@ fn f(x: i32, y: i32) -> i32 { return x + y; }
 fn Main() -> i32 {
   var xy: (i32, i32) = (1, 2);
   // should fail to type-check
-  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/function/fail_call_with_tuple.carbon:[[@LINE+1]]: mismatch in tuple sizes, expected 2 but got 1
+  // CHECK: COMPILATION ERROR: {{.*}}/explorer/testdata/function/fail_call_with_tuple.carbon:[[@LINE+1]]: wrong number of arguments in function call, expected 2 but got 1
   return f(xy);
 }

+ 30 - 0
explorer/testdata/generic_class/convert_from_struct.carbon

@@ -0,0 +1,30 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{explorer} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{explorer} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{explorer} %s
+// CHECK: result: 5
+
+package ExplorerTest api;
+
+class Point(T:! Type) {
+  var x: T;
+  var y: T;
+}
+
+fn GetX[T:! Type](pt: Point(T)) -> T {
+  return pt.x;
+}
+fn GetY(T:! Type, pt: Point(T)) -> T {
+  return pt.y;
+}
+
+fn Main() -> i32 {
+  var p: Point(i32) = {.x = 1, .y = 2};
+  // FIXME: Should `GetX({.x = 1, .y = 2})` work? See #1251.
+  return GetX(p) + GetY(i32, {.x = 3, .y = 4});
+}