浏览代码

Move pattern interpretation to compile time (#904)

Co-authored-by: Jon Meow <46229924+jonmeow@users.noreply.github.com>
Geoff Romer 4 年之前
父节点
当前提交
79e3d284b4

+ 13 - 0
executable_semantics/ast/pattern.h

@@ -60,6 +60,18 @@ class Pattern {
   // and after typechecking it's guaranteed to be true.
   auto has_static_type() const -> bool { return static_type_.has_value(); }
 
+  // The value of this pattern. Cannot be called before typechecking.
+  auto value() const -> const Value& { return **value_; }
+
+  // Sets the value of this pattern. Can only be called once, during
+  // typechecking.
+  void set_value(Nonnull<const Value*> value) { value_ = value; }
+
+  // Returns whether the value has been set. Should only be called
+  // during typechecking: before typechecking it's guaranteed to be false,
+  // and after typechecking it's guaranteed to be true.
+  auto has_value() const -> bool { return value_.has_value(); }
+
  protected:
   // Constructs a Pattern representing syntax at the given line number.
   // `kind` must be the enumerator corresponding to the most-derived type being
@@ -72,6 +84,7 @@ class Pattern {
   SourceLocation source_loc_;
 
   std::optional<Nonnull<const Value*>> static_type_;
+  std::optional<Nonnull<const Value*>> value_;
 };
 
 // A pattern consisting of the `auto` keyword.

+ 22 - 41
executable_semantics/interpreter/interpreter.cpp

@@ -116,8 +116,7 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
             heap_.AllocateValue(arena_->New<VariableType>(deduced.name));
         new_env.Set(deduced.name, a);
       }
-      auto pt = InterpPattern(new_env, &func_def.param_pattern());
-      auto f = arena_->New<FunctionValue>(func_def.name(), pt, func_def.body());
+      Nonnull<const FunctionValue*> f = arena_->New<FunctionValue>(&func_def);
       Address a = heap_.AllocateValue(f);
       env->Set(func_def.name(), a);
       break;
@@ -585,10 +584,8 @@ auto Interpreter::StepExp() -> Transition {
           }
           case Value::Kind::FunctionValue:
             return CallFunction{
-                // TODO: Think about a cleaner way to cast between Ptr types.
-                // (multiple TODOs)
-                .function = Nonnull<const FunctionValue*>(
-                    cast<FunctionValue>(act->results()[0])),
+                .function =
+                    &cast<FunctionValue>(*act->results()[0]).declaration(),
                 .args = act->results()[1],
                 .source_loc = exp.source_loc()};
           default:
@@ -755,44 +752,26 @@ auto Interpreter::StepStmt() -> Transition {
         frame->scopes.Push(arena_->New<Scope>(CurrentEnv()));
         return Spawn{arena_->New<ExpressionAction>(&match_stmt.expression())};
       } else {
-        // Regarding act->pos():
-        // * odd: start interpreting the pattern of a clause
-        // * even: finished interpreting the pattern, now try to match
-        //
-        // Regarding act->results():
-        // * 0: the value that we're matching
-        // * 1: the pattern for clause 0
-        // * 2: the pattern for clause 1
-        // * ...
-        auto clause_num = (act->pos() - 1) / 2;
+        int clause_num = act->pos() - 1;
         if (clause_num >= static_cast<int>(match_stmt.clauses().size())) {
           DeallocateScope(frame->scopes.Top());
           frame->scopes.Pop();
           return Done{};
         }
         auto c = match_stmt.clauses()[clause_num];
+        std::optional<Env> matches = PatternMatch(
+            &c.pattern().value(), act->results()[0], stmt.source_loc());
+        if (matches) {  // We have a match, start the body.
+          // Ensure we don't process any more clauses.
+          act->set_pos(match_stmt.clauses().size() + 1);
 
-        if (act->pos() % 2 == 1) {
-          // start interpreting the pattern of the clause
-          //    { {v :: (match ([]) ...) :: C, E, F} :: S, H}
-          // -> { {pi :: (match ([]) ...) :: C, E, F} :: S, H}
-          return Spawn{arena_->New<PatternAction>(&c.pattern())};
-        } else {  // try to match
-          auto v = act->results()[0];
-          auto pat = act->results()[clause_num + 1];
-          std::optional<Env> matches = PatternMatch(pat, v, stmt.source_loc());
-          if (matches) {  // we have a match, start the body
-            // Ensure we don't process any more clauses.
-            act->set_pos(2 * match_stmt.clauses().size() + 1);
-
-            for (const auto& [name, value] : *matches) {
-              frame->scopes.Top()->values.Set(name, value);
-              frame->scopes.Top()->locals.push_back(name);
-            }
-            return Spawn{arena_->New<StatementAction>(&c.statement())};
-          } else {
-            return RunAgain{};
+          for (const auto& [name, value] : *matches) {
+            frame->scopes.Top()->values.Set(name, value);
+            frame->scopes.Top()->locals.push_back(name);
           }
+          return Spawn{arena_->New<StatementAction>(&c.statement())};
+        } else {
+          return RunAgain{};
         }
       }
     }
@@ -859,14 +838,12 @@ auto Interpreter::StepStmt() -> Transition {
         // -> { {e :: (var x = []) :: C, E, F} :: S, H}
         return Spawn{arena_->New<ExpressionAction>(
             &cast<VariableDefinition>(stmt).init())};
-      } else if (act->pos() == 1) {
-        return Spawn{arena_->New<PatternAction>(
-            &cast<VariableDefinition>(stmt).pattern())};
       } else {
         //    { { v :: (x = []) :: C, E, F} :: S, H}
         // -> { { C, E(x := a), F} :: S, H(a := copy(v))}
         Nonnull<const Value*> v = act->results()[0];
-        Nonnull<const Value*> p = act->results()[1];
+        Nonnull<const Value*> p =
+            &cast<VariableDefinition>(stmt).pattern().value();
 
         std::optional<Env> matches = PatternMatch(p, v, stmt.source_loc());
         CHECK(matches)
@@ -1079,7 +1056,7 @@ class Interpreter::DoTransition {
   void operator()(const CallFunction& call) {
     interpreter->stack_.Top()->todo.Pop();
     std::optional<Env> matches = interpreter->PatternMatch(
-        &call.function->parameters(), call.args, call.source_loc);
+        &call.function->param_pattern().value(), call.args, call.source_loc);
     CHECK(matches.has_value())
         << "internal error in call_function, pattern match failed";
     // Create the new frame and push it on the stack
@@ -1153,6 +1130,10 @@ auto Interpreter::InterpProgram(llvm::ArrayRef<Nonnull<Declaration*>> fs,
   }
 
   while (stack_.Count() > 1 || !stack_.Top()->todo.IsEmpty()) {
+    if (!stack_.Top()->todo.IsEmpty()) {
+      CHECK(stack_.Top()->todo.Top()->kind() != Action::Kind::PatternAction)
+          << "Pattern evaluation must happen before run-time.";
+    }
     Step();
     if (trace_) {
       PrintState(llvm::outs());

+ 1 - 1
executable_semantics/interpreter/interpreter.h

@@ -101,7 +101,7 @@ class Interpreter {
   // stack, then creates a new stack frame which calls the specified function
   // with the specified arguments.
   struct CallFunction {
-    Nonnull<const FunctionValue*> function;
+    Nonnull<const FunctionDeclaration*> function;
     Nonnull<const Value*> args;
     SourceLocation source_loc;
   };

+ 14 - 0
executable_semantics/interpreter/type_checker.cpp

@@ -58,6 +58,16 @@ static void SetStaticType(Nonnull<FunctionDeclaration*> definition,
   }
 }
 
+static void SetValue(Nonnull<Pattern*> pattern, Nonnull<const Value*> value) {
+  // TODO: find some way to CHECK that `value` is identical to pattern->value(),
+  // if it's already set. Unclear if `ValueEqual` is suitable, because it
+  // currently focuses more on "real" values, and disallows the pseudo-values
+  // like `BindingPlaceholderValue` that we get in pattern evaluation.
+  if (!pattern->has_value()) {
+    pattern->set_value(value);
+  }
+}
+
 TypeChecker::ReturnTypeContext::ReturnTypeContext(
     Nonnull<const Value*> orig_return_type, bool is_omitted)
     : is_auto_(isa<AutoType>(orig_return_type)),
@@ -749,6 +759,7 @@ auto TypeChecker::TypeCheckPattern(
         types.Set(*binding.name(), type);
       }
       SetStaticType(&binding, type);
+      SetValue(&binding, interpreter_.InterpPattern(values, &binding));
       return TCResult(types);
     }
     case Pattern::Kind::TuplePattern: {
@@ -775,6 +786,7 @@ auto TypeChecker::TypeCheckPattern(
         field_types.push_back(&field->static_type());
       }
       SetStaticType(&tuple, arena_->New<TupleValue>(std::move(field_types)));
+      SetValue(&tuple, interpreter_.InterpPattern(values, &tuple));
       return TCResult(new_types);
     }
     case Pattern::Kind::AlternativePattern: {
@@ -800,12 +812,14 @@ auto TypeChecker::TypeCheckPattern(
       TCResult arg_results = TypeCheckPattern(&alternative.arguments(), types,
                                               values, *parameter_types);
       SetStaticType(&alternative, choice_type);
+      SetValue(&alternative, interpreter_.InterpPattern(values, &alternative));
       return TCResult(arg_results.types);
     }
     case Pattern::Kind::ExpressionPattern: {
       auto& expression = cast<ExpressionPattern>(*p).expression();
       TCResult result = TypeCheckExp(&expression, types, values);
       SetStaticType(p, &expression.static_type());
+      SetValue(p, interpreter_.InterpPattern(values, p));
       return TCResult(result.types);
     }
   }

+ 3 - 3
executable_semantics/interpreter/value.cpp

@@ -212,7 +212,7 @@ void Value::Print(llvm::raw_ostream& out) const {
       out << (cast<BoolValue>(*this).value() ? "true" : "false");
       break;
     case Value::Kind::FunctionValue:
-      out << "fun<" << cast<FunctionValue>(*this).name() << ">";
+      out << "fun<" << cast<FunctionValue>(*this).declaration().name() << ">";
       break;
     case Value::Kind::PointerValue:
       out << "ptr<" << cast<PointerValue>(*this).value() << ">";
@@ -392,9 +392,9 @@ auto ValueEqual(Nonnull<const Value*> v1, Nonnull<const Value*> v2,
       return cast<PointerValue>(*v1).value() == cast<PointerValue>(*v2).value();
     case Value::Kind::FunctionValue: {
       std::optional<Nonnull<const Statement*>> body1 =
-          cast<FunctionValue>(*v1).body();
+          cast<FunctionValue>(*v1).declaration().body();
       std::optional<Nonnull<const Statement*>> body2 =
-          cast<FunctionValue>(*v2).body();
+          cast<FunctionValue>(*v2).declaration().body();
       return body1.has_value() == body2.has_value() &&
              (!body1.has_value() || *body1 == *body2);
     }

+ 5 - 13
executable_semantics/interpreter/value.h

@@ -126,27 +126,19 @@ class IntValue : public Value {
 // A function value.
 class FunctionValue : public Value {
  public:
-  FunctionValue(std::string name, Nonnull<const Value*> parameters,
-                std::optional<Nonnull<const Statement*>> body)
-      : Value(Kind::FunctionValue),
-        name_(std::move(name)),
-        parameters_(parameters),
-        body_(body) {}
+  FunctionValue(Nonnull<const FunctionDeclaration*> declaration)
+      : Value(Kind::FunctionValue), declaration_(declaration) {}
 
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::FunctionValue;
   }
 
-  auto name() const -> const std::string& { return name_; }
-  auto parameters() const -> const Value& { return *parameters_; }
-  auto body() const -> std::optional<Nonnull<const Statement*>> {
-    return body_;
+  auto declaration() const -> const FunctionDeclaration& {
+    return *declaration_;
   }
 
  private:
-  std::string name_;
-  Nonnull<const Value*> parameters_;
-  std::optional<Nonnull<const Statement*>> body_;
+  Nonnull<const FunctionDeclaration*> declaration_;
 };
 
 // A pointer value.