Selaa lähdekoodia

Convert Statement to use Ptr (#788)

Note this makes a few cases where the Statement was optional explicit (Block, If, Sequence). I do add a few CHECKs around where statements were optional and assumed but unverified.

I switch TypeCheckStmt to not take an optional Statement because I think it makes the call sites clearer in behavior. It's also a smaller change than the converse, because taking an optional Statement means the returned statement would also need to be optional. Arguably a wrapper for optional statements could be added, but this still seems cleaner to me, and there aren't that many cases of an optional statement.

Co-authored-by: Geoff Romer <gromer@google.com>
Jon Meow 4 vuotta sitten
vanhempi
sitoutus
36ed79dc25

+ 1 - 1
executable_semantics/ast/function_definition.cpp

@@ -27,7 +27,7 @@ void FunctionDefinition::PrintDepth(int depth, llvm::raw_ostream& out) const {
   }
   }
   if (body) {
   if (body) {
     out << " {\n";
     out << " {\n";
-    body->PrintDepth(depth, out);
+    (*body)->PrintDepth(depth, out);
     out << "\n}\n";
     out << "\n}\n";
   } else {
   } else {
     out << ";\n";
     out << ";\n";

+ 3 - 2
executable_semantics/ast/function_definition.h

@@ -26,7 +26,8 @@ struct FunctionDefinition {
                      std::vector<GenericBinding> deduced_params,
                      std::vector<GenericBinding> deduced_params,
                      Ptr<const TuplePattern> param_pattern,
                      Ptr<const TuplePattern> param_pattern,
                      Ptr<const Pattern> return_type,
                      Ptr<const Pattern> return_type,
-                     bool is_omitted_return_type, const Statement* body)
+                     bool is_omitted_return_type,
+                     std::optional<Ptr<const Statement>> body)
       : source_location(source_location),
       : source_location(source_location),
         name(std::move(name)),
         name(std::move(name)),
         deduced_parameters(deduced_params),
         deduced_parameters(deduced_params),
@@ -45,7 +46,7 @@ struct FunctionDefinition {
   Ptr<const TuplePattern> param_pattern;
   Ptr<const TuplePattern> param_pattern;
   Ptr<const Pattern> return_type;
   Ptr<const Pattern> return_type;
   bool is_omitted_return_type;
   bool is_omitted_return_type;
-  const Statement* body;
+  std::optional<Ptr<const Statement>> body;
 };
 };
 
 
 }  // namespace Carbon
 }  // namespace Carbon

+ 3 - 3
executable_semantics/ast/statement.cpp

@@ -65,7 +65,7 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
       if_stmt.ThenStmt()->PrintDepth(depth - 1, out);
       if_stmt.ThenStmt()->PrintDepth(depth - 1, out);
       if (if_stmt.ElseStmt()) {
       if (if_stmt.ElseStmt()) {
         out << "\nelse\n";
         out << "\nelse\n";
-        if_stmt.ElseStmt()->PrintDepth(depth - 1, out);
+        (*if_stmt.ElseStmt())->PrintDepth(depth - 1, out);
       }
       }
       break;
       break;
     }
     }
@@ -87,7 +87,7 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
         out << " ";
         out << " ";
       }
       }
       if (seq.Next()) {
       if (seq.Next()) {
-        seq.Next()->PrintDepth(depth - 1, out);
+        (*seq.Next())->PrintDepth(depth - 1, out);
       }
       }
       break;
       break;
     }
     }
@@ -98,7 +98,7 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
         out << "\n";
         out << "\n";
       }
       }
       if (block.Stmt()) {
       if (block.Stmt()) {
-        block.Stmt()->PrintDepth(depth, out);
+        (*block.Stmt())->PrintDepth(depth, out);
         if (depth < 0 || depth > 1) {
         if (depth < 0 || depth > 1) {
           out << "\n";
           out << "\n";
         }
         }

+ 28 - 23
executable_semantics/ast/statement.h

@@ -109,8 +109,9 @@ class VariableDefinition : public Statement {
 
 
 class If : public Statement {
 class If : public Statement {
  public:
  public:
-  If(SourceLocation loc, Ptr<const Expression> cond, const Statement* then_stmt,
-     const Statement* else_stmt)
+  If(SourceLocation loc, Ptr<const Expression> cond,
+     Ptr<const Statement> then_stmt,
+     std::optional<Ptr<const Statement>> else_stmt)
       : Statement(Kind::If, loc),
       : Statement(Kind::If, loc),
         cond(cond),
         cond(cond),
         then_stmt(then_stmt),
         then_stmt(then_stmt),
@@ -121,13 +122,15 @@ class If : public Statement {
   }
   }
 
 
   auto Cond() const -> Ptr<const Expression> { return cond; }
   auto Cond() const -> Ptr<const Expression> { return cond; }
-  auto ThenStmt() const -> const Statement* { return then_stmt; }
-  auto ElseStmt() const -> const Statement* { return else_stmt; }
+  auto ThenStmt() const -> Ptr<const Statement> { return then_stmt; }
+  auto ElseStmt() const -> std::optional<Ptr<const Statement>> {
+    return else_stmt;
+  }
 
 
  private:
  private:
   Ptr<const Expression> cond;
   Ptr<const Expression> cond;
-  const Statement* then_stmt;
-  const Statement* else_stmt;
+  Ptr<const Statement> then_stmt;
+  std::optional<Ptr<const Statement>> else_stmt;
 };
 };
 
 
 class Return : public Statement {
 class Return : public Statement {
@@ -153,39 +156,41 @@ class Return : public Statement {
 
 
 class Sequence : public Statement {
 class Sequence : public Statement {
  public:
  public:
-  Sequence(SourceLocation loc, const Statement* stmt, const Statement* next)
+  Sequence(SourceLocation loc, Ptr<const Statement> stmt,
+           std::optional<Ptr<const Statement>> next)
       : Statement(Kind::Sequence, loc), stmt(stmt), next(next) {}
       : Statement(Kind::Sequence, loc), stmt(stmt), next(next) {}
 
 
   static auto classof(const Statement* stmt) -> bool {
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::Sequence;
     return stmt->Tag() == Kind::Sequence;
   }
   }
 
 
-  auto Stmt() const -> const Statement* { return stmt; }
-  auto Next() const -> const Statement* { return next; }
+  auto Stmt() const -> Ptr<const Statement> { return stmt; }
+  auto Next() const -> std::optional<Ptr<const Statement>> { return next; }
 
 
  private:
  private:
-  const Statement* stmt;
-  const Statement* next;
+  Ptr<const Statement> stmt;
+  std::optional<Ptr<const Statement>> next;
 };
 };
 
 
 class Block : public Statement {
 class Block : public Statement {
  public:
  public:
-  Block(SourceLocation loc, const Statement* stmt)
+  Block(SourceLocation loc, std::optional<Ptr<const Statement>> stmt)
       : Statement(Kind::Block, loc), stmt(stmt) {}
       : Statement(Kind::Block, loc), stmt(stmt) {}
 
 
   static auto classof(const Statement* stmt) -> bool {
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::Block;
     return stmt->Tag() == Kind::Block;
   }
   }
 
 
-  auto Stmt() const -> const Statement* { return stmt; }
+  auto Stmt() const -> std::optional<Ptr<const Statement>> { return stmt; }
 
 
  private:
  private:
-  const Statement* stmt;
+  std::optional<Ptr<const Statement>> stmt;
 };
 };
 
 
 class While : public Statement {
 class While : public Statement {
  public:
  public:
-  While(SourceLocation loc, Ptr<const Expression> cond, const Statement* body)
+  While(SourceLocation loc, Ptr<const Expression> cond,
+        Ptr<const Statement> body)
       : Statement(Kind::While, loc), cond(cond), body(body) {}
       : Statement(Kind::While, loc), cond(cond), body(body) {}
 
 
   static auto classof(const Statement* stmt) -> bool {
   static auto classof(const Statement* stmt) -> bool {
@@ -193,11 +198,11 @@ class While : public Statement {
   }
   }
 
 
   auto Cond() const -> Ptr<const Expression> { return cond; }
   auto Cond() const -> Ptr<const Expression> { return cond; }
-  auto Body() const -> const Statement* { return body; }
+  auto Body() const -> Ptr<const Statement> { return body; }
 
 
  private:
  private:
   Ptr<const Expression> cond;
   Ptr<const Expression> cond;
-  const Statement* body;
+  Ptr<const Statement> body;
 };
 };
 
 
 class Break : public Statement {
 class Break : public Statement {
@@ -221,7 +226,7 @@ class Continue : public Statement {
 class Match : public Statement {
 class Match : public Statement {
  public:
  public:
   Match(SourceLocation loc, Ptr<const Expression> exp,
   Match(SourceLocation loc, Ptr<const Expression> exp,
-        std::list<std::pair<Ptr<const Pattern>, const Statement*>>* clauses)
+        std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>* clauses)
       : Statement(Kind::Match, loc), exp(exp), clauses(clauses) {}
       : Statement(Kind::Match, loc), exp(exp), clauses(clauses) {}
 
 
   static auto classof(const Statement* stmt) -> bool {
   static auto classof(const Statement* stmt) -> bool {
@@ -230,13 +235,13 @@ class Match : public Statement {
 
 
   auto Exp() const -> Ptr<const Expression> { return exp; }
   auto Exp() const -> Ptr<const Expression> { return exp; }
   auto Clauses() const
   auto Clauses() const
-      -> const std::list<std::pair<Ptr<const Pattern>, const Statement*>>* {
+      -> const std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>* {
     return clauses;
     return clauses;
   }
   }
 
 
  private:
  private:
   Ptr<const Expression> exp;
   Ptr<const Expression> exp;
-  std::list<std::pair<Ptr<const Pattern>, const Statement*>>* clauses;
+  std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>* clauses;
 };
 };
 
 
 // A continuation statement.
 // A continuation statement.
@@ -247,7 +252,7 @@ class Match : public Statement {
 class Continuation : public Statement {
 class Continuation : public Statement {
  public:
  public:
   Continuation(SourceLocation loc, std::string continuation_variable,
   Continuation(SourceLocation loc, std::string continuation_variable,
-               const Statement* body)
+               Ptr<const Statement> body)
       : Statement(Kind::Continuation, loc),
       : Statement(Kind::Continuation, loc),
         continuation_variable(std::move(continuation_variable)),
         continuation_variable(std::move(continuation_variable)),
         body(body) {}
         body(body) {}
@@ -259,11 +264,11 @@ class Continuation : public Statement {
   auto ContinuationVariable() const -> const std::string& {
   auto ContinuationVariable() const -> const std::string& {
     return continuation_variable;
     return continuation_variable;
   }
   }
-  auto Body() const -> const Statement* { return body; }
+  auto Body() const -> Ptr<const Statement> { return body; }
 
 
  private:
  private:
   std::string continuation_variable;
   std::string continuation_variable;
-  const Statement* body;
+  Ptr<const Statement> body;
 };
 };
 
 
 // A run statement.
 // A run statement.

+ 3 - 3
executable_semantics/interpreter/action.h

@@ -117,17 +117,17 @@ class PatternAction : public Action {
 
 
 class StatementAction : public Action {
 class StatementAction : public Action {
  public:
  public:
-  explicit StatementAction(const Statement* stmt)
+  explicit StatementAction(Ptr<const Statement> stmt)
       : Action(Kind::StatementAction), stmt(stmt) {}
       : Action(Kind::StatementAction), stmt(stmt) {}
 
 
   static auto classof(const Action* action) -> bool {
   static auto classof(const Action* action) -> bool {
     return action->Tag() == Kind::StatementAction;
     return action->Tag() == Kind::StatementAction;
   }
   }
 
 
-  auto Stmt() const -> const Statement* { return stmt; }
+  auto Stmt() const -> Ptr<const Statement> { return stmt; }
 
 
  private:
  private:
-  const Statement* stmt;
+  Ptr<const Statement> stmt;
 };
 };
 
 
 }  // namespace Carbon
 }  // namespace Carbon

+ 13 - 14
executable_semantics/interpreter/interpreter.cpp

@@ -812,8 +812,7 @@ auto IsBlockAct(Ptr<Action> act) -> bool {
 Transition StepStmt() {
 Transition StepStmt() {
   Ptr<Frame> frame = state->stack.Top();
   Ptr<Frame> frame = state->stack.Top();
   Ptr<Action> act = frame->todo.Top();
   Ptr<Action> act = frame->todo.Top();
-  const Statement* stmt = cast<StatementAction>(*act).Stmt();
-  CHECK(stmt != nullptr) << "null statement!";
+  Ptr<const Statement> stmt = cast<StatementAction>(*act).Stmt();
   if (tracing_output) {
   if (tracing_output) {
     llvm::outs() << "--- step stmt ";
     llvm::outs() << "--- step stmt ";
     stmt->PrintDepth(1, llvm::outs());
     stmt->PrintDepth(1, llvm::outs());
@@ -861,9 +860,8 @@ Transition StepStmt() {
               vars.push_back(name);
               vars.push_back(name);
             }
             }
             frame->scopes.Push(global_arena->New<Scope>(values, vars));
             frame->scopes.Push(global_arena->New<Scope>(values, vars));
-            const Statement* body_block =
-                global_arena->RawNew<Block>(stmt->SourceLoc(), c->second);
-            auto body_act = global_arena->New<StatementAction>(body_block);
+            auto body_act = global_arena->New<StatementAction>(
+                global_arena->New<Block>(stmt->SourceLoc(), c->second));
             body_act->IncrementPos();
             body_act->IncrementPos();
             frame->todo.Pop(1);
             frame->todo.Pop(1);
             frame->todo.Push(body_act);
             frame->todo.Push(body_act);
@@ -925,9 +923,9 @@ Transition StepStmt() {
     case Statement::Kind::Block: {
     case Statement::Kind::Block: {
       if (act->Pos() == 0) {
       if (act->Pos() == 0) {
         const Block& block = cast<Block>(*stmt);
         const Block& block = cast<Block>(*stmt);
-        if (block.Stmt() != nullptr) {
+        if (block.Stmt()) {
           frame->scopes.Push(global_arena->New<Scope>(CurrentEnv(state)));
           frame->scopes.Push(global_arena->New<Scope>(CurrentEnv(state)));
-          return Spawn{global_arena->New<StatementAction>(block.Stmt())};
+          return Spawn{global_arena->New<StatementAction>(*block.Stmt())};
         } else {
         } else {
           return Done{};
           return Done{};
         }
         }
@@ -1007,7 +1005,7 @@ Transition StepStmt() {
         //      S, H}
         //      S, H}
         // -> { { else_stmt :: C, E, F } :: S, H}
         // -> { { else_stmt :: C, E, F } :: S, H}
         return Delegate{
         return Delegate{
-            global_arena->New<StatementAction>(cast<If>(*stmt).ElseStmt())};
+            global_arena->New<StatementAction>(*cast<If>(*stmt).ElseStmt())};
       } else {
       } else {
         return Done{};
         return Done{};
       }
       }
@@ -1030,9 +1028,9 @@ Transition StepStmt() {
       if (act->Pos() == 0) {
       if (act->Pos() == 0) {
         return Spawn{global_arena->New<StatementAction>(seq.Stmt())};
         return Spawn{global_arena->New<StatementAction>(seq.Stmt())};
       } else {
       } else {
-        if (seq.Next() != nullptr) {
-          return Delegate{
-              global_arena->New<StatementAction>(cast<Sequence>(*stmt).Next())};
+        if (seq.Next()) {
+          return Delegate{global_arena->New<StatementAction>(
+              *cast<Sequence>(*stmt).Next())};
         } else {
         } else {
           return Done{};
           return Done{};
         }
         }
@@ -1046,7 +1044,7 @@ Transition StepStmt() {
           Stack<Ptr<Scope>>(global_arena->New<Scope>(CurrentEnv(state)));
           Stack<Ptr<Scope>>(global_arena->New<Scope>(CurrentEnv(state)));
       Stack<Ptr<Action>> todo;
       Stack<Ptr<Action>> todo;
       todo.Push(global_arena->New<StatementAction>(
       todo.Push(global_arena->New<StatementAction>(
-          global_arena->RawNew<Return>(stmt->SourceLoc())));
+          global_arena->New<Return>(stmt->SourceLoc())));
       todo.Push(
       todo.Push(
           global_arena->New<StatementAction>(cast<Continuation>(*stmt).Body()));
           global_arena->New<StatementAction>(cast<Continuation>(*stmt).Body()));
       auto continuation_frame =
       auto continuation_frame =
@@ -1074,7 +1072,7 @@ Transition StepStmt() {
         // Push an expression statement action to ignore the result
         // Push an expression statement action to ignore the result
         // value from the continuation.
         // value from the continuation.
         auto ignore_result = global_arena->New<StatementAction>(
         auto ignore_result = global_arena->New<StatementAction>(
-            global_arena->RawNew<ExpressionStatement>(
+            global_arena->New<ExpressionStatement>(
                 stmt->SourceLoc(),
                 stmt->SourceLoc(),
                 global_arena->New<TupleLiteral>(stmt->SourceLoc())));
                 global_arena->New<TupleLiteral>(stmt->SourceLoc())));
         frame->todo.Push(ignore_result);
         frame->todo.Push(ignore_result);
@@ -1172,8 +1170,9 @@ struct DoTransition {
       params.push_back(name);
       params.push_back(name);
     }
     }
     auto scopes = Stack<Ptr<Scope>>(global_arena->New<Scope>(values, params));
     auto scopes = Stack<Ptr<Scope>>(global_arena->New<Scope>(values, params));
+    CHECK(call.function->Body()) << "Calling a function that's missing a body";
     auto todo = Stack<Ptr<Action>>(
     auto todo = Stack<Ptr<Action>>(
-        global_arena->New<StatementAction>(call.function->Body()));
+        global_arena->New<StatementAction>(*call.function->Body()));
     auto frame = global_arena->New<Frame>(call.function->Name(), scopes, todo);
     auto frame = global_arena->New<Frame>(call.function->Name(), scopes, todo);
     state->stack.Push(frame);
     state->stack.Push(frame);
   }
   }

+ 67 - 51
executable_semantics/interpreter/typecheck.cpp

@@ -657,9 +657,9 @@ auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
 }
 }
 
 
 static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
 static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
-                          const Statement* body, TypeEnv types, Env values,
+                          Ptr<const Statement> body, TypeEnv types, Env values,
                           const Value*& ret_type, bool is_omitted_ret_type)
                           const Value*& ret_type, bool is_omitted_ret_type)
-    -> std::pair<Ptr<const Pattern>, const Statement*> {
+    -> std::pair<Ptr<const Pattern>, Ptr<const Statement>> {
   auto pat_res = TypeCheckPattern(pat, types, values, expected);
   auto pat_res = TypeCheckPattern(pat, types, values, expected);
   auto res =
   auto res =
       TypeCheckStmt(body, pat_res.types, values, ret_type, is_omitted_ret_type);
       TypeCheckStmt(body, pat_res.types, values, ret_type, is_omitted_ret_type);
@@ -673,26 +673,23 @@ static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
 // It is the declared return type of the enclosing function definition.
 // It is the declared return type of the enclosing function definition.
 // If the return type is "auto", then the return type is inferred from
 // If the return type is "auto", then the return type is inferred from
 // the first return statement.
 // the first return statement.
-auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
+auto TypeCheckStmt(Ptr<const Statement> s, TypeEnv types, Env values,
                    const Value*& ret_type, bool is_omitted_ret_type)
                    const Value*& ret_type, bool is_omitted_ret_type)
     -> TCStatement {
     -> TCStatement {
-  if (!s) {
-    return TCStatement(s, types);
-  }
   switch (s->Tag()) {
   switch (s->Tag()) {
     case Statement::Kind::Match: {
     case Statement::Kind::Match: {
       const auto& match = cast<Match>(*s);
       const auto& match = cast<Match>(*s);
       auto res = TypeCheckExp(match.Exp(), types, values);
       auto res = TypeCheckExp(match.Exp(), types, values);
       auto res_type = res.type;
       auto res_type = res.type;
       auto new_clauses = global_arena->RawNew<
       auto new_clauses = global_arena->RawNew<
-          std::list<std::pair<Ptr<const Pattern>, const Statement*>>>();
+          std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>>();
       for (auto& clause : *match.Clauses()) {
       for (auto& clause : *match.Clauses()) {
         new_clauses->push_back(TypecheckCase(res_type, clause.first,
         new_clauses->push_back(TypecheckCase(res_type, clause.first,
                                              clause.second, types, values,
                                              clause.second, types, values,
                                              ret_type, is_omitted_ret_type));
                                              ret_type, is_omitted_ret_type));
       }
       }
-      const Statement* new_s =
-          global_arena->RawNew<Match>(s->SourceLoc(), res.exp, new_clauses);
+      auto new_s =
+          global_arena->New<Match>(s->SourceLoc(), res.exp, new_clauses);
       return TCStatement(new_s, types);
       return TCStatement(new_s, types);
     }
     }
     case Statement::Kind::While: {
     case Statement::Kind::While: {
@@ -702,39 +699,48 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
                  global_arena->RawNew<BoolType>(), cnd_res.type);
                  global_arena->RawNew<BoolType>(), cnd_res.type);
       auto body_res = TypeCheckStmt(while_stmt.Body(), types, values, ret_type,
       auto body_res = TypeCheckStmt(while_stmt.Body(), types, values, ret_type,
                                     is_omitted_ret_type);
                                     is_omitted_ret_type);
-      auto new_s = global_arena->RawNew<While>(s->SourceLoc(), cnd_res.exp,
-                                               body_res.stmt);
+      auto new_s =
+          global_arena->New<While>(s->SourceLoc(), cnd_res.exp, body_res.stmt);
       return TCStatement(new_s, types);
       return TCStatement(new_s, types);
     }
     }
     case Statement::Kind::Break:
     case Statement::Kind::Break:
     case Statement::Kind::Continue:
     case Statement::Kind::Continue:
       return TCStatement(s, types);
       return TCStatement(s, types);
     case Statement::Kind::Block: {
     case Statement::Kind::Block: {
-      auto stmt_res = TypeCheckStmt(cast<Block>(*s).Stmt(), types, values,
-                                    ret_type, is_omitted_ret_type);
-      return TCStatement(
-          global_arena->RawNew<Block>(s->SourceLoc(), stmt_res.stmt), types);
+      const auto& block = cast<Block>(*s);
+      if (block.Stmt()) {
+        auto stmt_res = TypeCheckStmt(*block.Stmt(), types, values, ret_type,
+                                      is_omitted_ret_type);
+        return TCStatement(
+            global_arena->New<Block>(s->SourceLoc(), stmt_res.stmt), types);
+      } else {
+        return TCStatement(s, types);
+      }
     }
     }
     case Statement::Kind::VariableDefinition: {
     case Statement::Kind::VariableDefinition: {
       const auto& var = cast<VariableDefinition>(*s);
       const auto& var = cast<VariableDefinition>(*s);
       auto res = TypeCheckExp(var.Init(), types, values);
       auto res = TypeCheckExp(var.Init(), types, values);
       const Value* rhs_ty = res.type;
       const Value* rhs_ty = res.type;
       auto lhs_res = TypeCheckPattern(var.Pat(), types, values, rhs_ty);
       auto lhs_res = TypeCheckPattern(var.Pat(), types, values, rhs_ty);
-      const Statement* new_s = global_arena->RawNew<VariableDefinition>(
-          s->SourceLoc(), var.Pat(), res.exp);
+      auto new_s = global_arena->New<VariableDefinition>(s->SourceLoc(),
+                                                         var.Pat(), res.exp);
       return TCStatement(new_s, lhs_res.types);
       return TCStatement(new_s, lhs_res.types);
     }
     }
     case Statement::Kind::Sequence: {
     case Statement::Kind::Sequence: {
       const auto& seq = cast<Sequence>(*s);
       const auto& seq = cast<Sequence>(*s);
       auto stmt_res = TypeCheckStmt(seq.Stmt(), types, values, ret_type,
       auto stmt_res = TypeCheckStmt(seq.Stmt(), types, values, ret_type,
                                     is_omitted_ret_type);
                                     is_omitted_ret_type);
-      auto types2 = stmt_res.types;
-      auto next_res = TypeCheckStmt(seq.Next(), types2, values, ret_type,
-                                    is_omitted_ret_type);
-      auto types3 = next_res.types;
-      return TCStatement(global_arena->RawNew<Sequence>(
-                             s->SourceLoc(), stmt_res.stmt, next_res.stmt),
-                         types3);
+      auto checked_types = stmt_res.types;
+      std::optional<Ptr<const Statement>> next_stmt;
+      if (seq.Next()) {
+        auto next_res = TypeCheckStmt(*seq.Next(), checked_types, values,
+                                      ret_type, is_omitted_ret_type);
+        next_stmt = next_res.stmt;
+        checked_types = next_res.types;
+      }
+      return TCStatement(
+          global_arena->New<Sequence>(s->SourceLoc(), stmt_res.stmt, next_stmt),
+          checked_types);
     }
     }
     case Statement::Kind::Assign: {
     case Statement::Kind::Assign: {
       const auto& assign = cast<Assign>(*s);
       const auto& assign = cast<Assign>(*s);
@@ -743,15 +749,15 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
       auto lhs_res = TypeCheckExp(assign.Lhs(), types, values);
       auto lhs_res = TypeCheckExp(assign.Lhs(), types, values);
       auto lhs_t = lhs_res.type;
       auto lhs_t = lhs_res.type;
       ExpectType(s->SourceLoc(), "assign", lhs_t, rhs_t);
       ExpectType(s->SourceLoc(), "assign", lhs_t, rhs_t);
-      auto new_s = global_arena->RawNew<Assign>(s->SourceLoc(), lhs_res.exp,
-                                                rhs_res.exp);
+      auto new_s =
+          global_arena->New<Assign>(s->SourceLoc(), lhs_res.exp, rhs_res.exp);
       return TCStatement(new_s, lhs_res.types);
       return TCStatement(new_s, lhs_res.types);
     }
     }
     case Statement::Kind::ExpressionStatement: {
     case Statement::Kind::ExpressionStatement: {
       auto res =
       auto res =
           TypeCheckExp(cast<ExpressionStatement>(*s).Exp(), types, values);
           TypeCheckExp(cast<ExpressionStatement>(*s).Exp(), types, values);
       auto new_s =
       auto new_s =
-          global_arena->RawNew<ExpressionStatement>(s->SourceLoc(), res.exp);
+          global_arena->New<ExpressionStatement>(s->SourceLoc(), res.exp);
       return TCStatement(new_s, types);
       return TCStatement(new_s, types);
     }
     }
     case Statement::Kind::If: {
     case Statement::Kind::If: {
@@ -761,10 +767,14 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
                  global_arena->RawNew<BoolType>(), cnd_res.type);
                  global_arena->RawNew<BoolType>(), cnd_res.type);
       auto then_res = TypeCheckStmt(if_stmt.ThenStmt(), types, values, ret_type,
       auto then_res = TypeCheckStmt(if_stmt.ThenStmt(), types, values, ret_type,
                                     is_omitted_ret_type);
                                     is_omitted_ret_type);
-      auto else_res = TypeCheckStmt(if_stmt.ElseStmt(), types, values, ret_type,
-                                    is_omitted_ret_type);
-      auto new_s = global_arena->RawNew<If>(s->SourceLoc(), cnd_res.exp,
-                                            then_res.stmt, else_res.stmt);
+      std::optional<Ptr<const Statement>> else_stmt;
+      if (if_stmt.ElseStmt()) {
+        auto else_res = TypeCheckStmt(*if_stmt.ElseStmt(), types, values,
+                                      ret_type, is_omitted_ret_type);
+        else_stmt = else_res.stmt;
+      }
+      auto new_s = global_arena->New<If>(s->SourceLoc(), cnd_res.exp,
+                                         then_res.stmt, else_stmt);
       return TCStatement(new_s, types);
       return TCStatement(new_s, types);
     }
     }
     case Statement::Kind::Return: {
     case Statement::Kind::Return: {
@@ -783,15 +793,15 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
             << *s << " should" << (is_omitted_ret_type ? " not" : "")
             << *s << " should" << (is_omitted_ret_type ? " not" : "")
             << " provide a return value, to match the function's signature.";
             << " provide a return value, to match the function's signature.";
       }
       }
-      return TCStatement(global_arena->RawNew<Return>(s->SourceLoc(), res.exp,
-                                                      ret.IsOmittedExp()),
+      return TCStatement(global_arena->New<Return>(s->SourceLoc(), res.exp,
+                                                   ret.IsOmittedExp()),
                          types);
                          types);
     }
     }
     case Statement::Kind::Continuation: {
     case Statement::Kind::Continuation: {
       const auto& cont = cast<Continuation>(*s);
       const auto& cont = cast<Continuation>(*s);
       TCStatement body_result = TypeCheckStmt(cont.Body(), types, values,
       TCStatement body_result = TypeCheckStmt(cont.Body(), types, values,
                                               ret_type, is_omitted_ret_type);
                                               ret_type, is_omitted_ret_type);
-      const Statement* new_continuation = global_arena->RawNew<Continuation>(
+      auto new_continuation = global_arena->New<Continuation>(
           s->SourceLoc(), cont.ContinuationVariable(), body_result.stmt);
           s->SourceLoc(), cont.ContinuationVariable(), body_result.stmt);
       types.Set(cont.ContinuationVariable(),
       types.Set(cont.ContinuationVariable(),
                 global_arena->RawNew<ContinuationType>());
                 global_arena->RawNew<ContinuationType>());
@@ -803,8 +813,8 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
       ExpectType(s->SourceLoc(), "argument of `run`",
       ExpectType(s->SourceLoc(), "argument of `run`",
                  global_arena->RawNew<ContinuationType>(),
                  global_arena->RawNew<ContinuationType>(),
                  argument_result.type);
                  argument_result.type);
-      const Statement* new_run =
-          global_arena->RawNew<Run>(s->SourceLoc(), argument_result.exp);
+      auto new_run =
+          global_arena->New<Run>(s->SourceLoc(), argument_result.exp);
       return TCStatement(new_run, types);
       return TCStatement(new_run, types);
     }
     }
     case Statement::Kind::Await: {
     case Statement::Kind::Await: {
@@ -814,38 +824,40 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
   }  // switch
   }  // switch
 }
 }
 
 
-static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
-                                SourceLocation loc) -> const Statement* {
-  if (!stmt) {
+static auto CheckOrEnsureReturn(std::optional<Ptr<const Statement>> opt_stmt,
+                                bool omitted_ret_type, SourceLocation loc)
+    -> Ptr<const Statement> {
+  if (!opt_stmt) {
     if (omitted_ret_type) {
     if (omitted_ret_type) {
-      return global_arena->RawNew<Return>(loc);
+      return global_arena->New<Return>(loc);
     } else {
     } else {
       FATAL_COMPILATION_ERROR(loc)
       FATAL_COMPILATION_ERROR(loc)
           << "control-flow reaches end of function that provides a `->` return "
           << "control-flow reaches end of function that provides a `->` return "
              "type without reaching a return statement";
              "type without reaching a return statement";
     }
     }
   }
   }
+  Ptr<const Statement> stmt = *opt_stmt;
   switch (stmt->Tag()) {
   switch (stmt->Tag()) {
     case Statement::Kind::Match: {
     case Statement::Kind::Match: {
       const auto& match = cast<Match>(*stmt);
       const auto& match = cast<Match>(*stmt);
       auto new_clauses = global_arena->RawNew<
       auto new_clauses = global_arena->RawNew<
-          std::list<std::pair<Ptr<const Pattern>, const Statement*>>>();
+          std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>>();
       for (const auto& clause : *match.Clauses()) {
       for (const auto& clause : *match.Clauses()) {
         auto s = CheckOrEnsureReturn(clause.second, omitted_ret_type,
         auto s = CheckOrEnsureReturn(clause.second, omitted_ret_type,
                                      stmt->SourceLoc());
                                      stmt->SourceLoc());
         new_clauses->push_back(std::make_pair(clause.first, s));
         new_clauses->push_back(std::make_pair(clause.first, s));
       }
       }
-      return global_arena->RawNew<Match>(stmt->SourceLoc(), match.Exp(),
-                                         new_clauses);
+      return global_arena->New<Match>(stmt->SourceLoc(), match.Exp(),
+                                      new_clauses);
     }
     }
     case Statement::Kind::Block:
     case Statement::Kind::Block:
-      return global_arena->RawNew<Block>(
+      return global_arena->New<Block>(
           stmt->SourceLoc(),
           stmt->SourceLoc(),
           CheckOrEnsureReturn(cast<Block>(*stmt).Stmt(), omitted_ret_type,
           CheckOrEnsureReturn(cast<Block>(*stmt).Stmt(), omitted_ret_type,
                               stmt->SourceLoc()));
                               stmt->SourceLoc()));
     case Statement::Kind::If: {
     case Statement::Kind::If: {
       const auto& if_stmt = cast<If>(*stmt);
       const auto& if_stmt = cast<If>(*stmt);
-      return global_arena->RawNew<If>(
+      return global_arena->New<If>(
           stmt->SourceLoc(), if_stmt.Cond(),
           stmt->SourceLoc(), if_stmt.Cond(),
           CheckOrEnsureReturn(if_stmt.ThenStmt(), omitted_ret_type,
           CheckOrEnsureReturn(if_stmt.ThenStmt(), omitted_ret_type,
                               stmt->SourceLoc()),
                               stmt->SourceLoc()),
@@ -857,7 +869,7 @@ static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
     case Statement::Kind::Sequence: {
     case Statement::Kind::Sequence: {
       const auto& seq = cast<Sequence>(*stmt);
       const auto& seq = cast<Sequence>(*stmt);
       if (seq.Next()) {
       if (seq.Next()) {
-        return global_arena->RawNew<Sequence>(
+        return global_arena->New<Sequence>(
             stmt->SourceLoc(), seq.Stmt(),
             stmt->SourceLoc(), seq.Stmt(),
             CheckOrEnsureReturn(seq.Next(), omitted_ret_type,
             CheckOrEnsureReturn(seq.Next(), omitted_ret_type,
                                 stmt->SourceLoc()));
                                 stmt->SourceLoc()));
@@ -877,8 +889,8 @@ static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
     case Statement::Kind::Continue:
     case Statement::Kind::Continue:
     case Statement::Kind::VariableDefinition:
     case Statement::Kind::VariableDefinition:
       if (omitted_ret_type) {
       if (omitted_ret_type) {
-        return global_arena->RawNew<Sequence>(
-            stmt->SourceLoc(), stmt, global_arena->RawNew<Return>(loc));
+        return global_arena->New<Sequence>(stmt->SourceLoc(), stmt,
+                                           global_arena->New<Return>(loc));
       } else {
       } else {
         FATAL_COMPILATION_ERROR(stmt->SourceLoc())
         FATAL_COMPILATION_ERROR(stmt->SourceLoc())
             << "control-flow reaches end of function that provides a `->` "
             << "control-flow reaches end of function that provides a `->` "
@@ -909,9 +921,13 @@ static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
                global_arena->RawNew<IntType>(), return_type);
                global_arena->RawNew<IntType>(), return_type);
     // TODO: Check that main doesn't have any parameters.
     // TODO: Check that main doesn't have any parameters.
   }
   }
-  auto res = TypeCheckStmt(f->body, param_res.types, values, return_type,
-                           f->is_omitted_return_type);
-  auto body = CheckOrEnsureReturn(res.stmt, f->is_omitted_return_type,
+  std::optional<Ptr<const Statement>> body_stmt;
+  if (f->body) {
+    auto res = TypeCheckStmt(*f->body, param_res.types, values, return_type,
+                             f->is_omitted_return_type);
+    body_stmt = res.stmt;
+  }
+  auto body = CheckOrEnsureReturn(body_stmt, f->is_omitted_return_type,
                                   f->source_location);
                                   f->source_location);
   return global_arena->New<FunctionDefinition>(
   return global_arena->New<FunctionDefinition>(
       f->source_location, f->name, f->deduced_parameters, f->param_pattern,
       f->source_location, f->name, f->deduced_parameters, f->param_pattern,

+ 3 - 3
executable_semantics/interpreter/typecheck.h

@@ -34,9 +34,9 @@ struct TCPattern {
 };
 };
 
 
 struct TCStatement {
 struct TCStatement {
-  TCStatement(const Statement* s, TypeEnv types) : stmt(s), types(types) {}
+  TCStatement(Ptr<const Statement> s, TypeEnv types) : stmt(s), types(types) {}
 
 
-  const Statement* stmt;
+  Ptr<const Statement> stmt;
   TypeEnv types;
   TypeEnv types;
 };
 };
 
 
@@ -52,7 +52,7 @@ auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
 auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
 auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
                       const Value* expected) -> TCPattern;
                       const Value* expected) -> TCPattern;
 
 
-auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
+auto TypeCheckStmt(Ptr<const Statement> s, TypeEnv types, Env values,
                    const Value*& ret_type, bool is_omitted_ret_type)
                    const Value*& ret_type, bool is_omitted_ret_type)
     -> TCStatement;
     -> TCStatement;
 
 

+ 8 - 2
executable_semantics/interpreter/value.cpp

@@ -399,8 +399,14 @@ auto ValueEqual(const Value* v1, const Value* v2, SourceLocation loc) -> bool {
       return cast<BoolValue>(*v1).Val() == cast<BoolValue>(*v2).Val();
       return cast<BoolValue>(*v1).Val() == cast<BoolValue>(*v2).Val();
     case Value::Kind::PointerValue:
     case Value::Kind::PointerValue:
       return cast<PointerValue>(*v1).Val() == cast<PointerValue>(*v2).Val();
       return cast<PointerValue>(*v1).Val() == cast<PointerValue>(*v2).Val();
-    case Value::Kind::FunctionValue:
-      return cast<FunctionValue>(*v1).Body() == cast<FunctionValue>(*v2).Body();
+    case Value::Kind::FunctionValue: {
+      std::optional<Ptr<const Statement>> body1 =
+          cast<FunctionValue>(*v1).Body();
+      std::optional<Ptr<const Statement>> body2 =
+          cast<FunctionValue>(*v2).Body();
+      return body1.has_value() == body2.has_value() &&
+             (!body1.has_value() || *body1 == *body2);
+    }
     case Value::Kind::TupleValue:
     case Value::Kind::TupleValue:
       return FieldsValueEqual(cast<TupleValue>(*v1).Elements(),
       return FieldsValueEqual(cast<TupleValue>(*v1).Elements(),
                               cast<TupleValue>(*v2).Elements(), loc);
                               cast<TupleValue>(*v2).Elements(), loc);

+ 4 - 3
executable_semantics/interpreter/value.h

@@ -121,7 +121,8 @@ class IntValue : public Value {
 // A function value.
 // A function value.
 class FunctionValue : public Value {
 class FunctionValue : public Value {
  public:
  public:
-  FunctionValue(std::string name, const Value* param, const Statement* body)
+  FunctionValue(std::string name, const Value* param,
+                std::optional<Ptr<const Statement>> body)
       : Value(Kind::FunctionValue),
       : Value(Kind::FunctionValue),
         name(std::move(name)),
         name(std::move(name)),
         param(param),
         param(param),
@@ -133,12 +134,12 @@ class FunctionValue : public Value {
 
 
   auto Name() const -> const std::string& { return name; }
   auto Name() const -> const std::string& { return name; }
   auto Param() const -> const Value* { return param; }
   auto Param() const -> const Value* { return param; }
-  auto Body() const -> const Statement* { return body; }
+  auto Body() const -> std::optional<Ptr<const Statement>> { return body; }
 
 
  private:
  private:
   std::string name;
   std::string name;
   const Value* param;
   const Value* param;
-  const Statement* body;
+  std::optional<Ptr<const Statement>> body;
 };
 };
 
 
 // A pointer value.
 // A pointer value.

+ 28 - 28
executable_semantics/syntax/parser.ypp

@@ -101,12 +101,12 @@ void Carbon::Parser::error(const location_type&, const std::string& message) {
 %type <BisonWrap<Ptr<const FunctionDefinition>>> function_declaration
 %type <BisonWrap<Ptr<const FunctionDefinition>>> function_declaration
 %type <BisonWrap<Ptr<const FunctionDefinition>>> function_definition
 %type <BisonWrap<Ptr<const FunctionDefinition>>> function_definition
 %type <std::list<Ptr<const Declaration>>> declaration_list
 %type <std::list<Ptr<const Declaration>>> declaration_list
-%type <const Statement*> statement
-%type <const Statement*> if_statement
-%type <const Statement*> optional_else
+%type <BisonWrap<Ptr<const Statement>>> statement
+%type <BisonWrap<Ptr<const Statement>>> if_statement
+%type <std::optional<Ptr<const Statement>>> optional_else
 %type <BisonWrap<std::pair<Ptr<const Expression>, bool>>> return_expression
 %type <BisonWrap<std::pair<Ptr<const Expression>, bool>>> return_expression
-%type <const Statement*> block
-%type <const Statement*> statement_list
+%type <BisonWrap<Ptr<const Statement>>> block
+%type <std::optional<Ptr<const Statement>>> statement_list
 %type <BisonWrap<Ptr<const Expression>>> expression
 %type <BisonWrap<Ptr<const Expression>>> expression
 %type <BisonWrap<GenericBinding>> generic_binding
 %type <BisonWrap<GenericBinding>> generic_binding
 %type <std::vector<GenericBinding>> deduced_params
 %type <std::vector<GenericBinding>> deduced_params
@@ -131,8 +131,8 @@ void Carbon::Parser::error(const location_type&, const std::string& message) {
 %type <ParenContents<Pattern>> paren_pattern_contents
 %type <ParenContents<Pattern>> paren_pattern_contents
 %type <BisonWrap<std::pair<std::string, Ptr<const Expression>>>> alternative
 %type <BisonWrap<std::pair<std::string, Ptr<const Expression>>>> alternative
 %type <std::list<std::pair<std::string, Ptr<const Expression>>>> alternative_list
 %type <std::list<std::pair<std::string, Ptr<const Expression>>>> alternative_list
-%type <std::pair<Ptr<const Pattern>, const Statement*>*> clause
-%type <std::list<std::pair<Ptr<const Pattern>, const Statement*>>*> clause_list
+%type <std::pair<Ptr<const Pattern>, Ptr<const Statement>>*> clause
+%type <std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>*> clause_list
 %token END_OF_FILE 0
 %token END_OF_FILE 0
 %token AND
 %token AND
 %token OR
 %token OR
@@ -408,61 +408,61 @@ maybe_empty_tuple_pattern:
 ;
 ;
 clause:
 clause:
   CASE pattern DBLARROW statement
   CASE pattern DBLARROW statement
-    { $$ = global_arena->RawNew<std::pair<Ptr<const Pattern>, const Statement*>>($2, $4); }
+    { $$ = global_arena->RawNew<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>($2, $4); }
 | DEFAULT DBLARROW statement
 | DEFAULT DBLARROW statement
     {
     {
       auto vp = global_arena->New<BindingPattern>(
       auto vp = global_arena->New<BindingPattern>(
           context.SourceLoc(), std::nullopt, global_arena->New<AutoPattern>(context.SourceLoc()));
           context.SourceLoc(), std::nullopt, global_arena->New<AutoPattern>(context.SourceLoc()));
-      $$ = global_arena->RawNew<std::pair<Ptr<const Pattern>, const Statement*>>(vp, $3);
+      $$ = global_arena->RawNew<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>(vp, $3);
     }
     }
 ;
 ;
 clause_list:
 clause_list:
   // Empty
   // Empty
     {
     {
       $$ = global_arena->RawNew<std::list<
       $$ = global_arena->RawNew<std::list<
-          std::pair<Ptr<const Pattern>, const Statement*>>>();
+          std::pair<Ptr<const Pattern>, Ptr<const Statement>>>>();
     }
     }
 | clause clause_list
 | clause clause_list
     { $$ = $2; $$->push_front(*$1); }
     { $$ = $2; $$->push_front(*$1); }
 ;
 ;
 statement:
 statement:
   expression "=" expression ";"
   expression "=" expression ";"
-    { $$ = global_arena->RawNew<Assign>(context.SourceLoc(), $1, $3); }
+    { $$ = global_arena->New<Assign>(context.SourceLoc(), $1, $3); }
 | VAR pattern "=" expression ";"
 | VAR pattern "=" expression ";"
-    { $$ = global_arena->RawNew<VariableDefinition>(context.SourceLoc(), $2, $4); }
+    { $$ = global_arena->New<VariableDefinition>(context.SourceLoc(), $2, $4); }
 | expression ";"
 | expression ";"
-    { $$ = global_arena->RawNew<ExpressionStatement>(context.SourceLoc(), $1); }
+    { $$ = global_arena->New<ExpressionStatement>(context.SourceLoc(), $1); }
 | if_statement
 | if_statement
     { $$ = $1; }
     { $$ = $1; }
 | WHILE "(" expression ")" block
 | WHILE "(" expression ")" block
-    { $$ = global_arena->RawNew<While>(context.SourceLoc(), $3, $5); }
+    { $$ = global_arena->New<While>(context.SourceLoc(), $3, $5); }
 | BREAK ";"
 | BREAK ";"
-    { $$ = global_arena->RawNew<Break>(context.SourceLoc()); }
+    { $$ = global_arena->New<Break>(context.SourceLoc()); }
 | CONTINUE ";"
 | CONTINUE ";"
-    { $$ = global_arena->RawNew<Continue>(context.SourceLoc()); }
+    { $$ = global_arena->New<Continue>(context.SourceLoc()); }
 | RETURN return_expression ";"
 | RETURN return_expression ";"
     {
     {
       auto [return_exp, is_omitted_exp] = $2.Release();
       auto [return_exp, is_omitted_exp] = $2.Release();
-      $$ = global_arena->RawNew<Return>(context.SourceLoc(), return_exp, is_omitted_exp);
+      $$ = global_arena->New<Return>(context.SourceLoc(), return_exp, is_omitted_exp);
     }
     }
 | block
 | block
     { $$ = $1; }
     { $$ = $1; }
 | MATCH "(" expression ")" "{" clause_list "}"
 | MATCH "(" expression ")" "{" clause_list "}"
-    { $$ = global_arena->RawNew<Match>(context.SourceLoc(), $3, $6); }
+    { $$ = global_arena->New<Match>(context.SourceLoc(), $3, $6); }
 | CONTINUATION identifier statement
 | CONTINUATION identifier statement
-    { $$ = global_arena->RawNew<Continuation>(context.SourceLoc(), $2, $3); }
+    { $$ = global_arena->New<Continuation>(context.SourceLoc(), $2, $3); }
 | RUN expression ";"
 | RUN expression ";"
-    { $$ = global_arena->RawNew<Run>(context.SourceLoc(), $2); }
+    { $$ = global_arena->New<Run>(context.SourceLoc(), $2); }
 | AWAIT ";"
 | AWAIT ";"
-    { $$ = global_arena->RawNew<Await>(context.SourceLoc()); }
+    { $$ = global_arena->New<Await>(context.SourceLoc()); }
 ;
 ;
 if_statement:
 if_statement:
   IF "(" expression ")" block optional_else
   IF "(" expression ")" block optional_else
-    { $$ = global_arena->RawNew<If>(context.SourceLoc(), $3, $5, $6); }
+    { $$ = global_arena->New<If>(context.SourceLoc(), $3, $5, $6); }
 ;
 ;
 optional_else:
 optional_else:
   // Empty
   // Empty
-    { $$ = 0; }
+    { $$ = std::nullopt; }
 | ELSE if_statement
 | ELSE if_statement
     { $$ = $2; }
     { $$ = $2; }
 | ELSE block
 | ELSE block
@@ -476,13 +476,13 @@ return_expression:
 ;
 ;
 statement_list:
 statement_list:
   // Empty
   // Empty
-    { $$ = 0; }
+    { $$ = std::nullopt; }
 | statement statement_list
 | statement statement_list
-    { $$ = global_arena->RawNew<Sequence>(context.SourceLoc(), $1, $2); }
+    { $$ = global_arena->New<Sequence>(context.SourceLoc(), $1, $2); }
 ;
 ;
 block:
 block:
   "{" statement_list "}"
   "{" statement_list "}"
-    { $$ = global_arena->RawNew<Block>(context.SourceLoc(), $2); }
+    { $$ = global_arena->New<Block>(context.SourceLoc(), $2); }
 ;
 ;
 return_type:
 return_type:
   // Empty
   // Empty
@@ -532,7 +532,7 @@ function_definition:
       $$ = global_arena->New<FunctionDefinition>(
       $$ = global_arena->New<FunctionDefinition>(
           context.SourceLoc(), $2, $3, $4,
           context.SourceLoc(), $2, $3, $4,
           global_arena->New<AutoPattern>(context.SourceLoc()), true,
           global_arena->New<AutoPattern>(context.SourceLoc()), true,
-          global_arena->RawNew<Return>(context.SourceLoc(), $6, true));
+          global_arena->New<Return>(context.SourceLoc(), $6, true));
     }
     }
 ;
 ;
 function_declaration:
 function_declaration:
@@ -542,7 +542,7 @@ function_declaration:
       $$ = global_arena->New<FunctionDefinition>(
       $$ = global_arena->New<FunctionDefinition>(
           context.SourceLoc(), $2, $3, $4,
           context.SourceLoc(), $2, $3, $4,
           global_arena->New<ExpressionPattern>(return_exp),
           global_arena->New<ExpressionPattern>(return_exp),
-          is_omitted_exp, nullptr);
+          is_omitted_exp, std::nullopt);
     }
     }
 ;
 ;
 variable_declaration: identifier ":" pattern
 variable_declaration: identifier ":" pattern

+ 5 - 5
executable_semantics/syntax/syntax_helpers.cpp

@@ -22,11 +22,11 @@ static void AddIntrinsics(std::list<Ptr<const Declaration>>* fs) {
                loc, "format_str",
                loc, "format_str",
                global_arena->New<ExpressionPattern>(
                global_arena->New<ExpressionPattern>(
                    global_arena->New<StringTypeLiteral>(loc))))};
                    global_arena->New<StringTypeLiteral>(loc))))};
-  auto* print_return = global_arena->RawNew<Return>(
-      loc,
-      global_arena->New<IntrinsicExpression>(
-          IntrinsicExpression::IntrinsicKind::Print),
-      false);
+  auto print_return =
+      global_arena->New<Return>(loc,
+                                global_arena->New<IntrinsicExpression>(
+                                    IntrinsicExpression::IntrinsicKind::Print),
+                                false);
   auto print = global_arena->New<FunctionDeclaration>(
   auto print = global_arena->New<FunctionDeclaration>(
       global_arena->New<FunctionDefinition>(
       global_arena->New<FunctionDefinition>(
           loc, "Print", std::vector<GenericBinding>(),
           loc, "Print", std::vector<GenericBinding>(),