Browse Source

Refactor Interpreter/TypeChecker to classes to remove interpreter globals (#790)

Along with #789 this addresses most of #769 although global_arena is still a TODO (that's widespread and overlaps with other changes so I wanted to do it after these are in).
Jon Meow 4 years ago
parent
commit
32f5845e7b

+ 3 - 3
executable_semantics/interpreter/BUILD

@@ -103,9 +103,9 @@ cc_library(
 )
 
 cc_library(
-    name = "typecheck",
-    srcs = ["typecheck.cpp"],
-    hdrs = ["typecheck.h"],
+    name = "type_checker",
+    srcs = ["type_checker.cpp"],
+    hdrs = ["type_checker.h"],
     deps = [
         ":dictionary",
         ":interpreter",

+ 2 - 0
executable_semantics/interpreter/dictionary.h

@@ -80,6 +80,8 @@ class Dictionary {
     head = global_arena->RawNew<Node>(std::make_pair(k, v), head);
   }
 
+  bool IsEmpty() { return head == nullptr; }
+
   // The position of the first element of the dictionary
   // or `end()` if the dictionary is empty.
   auto begin() const -> Iterator { return Iterator(head); }

+ 111 - 176
executable_semantics/interpreter/interpreter.cpp

@@ -30,18 +30,15 @@ using llvm::dyn_cast;
 
 namespace Carbon {
 
-State* state = nullptr;
-
-void Step();
 //
 // Auxiliary Functions
 //
 
-void PrintEnv(Env values, llvm::raw_ostream& out) {
+void Interpreter::PrintEnv(Env values, llvm::raw_ostream& out) {
   llvm::ListSeparator sep;
   for (const auto& [name, address] : values) {
     out << sep << name << ": ";
-    state->heap.PrintAddress(address, out);
+    heap.PrintAddress(address, out);
   }
 }
 
@@ -49,40 +46,37 @@ void PrintEnv(Env values, llvm::raw_ostream& out) {
 // State Operations
 //
 
-void PrintStack(const Stack<Ptr<Frame>>& ls, llvm::raw_ostream& out) {
-  llvm::ListSeparator sep(" :: ");
-  for (const auto& frame : ls) {
-    out << sep << *frame;
-  }
-}
-
-auto CurrentEnv(State* state) -> Env {
-  Ptr<Frame> frame = state->stack.Top();
+auto Interpreter::CurrentEnv() -> Env {
+  Ptr<Frame> frame = stack.Top();
   return frame->scopes.Top()->values;
 }
 
 // Returns the given name from the environment, printing an error if not found.
-static auto GetFromEnv(SourceLocation loc, const std::string& name) -> Address {
-  std::optional<Address> pointer = CurrentEnv(state).Get(name);
+auto Interpreter::GetFromEnv(SourceLocation loc, const std::string& name)
+    -> Address {
+  std::optional<Address> pointer = CurrentEnv().Get(name);
   if (!pointer) {
     FATAL_RUNTIME_ERROR(loc) << "could not find `" << name << "`";
   }
   return *pointer;
 }
 
-void PrintState(llvm::raw_ostream& out) {
+void Interpreter::PrintState(llvm::raw_ostream& out) {
   out << "{\nstack: ";
-  PrintStack(state->stack, out);
-  out << "\nheap: " << state->heap;
-  if (!state->stack.IsEmpty() && !state->stack.Top()->scopes.IsEmpty()) {
+  llvm::ListSeparator sep(" :: ");
+  for (const auto& frame : stack) {
+    out << sep << *frame;
+  }
+  out << "\nheap: " << heap;
+  if (!stack.IsEmpty() && !stack.Top()->scopes.IsEmpty()) {
     out << "\nvalues: ";
-    PrintEnv(CurrentEnv(state), out);
+    PrintEnv(CurrentEnv(), out);
   }
   out << "\n}\n";
 }
 
-auto EvalPrim(Operator op, const std::vector<const Value*>& args,
-              SourceLocation loc) -> const Value* {
+static auto EvalPrim(Operator op, const std::vector<const Value*>& args,
+                     SourceLocation loc) -> const Value* {
   switch (op) {
     case Operator::Neg:
       return global_arena->RawNew<IntValue>(-cast<IntValue>(*args[0]).Val());
@@ -112,10 +106,7 @@ auto EvalPrim(Operator op, const std::vector<const Value*>& args,
   }
 }
 
-// Globally-defined entities, such as functions, structs, choices.
-static Env globals;
-
-void InitEnv(const Declaration& d, Env* env) {
+void Interpreter::InitEnv(const Declaration& d, Env* env) {
   switch (d.Tag()) {
     case Declaration::Kind::FunctionDeclaration: {
       const FunctionDefinition& func_def =
@@ -123,14 +114,14 @@ void InitEnv(const Declaration& d, Env* env) {
       Env new_env = *env;
       // Bring the deduced parameters into scope.
       for (const auto& deduced : func_def.deduced_parameters) {
-        Address a = state->heap.AllocateValue(
+        Address a = heap.AllocateValue(
             global_arena->RawNew<VariableType>(deduced.name));
         new_env.Set(deduced.name, a);
       }
       auto pt = InterpPattern(new_env, func_def.param_pattern);
       auto f =
           global_arena->RawNew<FunctionValue>(func_def.name, pt, func_def.body);
-      Address a = state->heap.AllocateValue(f);
+      Address a = heap.AllocateValue(f);
       env->Set(func_def.name, a);
       break;
     }
@@ -153,7 +144,7 @@ void InitEnv(const Declaration& d, Env* env) {
       }
       auto st = global_arena->RawNew<ClassType>(
           class_def.name, std::move(fields), std::move(methods));
-      auto a = state->heap.AllocateValue(st);
+      auto a = heap.AllocateValue(st);
       env->Set(class_def.name, a);
       break;
     }
@@ -167,7 +158,7 @@ void InitEnv(const Declaration& d, Env* env) {
       }
       auto ct =
           global_arena->RawNew<ChoiceType>(choice.Name(), std::move(alts));
-      auto a = state->heap.AllocateValue(ct);
+      auto a = heap.AllocateValue(ct);
       env->Set(choice.Name(), a);
       break;
     }
@@ -177,35 +168,35 @@ void InitEnv(const Declaration& d, Env* env) {
       // Adds an entry in `globals` mapping the variable's name to the
       // result of evaluating the initializer.
       auto v = InterpExp(*env, var.Initializer());
-      Address a = state->heap.AllocateValue(v);
+      Address a = heap.AllocateValue(v);
       env->Set(*var.Binding()->Name(), a);
       break;
     }
   }
 }
 
-static void InitGlobals(const std::list<Ptr<const Declaration>>& fs) {
+void Interpreter::InitGlobals(const std::list<Ptr<const Declaration>>& fs) {
   for (const auto d : fs) {
     InitEnv(*d, &globals);
   }
 }
 
-void DeallocateScope(Ptr<Scope> scope) {
+void Interpreter::DeallocateScope(Ptr<Scope> scope) {
   for (const auto& l : scope->locals) {
     std::optional<Address> a = scope->values.Get(l);
     CHECK(a);
-    state->heap.Deallocate(*a);
+    heap.Deallocate(*a);
   }
 }
 
-void DeallocateLocals(Ptr<Frame> frame) {
+void Interpreter::DeallocateLocals(Ptr<Frame> frame) {
   while (!frame->scopes.IsEmpty()) {
     DeallocateScope(frame->scopes.Top());
     frame->scopes.Pop();
   }
 }
 
-const Value* CreateTuple(Ptr<Action> act, Ptr<const Expression> exp) {
+static const Value* CreateTuple(Ptr<Action> act, Ptr<const Expression> exp) {
   //    { { (v1,...,vn) :: C, E, F} :: S, H}
   // -> { { `(v1,...,vn) :: C, E, F} :: S, H}
   const auto& tup_lit = cast<TupleLiteral>(*exp);
@@ -219,14 +210,14 @@ const Value* CreateTuple(Ptr<Action> act, Ptr<const Expression> exp) {
   return global_arena->RawNew<TupleValue>(std::move(elements));
 }
 
-auto PatternMatch(const Value* p, const Value* v, SourceLocation loc)
-    -> std::optional<Env> {
+auto Interpreter::PatternMatch(const Value* p, const Value* v,
+                               SourceLocation loc) -> std::optional<Env> {
   switch (p->Tag()) {
     case Value::Kind::BindingPlaceholderValue: {
       const auto& placeholder = cast<BindingPlaceholderValue>(*p);
       Env values;
       if (placeholder.Name().has_value()) {
-        Address a = state->heap.AllocateValue(CopyVal(v, loc));
+        Address a = heap.AllocateValue(CopyVal(v, loc));
         values.Set(*placeholder.Name(), a);
       }
       return values;
@@ -314,10 +305,11 @@ auto PatternMatch(const Value* p, const Value* v, SourceLocation loc)
   }
 }
 
-void PatternAssignment(const Value* pat, const Value* val, SourceLocation loc) {
+void Interpreter::PatternAssignment(const Value* pat, const Value* val,
+                                    SourceLocation loc) {
   switch (pat->Tag()) {
     case Value::Kind::PointerValue:
-      state->heap.Write(cast<PointerValue>(*pat).Val(), CopyVal(val, loc), loc);
+      heap.Write(cast<PointerValue>(*pat).Val(), CopyVal(val, loc), loc);
       break;
     case Value::Kind::TupleValue: {
       switch (val->Tag()) {
@@ -366,71 +358,8 @@ void PatternAssignment(const Value* pat, const Value* val, SourceLocation loc) {
   }
 }
 
-// State transition functions
-//
-// The `Step*` family of functions implement state transitions in the
-// interpreter by executing a step of the Action at the top of the todo stack,
-// and then returning a Transition that specifies how `state.stack` should be
-// updated. `Transition` is a variant of several "transition types" representing
-// the different kinds of state transition.
-
-// Transition type which indicates that the current Action is now done.
-struct Done {
-  // The value computed by the Action. Should always be null for Statement
-  // Actions, and never null for any other kind of Action.
-  const Value* result = nullptr;
-};
-
-// Transition type which spawns a new Action on the todo stack above the current
-// Action, and increments the current Action's position counter.
-struct Spawn {
-  Ptr<Action> child;
-};
-
-// Transition type which spawns a new Action that replaces the current action
-// on the todo stack.
-struct Delegate {
-  Ptr<Action> delegate;
-};
-
-// Transition type which keeps the current Action at the top of the stack,
-// and increments its position counter.
-struct RunAgain {};
-
-// Transition type which unwinds the `todo` and `scopes` stacks until it
-// reaches a specified Action lower in the stack.
-struct UnwindTo {
-  const Ptr<Action> new_top;
-};
-
-// Transition type which unwinds the entire current stack frame, and returns
-// a specified value to the caller.
-struct UnwindFunctionCall {
-  const Value* return_val;
-};
-
-// Transition type which removes the current action from the top of the todo
-// stack, then creates a new stack frame which calls the specified function
-// with the specified arguments.
-struct CallFunction {
-  const FunctionValue* function;
-  const Value* args;
-  SourceLocation loc;
-};
-
-// Transition type which does nothing.
-//
-// TODO(geoffromer): This is a temporary placeholder during refactoring. All
-// uses of this type should be replaced with meaningful transitions.
-struct ManualTransition {};
-
-using Transition =
-    std::variant<Done, Spawn, Delegate, RunAgain, UnwindTo, UnwindFunctionCall,
-                 CallFunction, ManualTransition>;
-
-// State transitions for lvalues.
-Transition StepLvalue() {
-  Ptr<Action> act = state->stack.Top()->todo.Top();
+auto Interpreter::StepLvalue() -> Transition {
+  Ptr<Action> act = stack.Top()->todo.Top();
   Ptr<const Expression> exp = cast<LValAction>(*act).Exp();
   if (tracing_output) {
     llvm::outs() << "--- step lvalue " << *exp << " --->\n";
@@ -516,9 +445,8 @@ Transition StepLvalue() {
   }
 }
 
-// State transitions for expressions.
-Transition StepExp() {
-  Ptr<Action> act = state->stack.Top()->todo.Top();
+auto Interpreter::StepExp() -> Transition {
+  Ptr<Action> act = stack.Top()->todo.Top();
   Ptr<const Expression> exp = cast<ExpressionAction>(*act).Exp();
   if (tracing_output) {
     llvm::outs() << "--- step exp " << *exp << " --->\n";
@@ -593,7 +521,7 @@ Transition StepExp() {
       const auto& ident = cast<IdentifierExpression>(*exp);
       // { {x :: C, E, F} :: S, H} -> { {H(E(x)) :: C, E, F} :: S, H}
       Address pointer = GetFromEnv(exp->SourceLoc(), ident.Name());
-      return Done{state->heap.Read(pointer, exp->SourceLoc())};
+      return Done{heap.Read(pointer, exp->SourceLoc())};
     }
     case Expression::Kind::IntLiteral:
       CHECK(act->Pos() == 0);
@@ -662,7 +590,7 @@ Transition StepExp() {
       switch (cast<IntrinsicExpression>(*exp).Intrinsic()) {
         case IntrinsicExpression::IntrinsicKind::Print:
           Address pointer = GetFromEnv(exp->SourceLoc(), "format_str");
-          const Value* pointee = state->heap.Read(pointer, exp->SourceLoc());
+          const Value* pointee = heap.Read(pointer, exp->SourceLoc());
           CHECK(pointee->Tag() == Value::Kind::StringValue);
           // TODO: This could eventually use something like llvm::formatv.
           llvm::outs() << cast<StringValue>(*pointee).Val();
@@ -714,8 +642,8 @@ Transition StepExp() {
   }  // switch (exp->Tag)
 }
 
-Transition StepPattern() {
-  Ptr<Action> act = state->stack.Top()->todo.Top();
+auto Interpreter::StepPattern() -> Transition {
+  Ptr<Action> act = stack.Top()->todo.Top();
   Ptr<const Pattern> pattern = cast<PatternAction>(*act).Pat();
   if (tracing_output) {
     llvm::outs() << "--- step pattern " << *pattern << " --->\n";
@@ -780,7 +708,7 @@ Transition StepPattern() {
   }
 }
 
-auto IsWhileAct(Ptr<Action> act) -> bool {
+static auto IsWhileAct(Ptr<Action> act) -> bool {
   switch (act->Tag()) {
     case Action::Kind::StatementAction:
       switch (cast<StatementAction>(*act).Stmt()->Tag()) {
@@ -794,7 +722,7 @@ auto IsWhileAct(Ptr<Action> act) -> bool {
   }
 }
 
-auto IsBlockAct(Ptr<Action> act) -> bool {
+static auto IsBlockAct(Ptr<Action> act) -> bool {
   switch (act->Tag()) {
     case Action::Kind::StatementAction:
       switch (cast<StatementAction>(*act).Stmt()->Tag()) {
@@ -808,9 +736,8 @@ auto IsBlockAct(Ptr<Action> act) -> bool {
   }
 }
 
-// State transitions for statements.
-Transition StepStmt() {
-  Ptr<Frame> frame = state->stack.Top();
+auto Interpreter::StepStmt() -> Transition {
+  Ptr<Frame> frame = stack.Top();
   Ptr<Action> act = frame->todo.Top();
   Ptr<const Statement> stmt = cast<StatementAction>(*act).Stmt();
   if (tracing_output) {
@@ -853,7 +780,7 @@ Transition StepStmt() {
           auto pat = act->Results()[clause_num + 1];
           std::optional<Env> matches = PatternMatch(pat, v, stmt->SourceLoc());
           if (matches) {  // we have a match, start the body
-            Env values = CurrentEnv(state);
+            Env values = CurrentEnv();
             std::list<std::string> vars;
             for (const auto& [name, value] : *matches) {
               values.Set(name, value);
@@ -924,7 +851,7 @@ Transition StepStmt() {
       if (act->Pos() == 0) {
         const Block& block = cast<Block>(*stmt);
         if (block.Stmt()) {
-          frame->scopes.Push(global_arena->New<Scope>(CurrentEnv(state)));
+          frame->scopes.Push(global_arena->New<Scope>(CurrentEnv()));
           return Spawn{global_arena->New<StatementAction>(*block.Stmt())};
         } else {
           return Done{};
@@ -1040,8 +967,7 @@ Transition StepStmt() {
       CHECK(act->Pos() == 0);
       // Create a continuation object by creating a frame similar the
       // way one is created in a function call.
-      auto scopes =
-          Stack<Ptr<Scope>>(global_arena->New<Scope>(CurrentEnv(state)));
+      auto scopes = Stack<Ptr<Scope>>(global_arena->New<Scope>(CurrentEnv()));
       Stack<Ptr<Action>> todo;
       todo.Push(global_arena->New<StatementAction>(
           global_arena->New<Return>(stmt->SourceLoc())));
@@ -1050,7 +976,7 @@ Transition StepStmt() {
       auto continuation_frame =
           global_arena->New<Frame>("__continuation", scopes, todo);
       Address continuation_address =
-          state->heap.AllocateValue(global_arena->RawNew<ContinuationValue>(
+          heap.AllocateValue(global_arena->RawNew<ContinuationValue>(
               std::vector<Ptr<Frame>>({continuation_frame})));
       // Store the continuation's address in the frame.
       continuation_frame->continuation = continuation_address;
@@ -1081,7 +1007,7 @@ Transition StepStmt() {
             cast<ContinuationValue>(*act->Results()[0]).Stack();
         for (auto frame_iter = continuation_vector.rbegin();
              frame_iter != continuation_vector.rend(); ++frame_iter) {
-          state->stack.Push(*frame_iter);
+          stack.Push(*frame_iter);
         }
         return ManualTransition{};
       }
@@ -1091,25 +1017,28 @@ Transition StepStmt() {
       frame->todo.Pop();
       std::vector<Ptr<Frame>> paused;
       do {
-        paused.push_back(state->stack.Pop());
+        paused.push_back(stack.Pop());
       } while (paused.back()->continuation == std::nullopt);
       // Update the continuation with the paused stack.
-      state->heap.Write(*paused.back()->continuation,
-                        global_arena->RawNew<ContinuationValue>(paused),
-                        stmt->SourceLoc());
+      heap.Write(*paused.back()->continuation,
+                 global_arena->RawNew<ContinuationValue>(paused),
+                 stmt->SourceLoc());
       return ManualTransition{};
   }
 }
 
-// Visitor which implements the behavior associated with each transition type.
-struct DoTransition {
+class Interpreter::DoTransition {
+ public:
+  // Does not take ownership of interpreter.
+  DoTransition(Interpreter* interpreter) : interpreter(interpreter) {}
+
   void operator()(const Done& done) {
-    Ptr<Frame> frame = state->stack.Top();
+    Ptr<Frame> frame = interpreter->stack.Top();
     if (frame->todo.Top()->Tag() != Action::Kind::StatementAction) {
       CHECK(done.result != nullptr);
       frame->todo.Pop();
       if (frame->todo.IsEmpty()) {
-        state->program_value = done.result;
+        interpreter->program_value = done.result;
       } else {
         frame->todo.Top()->AddResult(done.result);
       }
@@ -1120,26 +1049,26 @@ struct DoTransition {
   }
 
   void operator()(const Spawn& spawn) {
-    Ptr<Frame> frame = state->stack.Top();
+    Ptr<Frame> frame = interpreter->stack.Top();
     frame->todo.Top()->IncrementPos();
     frame->todo.Push(spawn.child);
   }
 
   void operator()(const Delegate& delegate) {
-    Ptr<Frame> frame = state->stack.Top();
+    Ptr<Frame> frame = interpreter->stack.Top();
     frame->todo.Pop();
     frame->todo.Push(delegate.delegate);
   }
 
   void operator()(const RunAgain&) {
-    state->stack.Top()->todo.Top()->IncrementPos();
+    interpreter->stack.Top()->todo.Top()->IncrementPos();
   }
 
   void operator()(const UnwindTo& unwind_to) {
-    Ptr<Frame> frame = state->stack.Top();
+    Ptr<Frame> frame = interpreter->stack.Top();
     while (frame->todo.Top() != unwind_to.new_top) {
       if (IsBlockAct(frame->todo.Top())) {
-        DeallocateScope(frame->scopes.Top());
+        interpreter->DeallocateScope(frame->scopes.Top());
         frame->scopes.Pop();
       }
       frame->todo.Pop();
@@ -1147,23 +1076,23 @@ struct DoTransition {
   }
 
   void operator()(const UnwindFunctionCall& unwind) {
-    DeallocateLocals(state->stack.Top());
-    state->stack.Pop();
-    if (state->stack.Top()->todo.IsEmpty()) {
-      state->program_value = unwind.return_val;
+    interpreter->DeallocateLocals(interpreter->stack.Top());
+    interpreter->stack.Pop();
+    if (interpreter->stack.Top()->todo.IsEmpty()) {
+      interpreter->program_value = unwind.return_val;
     } else {
-      state->stack.Top()->todo.Top()->AddResult(unwind.return_val);
+      interpreter->stack.Top()->todo.Top()->AddResult(unwind.return_val);
     }
   }
 
   void operator()(const CallFunction& call) {
-    state->stack.Top()->todo.Pop();
+    interpreter->stack.Top()->todo.Pop();
     std::optional<Env> matches =
-        PatternMatch(call.function->Param(), call.args, call.loc);
+        interpreter->PatternMatch(call.function->Param(), call.args, call.loc);
     CHECK(matches.has_value())
         << "internal error in call_function, pattern match failed";
     // Create the new frame and push it on the stack
-    Env values = globals;
+    Env values = interpreter->globals;
     std::list<std::string> params;
     for (const auto& [name, value] : *matches) {
       values.Set(name, value);
@@ -1174,15 +1103,18 @@ struct DoTransition {
     auto todo = Stack<Ptr<Action>>(
         global_arena->New<StatementAction>(*call.function->Body()));
     auto frame = global_arena->New<Frame>(call.function->Name(), scopes, todo);
-    state->stack.Push(frame);
+    interpreter->stack.Push(frame);
   }
 
   void operator()(const ManualTransition&) {}
+
+ private:
+  Ptr<Interpreter> interpreter;
 };
 
 // State transition.
-void Step() {
-  Ptr<Frame> frame = state->stack.Top();
+void Interpreter::Step() {
+  Ptr<Frame> frame = stack.Top();
   if (frame->todo.IsEmpty()) {
     FATAL_RUNTIME_ERROR_NO_LINE()
         << "fell off end of function " << frame->name << " without `return`";
@@ -1191,23 +1123,27 @@ void Step() {
   Ptr<Action> act = frame->todo.Top();
   switch (act->Tag()) {
     case Action::Kind::LValAction:
-      std::visit(DoTransition(), StepLvalue());
+      std::visit(DoTransition(this), StepLvalue());
       break;
     case Action::Kind::ExpressionAction:
-      std::visit(DoTransition(), StepExp());
+      std::visit(DoTransition(this), StepExp());
       break;
     case Action::Kind::PatternAction:
-      std::visit(DoTransition(), StepPattern());
+      std::visit(DoTransition(this), StepPattern());
       break;
     case Action::Kind::StatementAction:
-      std::visit(DoTransition(), StepStmt());
+      std::visit(DoTransition(this), StepStmt());
       break;
   }  // switch
 }
 
-// Interpret the whole porogram.
-auto InterpProgram(const std::list<Ptr<const Declaration>>& fs) -> int {
-  state = global_arena->RawNew<State>();  // Runtime state.
+auto Interpreter::InterpProgram(const std::list<Ptr<const Declaration>>& fs)
+    -> int {
+  // Check that the interpreter is in a clean state.
+  CHECK(globals.IsEmpty());
+  CHECK(stack.IsEmpty());
+  CHECK(program_value == std::nullopt);
+
   if (tracing_output) {
     llvm::outs() << "********** initializing globals **********\n";
   }
@@ -1221,55 +1157,54 @@ auto InterpProgram(const std::list<Ptr<const Declaration>>& fs) -> int {
   auto todo =
       Stack<Ptr<Action>>(global_arena->New<ExpressionAction>(call_main));
   auto scopes = Stack<Ptr<Scope>>(global_arena->New<Scope>(globals));
-  state->stack =
-      Stack<Ptr<Frame>>(global_arena->New<Frame>("top", scopes, todo));
+  stack = Stack<Ptr<Frame>>(global_arena->New<Frame>("top", scopes, todo));
 
   if (tracing_output) {
     llvm::outs() << "********** calling main function **********\n";
     PrintState(llvm::outs());
   }
 
-  while (state->stack.Count() > 1 || !state->stack.Top()->todo.IsEmpty()) {
+  while (stack.Count() > 1 || !stack.Top()->todo.IsEmpty()) {
     Step();
     if (tracing_output) {
       PrintState(llvm::outs());
     }
   }
-  return cast<IntValue>(**state->program_value).Val();
+  return cast<IntValue>(**program_value).Val();
 }
 
-// Interpret an expression at compile-time.
-auto InterpExp(Env values, Ptr<const Expression> e) -> const Value* {
-  CHECK(state->program_value == std::nullopt);
+auto Interpreter::InterpExp(Env values, Ptr<const Expression> e)
+    -> const Value* {
+  CHECK(program_value == std::nullopt);
   auto program_value_guard =
-      llvm::make_scope_exit([] { state->program_value = std::nullopt; });
+      llvm::make_scope_exit([&] { program_value = std::nullopt; });
   auto todo = Stack<Ptr<Action>>(global_arena->New<ExpressionAction>(e));
   auto scopes = Stack<Ptr<Scope>>(global_arena->New<Scope>(values));
-  state->stack =
+  stack =
       Stack<Ptr<Frame>>(global_arena->New<Frame>("InterpExp", scopes, todo));
 
-  while (state->stack.Count() > 1 || !state->stack.Top()->todo.IsEmpty()) {
+  while (stack.Count() > 1 || !stack.Top()->todo.IsEmpty()) {
     Step();
   }
-  CHECK(state->program_value != std::nullopt);
-  return *state->program_value;
+  CHECK(program_value != std::nullopt);
+  return *program_value;
 }
 
-// Interpret a pattern at compile-time.
-auto InterpPattern(Env values, Ptr<const Pattern> p) -> const Value* {
-  CHECK(state->program_value == std::nullopt);
+auto Interpreter::InterpPattern(Env values, Ptr<const Pattern> p)
+    -> const Value* {
+  CHECK(program_value == std::nullopt);
   auto program_value_guard =
-      llvm::make_scope_exit([] { state->program_value = std::nullopt; });
+      llvm::make_scope_exit([&] { program_value = std::nullopt; });
   auto todo = Stack<Ptr<Action>>(global_arena->New<PatternAction>(p));
   auto scopes = Stack<Ptr<Scope>>(global_arena->New<Scope>(values));
-  state->stack = Stack<Ptr<Frame>>(
+  stack = Stack<Ptr<Frame>>(
       global_arena->New<Frame>("InterpPattern", scopes, todo));
 
-  while (state->stack.Count() > 1 || !state->stack.Top()->todo.IsEmpty()) {
+  while (stack.Count() > 1 || !stack.Top()->todo.IsEmpty()) {
     Step();
   }
-  CHECK(state->program_value != std::nullopt);
-  return *state->program_value;
+  CHECK(program_value != std::nullopt);
+  return *program_value;
 }
 
 }  // namespace Carbon

+ 115 - 17
executable_semantics/interpreter/interpreter.h

@@ -23,28 +23,126 @@ namespace Carbon {
 
 using Env = Dictionary<std::string, Address>;
 
-struct State {
-  Stack<Ptr<Frame>> stack;
-  Heap heap;
-  std::optional<const Value*> program_value;
-};
+class Interpreter {
+ public:
+  // Interpret the whole program.
+  auto InterpProgram(const std::list<Ptr<const Declaration>>& fs) -> int;
+
+  // Interpret an expression at compile-time.
+  auto InterpExp(Env values, Ptr<const Expression> e) -> const Value*;
+
+  // Interpret a pattern at compile-time.
+  auto InterpPattern(Env values, Ptr<const Pattern> p) -> const Value*;
+
+  // Attempts to match `v` against the pattern `p`. If matching succeeds,
+  // returns the bindings of pattern variables to their matched values.
+  auto PatternMatch(const Value* p, const Value* v, SourceLocation loc)
+      -> std::optional<Env>;
+
+  // Support TypeChecker allocating values on the heap.
+  auto AllocateValue(const Value* v) -> Address {
+    return heap.AllocateValue(v);
+  }
+
+  void InitEnv(const Declaration& d, Env* env);
+  void PrintEnv(Env values, llvm::raw_ostream& out);
+
+ private:
+  // State transition functions
+  //
+  // The `Step*` family of functions implement state transitions in the
+  // interpreter by executing a step of the Action at the top of the todo stack,
+  // and then returning a Transition that specifies how `state.stack` should be
+  // updated. `Transition` is a variant of several "transition types"
+  // representing the different kinds of state transition.
+
+  // Transition type which indicates that the current Action is now done.
+  struct Done {
+    // The value computed by the Action. Should always be null for Statement
+    // Actions, and never null for any other kind of Action.
+    const Value* result = nullptr;
+  };
+
+  // Transition type which spawns a new Action on the todo stack above the
+  // current Action, and increments the current Action's position counter.
+  struct Spawn {
+    Ptr<Action> child;
+  };
+
+  // Transition type which spawns a new Action that replaces the current action
+  // on the todo stack.
+  struct Delegate {
+    Ptr<Action> delegate;
+  };
+
+  // Transition type which keeps the current Action at the top of the stack,
+  // and increments its position counter.
+  struct RunAgain {};
 
-extern State* state;
+  // Transition type which unwinds the `todo` and `scopes` stacks until it
+  // reaches a specified Action lower in the stack.
+  struct UnwindTo {
+    const Ptr<Action> new_top;
+  };
 
-void InitEnv(const Declaration& d, Env* env);
-void PrintStack(const Stack<Frame*>& ls, llvm::raw_ostream& out);
-void PrintEnv(Env values, llvm::raw_ostream& out);
+  // Transition type which unwinds the entire current stack frame, and returns
+  // a specified value to the caller.
+  struct UnwindFunctionCall {
+    const Value* return_val;
+  };
 
-/***** Interpreters *****/
+  // Transition type which removes the current action from the top of the todo
+  // stack, then creates a new stack frame which calls the specified function
+  // with the specified arguments.
+  struct CallFunction {
+    const FunctionValue* function;
+    const Value* args;
+    SourceLocation loc;
+  };
 
-// Attempts to match `v` against the pattern `p`. If matching succeeds, returns
-// the bindings of pattern variables to their matched values.
-auto PatternMatch(const Value* p, const Value* v, SourceLocation loc)
-    -> std::optional<Env>;
+  // Transition type which does nothing.
+  //
+  // TODO(geoffromer): This is a temporary placeholder during refactoring. All
+  // uses of this type should be replaced with meaningful transitions.
+  struct ManualTransition {};
 
-auto InterpProgram(const std::list<Ptr<const Declaration>>& fs) -> int;
-auto InterpExp(Env values, Ptr<const Expression> e) -> const Value*;
-auto InterpPattern(Env values, Ptr<const Pattern> p) -> const Value*;
+  using Transition =
+      std::variant<Done, Spawn, Delegate, RunAgain, UnwindTo,
+                   UnwindFunctionCall, CallFunction, ManualTransition>;
+
+  // Visitor which implements the behavior associated with each transition type.
+  class DoTransition;
+
+  void Step();
+
+  // State transitions for expressions.
+  auto StepExp() -> Transition;
+  // State transitions for lvalues.
+  auto StepLvalue() -> Transition;
+  // State transitions for patterns.
+  auto StepPattern() -> Transition;
+  // State transition for statements.
+  auto StepStmt() -> Transition;
+
+  void InitGlobals(const std::list<Ptr<const Declaration>>& fs);
+  auto CurrentEnv() -> Env;
+  auto GetFromEnv(SourceLocation loc, const std::string& name) -> Address;
+
+  void DeallocateScope(Ptr<Scope> scope);
+  void DeallocateLocals(Ptr<Frame> frame);
+
+  void PatternAssignment(const Value* pat, const Value* val,
+                         SourceLocation loc);
+
+  void PrintState(llvm::raw_ostream& out);
+
+  // Globally-defined entities, such as functions, structs, or choices.
+  Env globals;
+
+  Stack<Ptr<Frame>> stack;
+  Heap heap;
+  std::optional<const Value*> program_value;
+};
 
 }  // namespace Carbon
 

+ 54 - 65
executable_semantics/interpreter/typecheck.cpp → executable_semantics/interpreter/type_checker.cpp

@@ -2,7 +2,7 @@
 // Exceptions. See /LICENSE for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "executable_semantics/interpreter/typecheck.h"
+#include "executable_semantics/interpreter/type_checker.h"
 
 #include <algorithm>
 #include <iterator>
@@ -253,26 +253,13 @@ static auto Substitute(TypeEnv dict, const Value* type) -> const Value* {
   }
 }
 
-// The TypeCheckExp function performs semantic analysis on an expression.
-// It returns a new version of the expression, its type, and an
-// updated environment which are bundled into a TCResult object.
-// The purpose of the updated environment is
-// to bring pattern variables into scope, for example, in a match case.
-// The new version of the expression may include more information,
-// for example, the type arguments deduced for the type parameters of a
-// generic.
-//
-// e is the expression to be analyzed.
-// types maps variable names to the type of their run-time value.
-// values maps variable names to their compile-time values. It is not
-//    directly used in this function but is passed to InterExp.
-auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
-    -> TCExpression {
+auto TypeChecker::TypeCheckExp(Ptr<const Expression> e, TypeEnv types,
+                               Env values) -> TCExpression {
   if (tracing_output) {
     llvm::outs() << "checking expression " << *e << "\ntypes: ";
     PrintTypeEnv(types, llvm::outs());
     llvm::outs() << "\nvalues: ";
-    PrintEnv(values, llvm::outs());
+    interpreter.PrintEnv(values, llvm::outs());
     llvm::outs() << "\n";
   }
   switch (e->Tag()) {
@@ -282,7 +269,9 @@ auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
       auto t = res.type;
       switch (t->Tag()) {
         case Value::Kind::TupleValue: {
-          auto i = cast<IntValue>(*InterpExp(values, index.Offset())).Val();
+          auto i =
+              cast<IntValue>(*interpreter.InterpExp(values, index.Offset()))
+                  .Val();
           std::string f = std::to_string(i);
           const Value* field_t = cast<TupleValue>(*t).FindField(f);
           if (field_t == nullptr) {
@@ -505,8 +494,8 @@ auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
     }
     case Expression::Kind::FunctionTypeLiteral: {
       const auto& fn = cast<FunctionTypeLiteral>(*e);
-      auto pt = InterpExp(values, fn.Parameter());
-      auto rt = InterpExp(values, fn.ReturnType());
+      auto pt = interpreter.InterpExp(values, fn.Parameter());
+      auto rt = interpreter.InterpExp(values, fn.ReturnType());
       auto new_e = global_arena->New<FunctionTypeLiteral>(
           e->SourceLoc(), ReifyType(pt, e->SourceLoc()),
           ReifyType(rt, e->SourceLoc()),
@@ -532,8 +521,9 @@ auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
 // Equivalent to TypeCheckExp, but operates on Patterns instead of Expressions.
 // `expected` is the type that this pattern is expected to have, if the
 // surrounding context gives us that information. Otherwise, it is null.
-auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
-                      const Value* expected) -> TCPattern {
+auto TypeChecker::TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types,
+                                   Env values, const Value* expected)
+    -> TCPattern {
   if (tracing_output) {
     llvm::outs() << "checking pattern " << *p;
     if (expected) {
@@ -542,7 +532,7 @@ auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
     llvm::outs() << "\ntypes: ";
     PrintTypeEnv(types, llvm::outs());
     llvm::outs() << "\nvalues: ";
-    PrintEnv(values, llvm::outs());
+    interpreter.PrintEnv(values, llvm::outs());
     llvm::outs() << "\n";
   }
   switch (p->Tag()) {
@@ -555,10 +545,11 @@ auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
       const auto& binding = cast<BindingPattern>(*p);
       TCPattern binding_type_result =
           TypeCheckPattern(binding.Type(), types, values, nullptr);
-      const Value* type = InterpPattern(values, binding_type_result.pattern);
+      const Value* type =
+          interpreter.InterpPattern(values, binding_type_result.pattern);
       if (expected != nullptr) {
-        std::optional<Env> values =
-            PatternMatch(type, expected, binding.Type()->SourceLoc());
+        std::optional<Env> values = interpreter.PatternMatch(
+            type, expected, binding.Type()->SourceLoc());
         if (values == std::nullopt) {
           FATAL_COMPILATION_ERROR(binding.Type()->SourceLoc())
               << "Type pattern '" << *type << "' does not match actual type '"
@@ -617,7 +608,8 @@ auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
     }
     case Pattern::Kind::AlternativePattern: {
       const auto& alternative = cast<AlternativePattern>(*p);
-      const Value* choice_type = InterpExp(values, alternative.ChoiceType());
+      const Value* choice_type =
+          interpreter.InterpExp(values, alternative.ChoiceType());
       if (choice_type->Tag() != Value::Kind::ChoiceType) {
         FATAL_COMPILATION_ERROR(alternative.SourceLoc())
             << "alternative pattern does not name a choice type.";
@@ -656,9 +648,10 @@ auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
   }
 }
 
-static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
-                          Ptr<const Statement> body, TypeEnv types, Env values,
-                          const Value*& ret_type, bool is_omitted_ret_type)
+auto TypeChecker::TypeCheckCase(const Value* expected, Ptr<const Pattern> pat,
+                                Ptr<const Statement> body, TypeEnv types,
+                                Env values, const Value*& ret_type,
+                                bool is_omitted_ret_type)
     -> std::pair<Ptr<const Pattern>, Ptr<const Statement>> {
   auto pat_res = TypeCheckPattern(pat, types, values, expected);
   auto res =
@@ -666,16 +659,9 @@ static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
   return std::make_pair(pat, res.stmt);
 }
 
-// The TypeCheckStmt function performs semantic analysis on a statement.
-// It returns a new version of the statement and a new type environment.
-//
-// The ret_type parameter is used for analyzing return statements.
-// It is the declared return type of the enclosing function definition.
-// If the return type is "auto", then the return type is inferred from
-// the first return statement.
-auto TypeCheckStmt(Ptr<const Statement> s, TypeEnv types, Env values,
-                   const Value*& ret_type, bool is_omitted_ret_type)
-    -> TCStatement {
+auto TypeChecker::TypeCheckStmt(Ptr<const Statement> s, TypeEnv types,
+                                Env values, const Value*& ret_type,
+                                bool is_omitted_ret_type) -> TCStatement {
   switch (s->Tag()) {
     case Statement::Kind::Match: {
       const auto& match = cast<Match>(*s);
@@ -684,7 +670,7 @@ auto TypeCheckStmt(Ptr<const Statement> s, TypeEnv types, Env values,
       auto new_clauses = global_arena->RawNew<
           std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>>();
       for (auto& clause : *match.Clauses()) {
-        new_clauses->push_back(TypecheckCase(res_type, clause.first,
+        new_clauses->push_back(TypeCheckCase(res_type, clause.first,
                                              clause.second, types, values,
                                              ret_type, is_omitted_ret_type));
       }
@@ -903,19 +889,19 @@ static auto CheckOrEnsureReturn(std::optional<Ptr<const Statement>> opt_stmt,
 // a function.
 // TODO: Add checking to function definitions to ensure that
 //   all deduced type parameters will be deduced.
-static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
-                            Env values) -> Ptr<const FunctionDefinition> {
+auto TypeChecker::TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
+                                  Env values) -> Ptr<const FunctionDefinition> {
   // Bring the deduced parameters into scope
   for (const auto& deduced : f->deduced_parameters) {
-    // auto t = InterpExp(values, deduced.type);
+    // auto t = interpreter.InterpExp(values, deduced.type);
     types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
-    Address a = state->heap.AllocateValue(*types.Get(deduced.name));
+    Address a = interpreter.AllocateValue(*types.Get(deduced.name));
     values.Set(deduced.name, a);
   }
   // Type check the parameter pattern
   auto param_res = TypeCheckPattern(f->param_pattern, types, values, nullptr);
   // Evaluate the return type expression
-  auto return_type = InterpPattern(values, f->return_type);
+  auto return_type = interpreter.InterpPattern(values, f->return_type);
   if (f->name == "main") {
     ExpectType(f->source_location, "return type of `main`",
                global_arena->RawNew<IntType>(), return_type);
@@ -936,30 +922,31 @@ static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
       /*is_omitted_return_type=*/false, body);
 }
 
-static auto TypeOfFunDef(TypeEnv types, Env values,
-                         const FunctionDefinition* fun_def) -> const Value* {
+auto TypeChecker::TypeOfFunDef(TypeEnv types, Env values,
+                               const FunctionDefinition* fun_def)
+    -> const Value* {
   // Bring the deduced parameters into scope
   for (const auto& deduced : fun_def->deduced_parameters) {
-    // auto t = InterpExp(values, deduced.type);
+    // auto t = interpreter.InterpExp(values, deduced.type);
     types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
-    Address a = state->heap.AllocateValue(*types.Get(deduced.name));
+    Address a = interpreter.AllocateValue(*types.Get(deduced.name));
     values.Set(deduced.name, a);
   }
   // Type check the parameter pattern
   auto param_res =
       TypeCheckPattern(fun_def->param_pattern, types, values, nullptr);
   // Evaluate the return type expression
-  auto ret = InterpPattern(values, fun_def->return_type);
+  auto ret = interpreter.InterpPattern(values, fun_def->return_type);
   if (ret->Tag() == Value::Kind::AutoType) {
     auto f = TypeCheckFunDef(fun_def, types, values);
-    ret = InterpPattern(values, f->return_type);
+    ret = interpreter.InterpPattern(values, f->return_type);
   }
   return global_arena->RawNew<FunctionType>(fun_def->deduced_parameters,
                                             param_res.type, ret);
 }
 
-static auto TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
-                           Env ct_top) -> const Value* {
+auto TypeChecker::TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
+                                 Env ct_top) -> const Value* {
   VarValues fields;
   VarValues methods;
   for (Ptr<const Member> m : sd->members) {
@@ -976,7 +963,7 @@ static auto TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
           FATAL_COMPILATION_ERROR(binding->SourceLoc())
               << "Struct members must have explicit types";
         }
-        auto type = InterpExp(ct_top, binding_type->Expression());
+        auto type = interpreter.InterpExp(ct_top, binding_type->Expression());
         fields.push_back(std::make_pair(*binding->Name(), type));
         break;
       }
@@ -1006,8 +993,9 @@ static auto GetName(const Declaration& d) -> const std::string& {
   }
 }
 
-auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
-                     const Env& values) -> Ptr<const Declaration> {
+auto TypeChecker::MakeTypeChecked(const Ptr<const Declaration> d,
+                                  const TypeEnv& types, const Env& values)
+    -> Ptr<const Declaration> {
   switch (d->Tag()) {
     case Declaration::Kind::FunctionDeclaration:
       return global_arena->New<FunctionDeclaration>(TypeCheckFunDef(
@@ -1048,7 +1036,7 @@ auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
             << "Type of a top-level variable must be an expression.";
       }
       const Value* declared_type =
-          InterpExp(values, binding_type->Expression());
+          interpreter.InterpExp(values, binding_type->Expression());
       ExpectType(var.SourceLoc(), "initializer of variable", declared_type,
                  type_checked_initializer.type);
       return d;
@@ -1056,21 +1044,21 @@ auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
   }
 }
 
-static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
+void TypeChecker::TopLevel(const Declaration& d, TypeCheckContext* tops) {
   switch (d.Tag()) {
     case Declaration::Kind::FunctionDeclaration: {
       const FunctionDefinition& func_def =
           cast<FunctionDeclaration>(d).Definition();
       auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
       tops->types.Set(func_def.name, t);
-      InitEnv(d, &tops->values);
+      interpreter.InitEnv(d, &tops->values);
       break;
     }
 
     case Declaration::Kind::ClassDeclaration: {
       const ClassDefinition& class_def = cast<ClassDeclaration>(d).Definition();
       auto st = TypeOfClassDef(&class_def, tops->types, tops->values);
-      Address a = state->heap.AllocateValue(st);
+      Address a = interpreter.AllocateValue(st);
       tops->values.Set(class_def.name, a);  // Is this obsolete?
       std::vector<TupleElement> field_types;
       for (const auto& [field_name, field_value] :
@@ -1088,12 +1076,12 @@ static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
       const auto& choice = cast<ChoiceDeclaration>(d);
       VarValues alts;
       for (const auto& [name, signature] : choice.Alternatives()) {
-        auto t = InterpExp(tops->values, signature);
+        auto t = interpreter.InterpExp(tops->values, signature);
         alts.push_back(std::make_pair(name, t));
       }
       auto ct =
           global_arena->RawNew<ChoiceType>(choice.Name(), std::move(alts));
-      Address a = state->heap.AllocateValue(ct);
+      Address a = interpreter.AllocateValue(ct);
       tops->values.Set(choice.Name(), a);  // Is this obsolete?
       tops->types.Set(choice.Name(), ct);
       break;
@@ -1105,14 +1093,15 @@ static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
       // compile-time symbol table.
       Ptr<const Expression> type =
           cast<ExpressionPattern>(*var.Binding()->Type()).Expression();
-      const Value* declared_type = InterpExp(tops->values, type);
+      const Value* declared_type = interpreter.InterpExp(tops->values, type);
       tops->types.Set(*var.Binding()->Name(), declared_type);
       break;
     }
   }
 }
 
-auto TopLevel(const std::list<Ptr<const Declaration>>& fs) -> TypeCheckContext {
+auto TypeChecker::TopLevel(const std::list<Ptr<const Declaration>>& fs)
+    -> TypeCheckContext {
   TypeCheckContext tops;
   bool found_main = false;
 

+ 108 - 0
executable_semantics/interpreter/type_checker.h

@@ -0,0 +1,108 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef EXECUTABLE_SEMANTICS_INTERPRETER_TYPE_CHECKER_H_
+#define EXECUTABLE_SEMANTICS_INTERPRETER_TYPE_CHECKER_H_
+
+#include <set>
+
+#include "common/ostream.h"
+#include "executable_semantics/ast/expression.h"
+#include "executable_semantics/ast/statement.h"
+#include "executable_semantics/common/ptr.h"
+#include "executable_semantics/interpreter/dictionary.h"
+#include "executable_semantics/interpreter/interpreter.h"
+
+namespace Carbon {
+
+using TypeEnv = Dictionary<std::string, const Value*>;
+
+class TypeChecker {
+ public:
+  struct TypeCheckContext {
+    // Symbol table mapping names of runtime entities to their type.
+    TypeEnv types;
+    // Symbol table mapping names of compile time entities to their value.
+    Env values;
+  };
+
+  auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
+                       const Env& values) -> Ptr<const Declaration>;
+
+  auto TopLevel(const std::list<Ptr<const Declaration>>& fs)
+      -> TypeCheckContext;
+
+ private:
+  struct TCExpression {
+    TCExpression(Ptr<const Expression> e, const Value* t, TypeEnv types)
+        : exp(e), type(t), types(types) {}
+
+    Ptr<const Expression> exp;
+    const Value* type;
+    TypeEnv types;
+  };
+
+  struct TCPattern {
+    Ptr<const Pattern> pattern;
+    const Value* type;
+    TypeEnv types;
+  };
+
+  struct TCStatement {
+    TCStatement(Ptr<const Statement> s, TypeEnv types)
+        : stmt(s), types(types) {}
+
+    Ptr<const Statement> stmt;
+    TypeEnv types;
+  };
+
+  // TypeCheckExp performs semantic analysis on an expression.  It returns a new
+  // version of the expression, its type, and an updated environment which are
+  // bundled into a TCResult object.  The purpose of the updated environment is
+  // to bring pattern variables into scope, for example, in a match case.  The
+  // new version of the expression may include more information, for example,
+  // the type arguments deduced for the type parameters of a generic.
+  //
+  // e is the expression to be analyzed.
+  // types maps variable names to the type of their run-time value.
+  // values maps variable names to their compile-time values. It is not
+  //    directly used in this function but is passed to InterExp.
+  auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
+      -> TCExpression;
+
+  auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
+                        const Value* expected) -> TCPattern;
+
+  // TypeCheckStmt performs semantic analysis on a statement.  It returns a new
+  // version of the statement and a new type environment.
+  //
+  // The ret_type parameter is used for analyzing return statements.  It is the
+  // declared return type of the enclosing function definition.  If the return
+  // type is "auto", then the return type is inferred from the first return
+  // statement.
+  auto TypeCheckStmt(Ptr<const Statement> s, TypeEnv types, Env values,
+                     const Value*& ret_type, bool is_omitted_ret_type)
+      -> TCStatement;
+
+  auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types, Env values)
+      -> Ptr<const FunctionDefinition>;
+
+  auto TypeCheckCase(const Value* expected, Ptr<const Pattern> pat,
+                     Ptr<const Statement> body, TypeEnv types, Env values,
+                     const Value*& ret_type, bool is_omitted_ret_type)
+      -> std::pair<Ptr<const Pattern>, Ptr<const Statement>>;
+
+  auto TypeOfFunDef(TypeEnv types, Env values,
+                    const FunctionDefinition* fun_def) -> const Value*;
+  auto TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/, Env ct_top)
+      -> const Value*;
+
+  void TopLevel(const Declaration& d, TypeCheckContext* tops);
+
+  Interpreter interpreter;
+};
+
+}  // namespace Carbon
+
+#endif  // EXECUTABLE_SEMANTICS_INTERPRETER_TYPE_CHECKER_H_

+ 0 - 65
executable_semantics/interpreter/typecheck.h

@@ -1,65 +0,0 @@
-// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
-// Exceptions. See /LICENSE for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef EXECUTABLE_SEMANTICS_INTERPRETER_TYPECHECK_H_
-#define EXECUTABLE_SEMANTICS_INTERPRETER_TYPECHECK_H_
-
-#include <set>
-
-#include "common/ostream.h"
-#include "executable_semantics/ast/expression.h"
-#include "executable_semantics/ast/statement.h"
-#include "executable_semantics/common/ptr.h"
-#include "executable_semantics/interpreter/dictionary.h"
-#include "executable_semantics/interpreter/interpreter.h"
-
-namespace Carbon {
-
-using TypeEnv = Dictionary<std::string, const Value*>;
-
-struct TCExpression {
-  TCExpression(Ptr<const Expression> e, const Value* t, TypeEnv types)
-      : exp(e), type(t), types(types) {}
-
-  Ptr<const Expression> exp;
-  const Value* type;
-  TypeEnv types;
-};
-
-struct TCPattern {
-  Ptr<const Pattern> pattern;
-  const Value* type;
-  TypeEnv types;
-};
-
-struct TCStatement {
-  TCStatement(Ptr<const Statement> s, TypeEnv types) : stmt(s), types(types) {}
-
-  Ptr<const Statement> stmt;
-  TypeEnv types;
-};
-
-struct TypeCheckContext {
-  // Symbol table mapping names of runtime entities to their type.
-  TypeEnv types;
-  // Symbol table mapping names of compile time entities to their value.
-  Env values;
-};
-
-auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
-    -> TCExpression;
-auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
-                      const Value* expected) -> TCPattern;
-
-auto TypeCheckStmt(Ptr<const Statement> s, TypeEnv types, Env values,
-                   const Value*& ret_type, bool is_omitted_ret_type)
-    -> TCStatement;
-
-auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
-                     const Env& values) -> Ptr<const Declaration>;
-auto TopLevel(const std::list<Ptr<const Declaration>>& fs) -> TypeCheckContext;
-
-}  // namespace Carbon
-
-#endif  // EXECUTABLE_SEMANTICS_INTERPRETER_TYPECHECK_H_

+ 1 - 1
executable_semantics/syntax/BUILD

@@ -44,7 +44,7 @@ cc_library(
         "//executable_semantics/common:error",
         "//executable_semantics/common:tracing_flag",
         "//executable_semantics/interpreter",
-        "//executable_semantics/interpreter:typecheck",
+        "//executable_semantics/interpreter:type_checker",
     ],
 )
 

+ 5 - 5
executable_semantics/syntax/syntax_helpers.cpp

@@ -9,7 +9,7 @@
 #include "executable_semantics/common/arena.h"
 #include "executable_semantics/common/tracing_flag.h"
 #include "executable_semantics/interpreter/interpreter.h"
-#include "executable_semantics/interpreter/typecheck.h"
+#include "executable_semantics/interpreter/type_checker.h"
 
 namespace Carbon {
 
@@ -46,13 +46,13 @@ void ExecProgram(std::list<Ptr<const Declaration>> fs) {
     }
     llvm::outs() << "********** type checking **********\n";
   }
-  state = global_arena->RawNew<State>();  // Compile-time state.
-  TypeCheckContext p = TopLevel(fs);
+  TypeChecker type_checker;
+  TypeChecker::TypeCheckContext p = type_checker.TopLevel(fs);
   TypeEnv top = p.types;
   Env ct_top = p.values;
   std::list<Ptr<const Declaration>> new_decls;
   for (const auto decl : fs) {
-    new_decls.push_back(MakeTypeChecked(decl, top, ct_top));
+    new_decls.push_back(type_checker.MakeTypeChecked(decl, top, ct_top));
   }
   if (tracing_output) {
     llvm::outs() << "\n";
@@ -62,7 +62,7 @@ void ExecProgram(std::list<Ptr<const Declaration>> fs) {
     }
     llvm::outs() << "********** starting execution **********\n";
   }
-  int result = InterpProgram(new_decls);
+  int result = Interpreter().InterpProgram(new_decls);
   llvm::outs() << "result: " << result << "\n";
 }