Просмотр исходного кода

Use AST nodes as generic arg deduction keys (#982)

Geoff Romer 4 лет назад
Родитель
Сommit
e65e85b15d

+ 1 - 1
executable_semantics/interpreter/interpreter.cpp

@@ -105,7 +105,7 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
       for (Nonnull<const GenericBinding*> deduced :
            func_def.deduced_parameters()) {
         AllocationId a =
-            heap_.AllocateValue(arena_->New<VariableType>(deduced->name()));
+            heap_.AllocateValue(arena_->New<VariableType>(deduced));
         new_env.Set(deduced->name(), a);
       }
       Nonnull<const FunctionValue*> f = arena_->New<FunctionValue>(&func_def);

+ 37 - 38
executable_semantics/interpreter/type_checker.cpp

@@ -210,20 +210,19 @@ static void ExpectType(SourceLocation source_loc, const std::string& context,
   }
 }
 
-auto TypeChecker::ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
-                                    Nonnull<const Value*> param,
-                                    Nonnull<const Value*> arg) -> TypeEnv {
+void TypeChecker::ArgumentDeduction(
+    SourceLocation source_loc,
+    std::map<Nonnull<const GenericBinding*>, Nonnull<const Value*>>& deduced,
+    Nonnull<const Value*> param, Nonnull<const Value*> arg) {
   switch (param->kind()) {
     case Value::Kind::VariableType: {
       const auto& var_type = cast<VariableType>(*param);
-      std::optional<Nonnull<const Value*>> d = deduced.Get(var_type.name());
-      if (!d) {
-        deduced.Set(var_type.name(), arg);
-      } else {
+      auto [it, success] = deduced.insert({&var_type.binding(), arg});
+      if (!success) {
         // TODO: can we allow implicit conversions here?
-        ExpectExactType(source_loc, "argument deduction", *d, arg);
+        ExpectExactType(source_loc, "argument deduction", it->second, arg);
       }
-      return deduced;
+      return;
     }
     case Value::Kind::TupleValue: {
       if (arg->kind() != Value::Kind::TupleValue) {
@@ -241,11 +240,10 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
             << arg_tup.elements().size();
       }
       for (size_t i = 0; i < param_tup.elements().size(); ++i) {
-        deduced =
-            ArgumentDeduction(source_loc, deduced, param_tup.elements()[i],
-                              arg_tup.elements()[i]);
+        ArgumentDeduction(source_loc, deduced, param_tup.elements()[i],
+                          arg_tup.elements()[i]);
       }
-      return deduced;
+      return;
     }
     case Value::Kind::StructType: {
       if (arg->kind() != Value::Kind::StructType) {
@@ -268,11 +266,10 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
               << "mismatch in field names, " << param_struct.fields()[i].name
               << " != " << arg_struct.fields()[i].name;
         }
-        deduced = ArgumentDeduction(source_loc, deduced,
-                                    param_struct.fields()[i].value,
-                                    arg_struct.fields()[i].value);
+        ArgumentDeduction(source_loc, deduced, param_struct.fields()[i].value,
+                          arg_struct.fields()[i].value);
       }
-      return deduced;
+      return;
     }
     case Value::Kind::FunctionType: {
       if (arg->kind() != Value::Kind::FunctionType) {
@@ -284,11 +281,11 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
       const auto& param_fn = cast<FunctionType>(*param);
       const auto& arg_fn = cast<FunctionType>(*arg);
       // TODO: handle situation when arg has deduced parameters.
-      deduced = ArgumentDeduction(source_loc, deduced, &param_fn.parameters(),
-                                  &arg_fn.parameters());
-      deduced = ArgumentDeduction(source_loc, deduced, &param_fn.return_type(),
-                                  &arg_fn.return_type());
-      return deduced;
+      ArgumentDeduction(source_loc, deduced, &param_fn.parameters(),
+                        &arg_fn.parameters());
+      ArgumentDeduction(source_loc, deduced, &param_fn.return_type(),
+                        &arg_fn.return_type());
+      return;
     }
     case Value::Kind::PointerType: {
       if (arg->kind() != Value::Kind::PointerType) {
@@ -297,13 +294,13 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
             << "expected: " << *param << "\n"
             << "actual: " << *arg;
       }
-      return ArgumentDeduction(source_loc, deduced,
-                               &cast<PointerType>(*param).type(),
-                               &cast<PointerType>(*arg).type());
+      ArgumentDeduction(source_loc, deduced, &cast<PointerType>(*param).type(),
+                        &cast<PointerType>(*arg).type());
+      return;
     }
     // Nothing to do in the case for `auto`.
     case Value::Kind::AutoType: {
-      return deduced;
+      return;
     }
     // For the following cases, we check for type convertability.
     case Value::Kind::ContinuationType:
@@ -314,7 +311,7 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
     case Value::Kind::TypeType:
     case Value::Kind::StringType:
       ExpectType(source_loc, "argument deduction", param, arg);
-      return deduced;
+      return;
     // The rest of these cases should never happen.
     case Value::Kind::IntValue:
     case Value::Kind::BoolValue:
@@ -331,16 +328,16 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
   }
 }
 
-auto TypeChecker::Substitute(TypeEnv dict, Nonnull<const Value*> type)
-    -> Nonnull<const Value*> {
+auto TypeChecker::Substitute(
+    const std::map<Nonnull<const GenericBinding*>, Nonnull<const Value*>>& dict,
+    Nonnull<const Value*> type) -> Nonnull<const Value*> {
   switch (type->kind()) {
     case Value::Kind::VariableType: {
-      std::optional<Nonnull<const Value*>> t =
-          dict.Get(cast<VariableType>(*type).name());
-      if (!t) {
+      auto it = dict.find(&cast<VariableType>(*type).binding());
+      if (it == dict.end()) {
         return type;
       } else {
-        return *t;
+        return it->second;
       }
     }
     case Value::Kind::TupleValue: {
@@ -652,14 +649,16 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
           Nonnull<const Value*> parameters = &fun_t.parameters();
           Nonnull<const Value*> return_type = &fun_t.return_type();
           if (!fun_t.deduced().empty()) {
-            auto deduced_args =
-                ArgumentDeduction(e->source_loc(), TypeEnv(arena_), parameters,
-                                  &call.argument().static_type());
+            std::map<Nonnull<const GenericBinding*>, Nonnull<const Value*>>
+                deduced_args;
+            ArgumentDeduction(e->source_loc(), deduced_args, parameters,
+                              &call.argument().static_type());
             for (Nonnull<const GenericBinding*> deduced_param :
                  fun_t.deduced()) {
               // TODO: change the following to a CHECK once the real checking
               // has been added to the type checking of function signatures.
-              if (!deduced_args.Get(deduced_param->name())) {
+              if (auto it = deduced_args.find(deduced_param);
+                  it == deduced_args.end()) {
                 FATAL_COMPILATION_ERROR(e->source_loc())
                     << "could not deduce type argument for type parameter "
                     << deduced_param->name();
@@ -1029,7 +1028,7 @@ auto TypeChecker::TypeCheckFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
   for (Nonnull<GenericBinding*> deduced : f->deduced_parameters()) {
     TypeCheckExp(&deduced->type(), types, values);
     // auto t = interpreter_.InterpExp(values, deduced.type);
-    SetStaticType(deduced, arena_->New<VariableType>(deduced->name()));
+    SetStaticType(deduced, arena_->New<VariableType>(deduced));
     types.Set(deduced->name(), &deduced->static_type());
     AllocationId a = interpreter_.AllocateValue(*types.Get(deduced->name()));
     values.Set(deduced->name(), a);

+ 8 - 5
executable_semantics/interpreter/type_checker.h

@@ -5,6 +5,7 @@
 #ifndef EXECUTABLE_SEMANTICS_INTERPRETER_TYPE_CHECKER_H_
 #define EXECUTABLE_SEMANTICS_INTERPRETER_TYPE_CHECKER_H_
 
+#include <map>
 #include <set>
 
 #include "common/ostream.h"
@@ -51,9 +52,10 @@ class TypeChecker {
   // inside the argument type.
   // The `deduced` parameter is an accumulator, that is, it holds the
   // results so-far.
-  static auto ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
-                                Nonnull<const Value*> param,
-                                Nonnull<const Value*> arg) -> TypeEnv;
+  static void ArgumentDeduction(
+      SourceLocation source_loc,
+      std::map<Nonnull<const GenericBinding*>, Nonnull<const Value*>>& deduced,
+      Nonnull<const Value*> param, Nonnull<const Value*> arg);
 
   // Traverses the AST rooted at `e`, populating the static_type() of all nodes
   // and ensuring they follow Carbon's typing rules.
@@ -110,8 +112,9 @@ class TypeChecker {
   void ExpectIsConcreteType(SourceLocation source_loc,
                             Nonnull<const Value*> value);
 
-  auto Substitute(TypeEnv dict, Nonnull<const Value*> type)
-      -> Nonnull<const Value*>;
+  auto Substitute(const std::map<Nonnull<const GenericBinding*>,
+                                 Nonnull<const Value*>>& dict,
+                  Nonnull<const Value*> type) -> Nonnull<const Value*>;
 
   Nonnull<Arena*> arena_;
   Interpreter interpreter_;

+ 3 - 2
executable_semantics/interpreter/value.cpp

@@ -241,7 +241,7 @@ void Value::Print(llvm::raw_ostream& out) const {
       out << "choice " << cast<ChoiceType>(*this).name();
       break;
     case Value::Kind::VariableType:
-      out << cast<VariableType>(*this).name();
+      out << cast<VariableType>(*this).binding().name();
       break;
     case Value::Kind::ContinuationValue: {
       out << cast<ContinuationValue>(*this).stack();
@@ -348,7 +348,8 @@ auto TypeEqual(Nonnull<const Value*> t1, Nonnull<const Value*> t2) -> bool {
     case Value::Kind::StringType:
       return true;
     case Value::Kind::VariableType:
-      return cast<VariableType>(*t1).name() == cast<VariableType>(*t2).name();
+      return &cast<VariableType>(*t1).binding() ==
+             &cast<VariableType>(*t2).binding();
     default:
       FATAL() << "TypeEqual used to compare non-type values\n"
               << *t1 << "\n"

+ 4 - 4
executable_semantics/interpreter/value.h

@@ -467,17 +467,17 @@ class ContinuationType : public Value {
 // A variable type.
 class VariableType : public Value {
  public:
-  explicit VariableType(std::string name)
-      : Value(Kind::VariableType), name_(std::move(name)) {}
+  explicit VariableType(Nonnull<const GenericBinding*> binding)
+      : Value(Kind::VariableType), binding_(binding) {}
 
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::VariableType;
   }
 
-  auto name() const -> const std::string& { return name_; }
+  auto binding() const -> const GenericBinding& { return *binding_; }
 
  private:
-  std::string name_;
+  Nonnull<const GenericBinding*> binding_;
 };
 
 // A first-class continuation representation of a fragment of the stack.