Просмотр исходного кода

Unify function declarations and definitions. (#896)

Geoff Romer 4 лет назад
Родитель
Сommit
a3eac75a5b

+ 1 - 13
executable_semantics/ast/BUILD

@@ -32,10 +32,10 @@ cc_library(
     ],
     deps = [
         ":class_definition",
-        ":function_definition",
         ":member",
         ":pattern",
         ":source_location",
+        ":statement",
         "//common:ostream",
         "//executable_semantics/common:nonnull",
         "@llvm-project//llvm:Support",
@@ -67,18 +67,6 @@ cc_test(
     ],
 )
 
-cc_library(
-    name = "function_definition",
-    srcs = ["function_definition.cpp"],
-    hdrs = ["function_definition.h"],
-    deps = [
-        ":expression",
-        ":source_location",
-        ":statement",
-        "@llvm-project//llvm:Support",
-    ],
-)
-
 cc_library(
     name = "member",
     srcs = ["member.cpp"],

+ 29 - 1
executable_semantics/ast/declaration.cpp

@@ -13,7 +13,7 @@ using llvm::cast;
 void Declaration::Print(llvm::raw_ostream& out) const {
   switch (kind()) {
     case Kind::FunctionDeclaration:
-      out << cast<FunctionDeclaration>(*this).definition();
+      cast<FunctionDeclaration>(*this).PrintDepth(-1, out);
       break;
 
     case Kind::ClassDeclaration: {
@@ -45,4 +45,32 @@ void Declaration::Print(llvm::raw_ostream& out) const {
   }
 }
 
+void FunctionDeclaration::PrintDepth(int depth, llvm::raw_ostream& out) const {
+  out << "fn " << name_ << " ";
+  if (!deduced_parameters_.empty()) {
+    out << "[";
+    unsigned int i = 0;
+    for (const auto& deduced : deduced_parameters_) {
+      if (i != 0) {
+        out << ", ";
+      }
+      out << deduced.name << ":! ";
+      deduced.type->Print(out);
+      ++i;
+    }
+    out << "]";
+  }
+  out << *param_pattern_;
+  if (!is_omitted_return_type_) {
+    out << " -> " << *return_type_;
+  }
+  if (body_) {
+    out << " {\n";
+    (*body_)->PrintDepth(depth, out);
+    out << "\n}\n";
+  } else {
+    out << ";\n";
+  }
+}
+
 }  // namespace Carbon

+ 58 - 7
executable_semantics/ast/declaration.h

@@ -10,10 +10,10 @@
 
 #include "common/ostream.h"
 #include "executable_semantics/ast/class_definition.h"
-#include "executable_semantics/ast/function_definition.h"
 #include "executable_semantics/ast/member.h"
 #include "executable_semantics/ast/pattern.h"
 #include "executable_semantics/ast/source_location.h"
+#include "executable_semantics/ast/statement.h"
 #include "executable_semantics/common/nonnull.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Compiler.h"
@@ -61,21 +61,72 @@ class Declaration {
   SourceLocation source_loc_;
 };
 
+// TODO: expand the kinds of things that can be deduced parameters.
+//   For now, only generic parameters are supported.
+struct GenericBinding {
+  std::string name;
+  Nonnull<const Expression*> type;
+};
+
 class FunctionDeclaration : public Declaration {
  public:
-  FunctionDeclaration(Nonnull<FunctionDefinition*> definition)
-      : Declaration(Kind::FunctionDeclaration, definition->source_loc()),
-        definition_(definition) {}
+  FunctionDeclaration(SourceLocation source_loc, std::string name,
+                      std::vector<GenericBinding> deduced_params,
+                      Nonnull<TuplePattern*> param_pattern,
+                      Nonnull<Pattern*> return_type,
+                      bool is_omitted_return_type,
+                      std::optional<Nonnull<Statement*>> body)
+      : Declaration(Kind::FunctionDeclaration, source_loc),
+        name_(std::move(name)),
+        deduced_parameters_(std::move(deduced_params)),
+        param_pattern_(param_pattern),
+        return_type_(return_type),
+        is_omitted_return_type_(is_omitted_return_type),
+        body_(body) {}
 
   static auto classof(const Declaration* decl) -> bool {
     return decl->kind() == Kind::FunctionDeclaration;
   }
 
-  auto definition() const -> const FunctionDefinition& { return *definition_; }
-  auto definition() -> FunctionDefinition& { return *definition_; }
+  void PrintDepth(int depth, llvm::raw_ostream& out) const;
+
+  auto name() const -> const std::string& { return name_; }
+  auto deduced_parameters() const -> llvm::ArrayRef<GenericBinding> {
+    return deduced_parameters_;
+  }
+  auto param_pattern() const -> const TuplePattern& { return *param_pattern_; }
+  auto param_pattern() -> TuplePattern& { return *param_pattern_; }
+  auto return_type() const -> const Pattern& { return *return_type_; }
+  auto return_type() -> Pattern& { return *return_type_; }
+  auto is_omitted_return_type() const -> bool {
+    return is_omitted_return_type_;
+  }
+  auto body() const -> std::optional<Nonnull<const Statement*>> {
+    return body_;
+  }
+  auto body() -> std::optional<Nonnull<Statement*>> { return body_; }
+
+  // The static type of this function. Cannot be called before typechecking.
+  auto static_type() const -> const Value& { return **static_type_; }
+
+  // Sets the static type of this expression. Can only be called once, during
+  // typechecking.
+  void set_static_type(Nonnull<const Value*> type) { static_type_ = type; }
+
+  // Returns whether the static type has been set. Should only be called
+  // during typechecking: before typechecking it's guaranteed to be false,
+  // and after typechecking it's guaranteed to be true.
+  auto has_static_type() const -> bool { return static_type_.has_value(); }
 
  private:
-  Nonnull<FunctionDefinition*> definition_;
+  std::string name_;
+  std::vector<GenericBinding> deduced_parameters_;
+  Nonnull<TuplePattern*> param_pattern_;
+  Nonnull<Pattern*> return_type_;
+  bool is_omitted_return_type_;
+  std::optional<Nonnull<Statement*>> body_;
+
+  std::optional<Nonnull<const Value*>> static_type_;
 };
 
 class ClassDeclaration : public Declaration {

+ 0 - 37
executable_semantics/ast/function_definition.cpp

@@ -1,37 +0,0 @@
-// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
-// Exceptions. See /LICENSE for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "executable_semantics/ast/function_definition.h"
-
-namespace Carbon {
-
-void FunctionDefinition::PrintDepth(int depth, llvm::raw_ostream& out) const {
-  out << "fn " << name_ << " ";
-  if (!deduced_parameters_.empty()) {
-    out << "[";
-    unsigned int i = 0;
-    for (const auto& deduced : deduced_parameters_) {
-      if (i != 0) {
-        out << ", ";
-      }
-      out << deduced.name << ":! ";
-      deduced.type->Print(out);
-      ++i;
-    }
-    out << "]";
-  }
-  out << *param_pattern_;
-  if (!is_omitted_return_type_) {
-    out << " -> " << *return_type_;
-  }
-  if (body_) {
-    out << " {\n";
-    (*body_)->PrintDepth(depth, out);
-    out << "\n}\n";
-  } else {
-    out << ";\n";
-  }
-}
-
-}  // namespace Carbon

+ 0 - 89
executable_semantics/ast/function_definition.h

@@ -1,89 +0,0 @@
-// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
-// Exceptions. See /LICENSE for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef EXECUTABLE_SEMANTICS_AST_FUNCTION_DEFINITION_H_
-#define EXECUTABLE_SEMANTICS_AST_FUNCTION_DEFINITION_H_
-
-#include "common/ostream.h"
-#include "executable_semantics/ast/expression.h"
-#include "executable_semantics/ast/pattern.h"
-#include "executable_semantics/ast/source_location.h"
-#include "executable_semantics/ast/statement.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/Support/Compiler.h"
-
-namespace Carbon {
-
-class Value;
-
-// TODO: expand the kinds of things that can be deduced parameters.
-//   For now, only generic parameters are supported.
-struct GenericBinding {
-  std::string name;
-  Nonnull<const Expression*> type;
-};
-
-class FunctionDefinition {
- public:
-  FunctionDefinition(SourceLocation source_loc, std::string name,
-                     std::vector<GenericBinding> deduced_params,
-                     Nonnull<TuplePattern*> param_pattern,
-                     Nonnull<Pattern*> return_type, bool is_omitted_return_type,
-                     std::optional<Nonnull<Statement*>> body)
-      : source_loc_(source_loc),
-        name_(std::move(name)),
-        deduced_parameters_(std::move(deduced_params)),
-        param_pattern_(param_pattern),
-        return_type_(return_type),
-        is_omitted_return_type_(is_omitted_return_type),
-        body_(body) {}
-
-  void Print(llvm::raw_ostream& out) const { PrintDepth(-1, out); }
-  void PrintDepth(int depth, llvm::raw_ostream& out) const;
-  LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
-
-  auto source_loc() const -> SourceLocation { return source_loc_; }
-  auto name() const -> const std::string& { return name_; }
-  auto deduced_parameters() const -> llvm::ArrayRef<GenericBinding> {
-    return deduced_parameters_;
-  }
-  auto param_pattern() const -> const TuplePattern& { return *param_pattern_; }
-  auto param_pattern() -> TuplePattern& { return *param_pattern_; }
-  auto return_type() const -> const Pattern& { return *return_type_; }
-  auto return_type() -> Pattern& { return *return_type_; }
-  auto is_omitted_return_type() const -> bool {
-    return is_omitted_return_type_;
-  }
-  auto body() const -> std::optional<Nonnull<const Statement*>> {
-    return body_;
-  }
-  auto body() -> std::optional<Nonnull<Statement*>> { return body_; }
-
-  // The static type of this function. Cannot be called before typechecking.
-  auto static_type() const -> const Value& { return **static_type_; }
-
-  // Sets the static type of this expression. Can only be called once, during
-  // typechecking.
-  void set_static_type(Nonnull<const Value*> type) { static_type_ = type; }
-
-  // Returns whether the static type has been set. Should only be called
-  // during typechecking: before typechecking it's guaranteed to be false,
-  // and after typechecking it's guaranteed to be true.
-  auto has_static_type() const -> bool { return static_type_.has_value(); }
-
- private:
-  SourceLocation source_loc_;
-  std::string name_;
-  std::vector<GenericBinding> deduced_parameters_;
-  Nonnull<TuplePattern*> param_pattern_;
-  Nonnull<Pattern*> return_type_;
-  bool is_omitted_return_type_;
-  std::optional<Nonnull<Statement*>> body_;
-
-  std::optional<Nonnull<const Value*>> static_type_;
-};
-
-}  // namespace Carbon
-
-#endif  // EXECUTABLE_SEMANTICS_AST_FUNCTION_DEFINITION_H_

+ 2 - 3
executable_semantics/interpreter/BUILD

@@ -24,8 +24,8 @@ cc_library(
         ":field_path",
         ":stack",
         "//common:ostream",
+        "//executable_semantics/ast:declaration",
         "//executable_semantics/ast:expression",
-        "//executable_semantics/ast:function_definition",
         "//executable_semantics/ast:statement",
         "//executable_semantics/common:arena",
         "//executable_semantics/common:error",
@@ -97,7 +97,6 @@ cc_library(
         "//common:ostream",
         "//executable_semantics/ast:declaration",
         "//executable_semantics/ast:expression",
-        "//executable_semantics/ast:function_definition",
         "//executable_semantics/common:arena",
         "//executable_semantics/common:tracing_flag",
         "@llvm-project//llvm:Support",
@@ -118,8 +117,8 @@ cc_library(
         ":dictionary",
         ":interpreter",
         "//common:ostream",
+        "//executable_semantics/ast:declaration",
         "//executable_semantics/ast:expression",
-        "//executable_semantics/ast:function_definition",
         "//executable_semantics/ast:statement",
         "//executable_semantics/common:arena",
         "//executable_semantics/common:tracing_flag",

+ 1 - 1
executable_semantics/interpreter/action.cpp

@@ -10,8 +10,8 @@
 #include <utility>
 #include <vector>
 
+#include "executable_semantics/ast/declaration.h"
 #include "executable_semantics/ast/expression.h"
-#include "executable_semantics/ast/function_definition.h"
 #include "executable_semantics/common/arena.h"
 #include "executable_semantics/interpreter/stack.h"
 #include "llvm/ADT/StringExtras.h"

+ 2 - 2
executable_semantics/interpreter/exec_program.cpp

@@ -26,11 +26,11 @@ static void AddIntrinsics(Nonnull<Arena*> arena,
       source_loc,
       arena->New<IntrinsicExpression>(IntrinsicExpression::Intrinsic::Print),
       false);
-  auto print = arena->New<FunctionDeclaration>(arena->New<FunctionDefinition>(
+  auto print = arena->New<FunctionDeclaration>(
       source_loc, "Print", std::vector<GenericBinding>(),
       arena->New<TuplePattern>(source_loc, print_params),
       arena->New<ExpressionPattern>(arena->New<TupleLiteral>(source_loc)),
-      /*is_omitted_return_type=*/false, print_return));
+      /*is_omitted_return_type=*/false, print_return);
   declarations->insert(declarations->begin(), print);
 }
 

+ 2 - 3
executable_semantics/interpreter/interpreter.cpp

@@ -12,8 +12,8 @@
 #include <vector>
 
 #include "common/check.h"
+#include "executable_semantics/ast/declaration.h"
 #include "executable_semantics/ast/expression.h"
-#include "executable_semantics/ast/function_definition.h"
 #include "executable_semantics/common/arena.h"
 #include "executable_semantics/common/error.h"
 #include "executable_semantics/common/tracing_flag.h"
@@ -109,8 +109,7 @@ auto Interpreter::EvalPrim(Operator op,
 void Interpreter::InitEnv(const Declaration& d, Env* env) {
   switch (d.kind()) {
     case Declaration::Kind::FunctionDeclaration: {
-      const FunctionDefinition& func_def =
-          cast<FunctionDeclaration>(d).definition();
+      const auto& func_def = cast<FunctionDeclaration>(d);
       Env new_env = *env;
       // Bring the deduced parameters into scope.
       for (const auto& deduced : func_def.deduced_parameters()) {

+ 7 - 8
executable_semantics/interpreter/type_checker.cpp

@@ -11,7 +11,7 @@
 #include <vector>
 
 #include "common/ostream.h"
-#include "executable_semantics/ast/function_definition.h"
+#include "executable_semantics/ast/declaration.h"
 #include "executable_semantics/common/arena.h"
 #include "executable_semantics/common/error.h"
 #include "executable_semantics/common/tracing_flag.h"
@@ -50,7 +50,7 @@ static void SetStaticType(Nonnull<Pattern*> pattern,
 
 // Sets the static type of `definition`. Can be called multiple times on
 // the same node, so long as the types are the same on each call.
-static void SetStaticType(Nonnull<FunctionDefinition*> definition,
+static void SetStaticType(Nonnull<FunctionDeclaration*> definition,
                           Nonnull<const Value*> type) {
   if (definition->has_static_type()) {
     CHECK(TypeEqual(&definition->static_type(), type));
@@ -1027,7 +1027,7 @@ void TypeChecker::ExpectReturnOnAllPaths(
 // a function.
 // TODO: Add checking to function definitions to ensure that
 //   all deduced type parameters will be deduced.
-auto TypeChecker::TypeCheckFunDef(FunctionDefinition* f, TypeEnv types,
+auto TypeChecker::TypeCheckFunDef(FunctionDeclaration* f, TypeEnv types,
                                   Env values) -> TCResult {
   // Bring the deduced parameters into scope
   for (const auto& deduced : f->deduced_parameters()) {
@@ -1068,7 +1068,7 @@ auto TypeChecker::TypeCheckFunDef(FunctionDefinition* f, TypeEnv types,
 }
 
 auto TypeChecker::TypeOfFunDef(TypeEnv types, Env values,
-                               FunctionDefinition* fun_def)
+                               FunctionDeclaration* fun_def)
     -> Nonnull<const Value*> {
   // Bring the deduced parameters into scope
   for (const auto& deduced : fun_def->deduced_parameters()) {
@@ -1120,7 +1120,7 @@ auto TypeChecker::TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
 static auto GetName(const Declaration& d) -> const std::string& {
   switch (d.kind()) {
     case Declaration::Kind::FunctionDeclaration:
-      return cast<FunctionDeclaration>(d).definition().name();
+      return cast<FunctionDeclaration>(d).name();
     case Declaration::Kind::ClassDeclaration:
       return cast<ClassDeclaration>(d).definition().name();
     case Declaration::Kind::ChoiceDeclaration:
@@ -1140,8 +1140,7 @@ void TypeChecker::TypeCheck(Nonnull<Declaration*> d, const TypeEnv& types,
                             const Env& values) {
   switch (d->kind()) {
     case Declaration::Kind::FunctionDeclaration:
-      TypeCheckFunDef(&cast<FunctionDeclaration>(*d).definition(), types,
-                      values);
+      TypeCheckFunDef(&cast<FunctionDeclaration>(*d), types, values);
       return;
     case Declaration::Kind::ClassDeclaration:
       // TODO
@@ -1176,7 +1175,7 @@ void TypeChecker::TypeCheck(Nonnull<Declaration*> d, const TypeEnv& types,
 void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
   switch (d->kind()) {
     case Declaration::Kind::FunctionDeclaration: {
-      FunctionDefinition& func_def = cast<FunctionDeclaration>(*d).definition();
+      FunctionDeclaration& func_def = cast<FunctionDeclaration>(*d);
       auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
       tops->types.Set(func_def.name(), t);
       interpreter.InitEnv(*d, &tops->values);

+ 2 - 2
executable_semantics/interpreter/type_checker.h

@@ -109,7 +109,7 @@ class TypeChecker {
                      Nonnull<ReturnTypeContext*> return_type_context)
       -> TCResult;
 
-  auto TypeCheckFunDef(FunctionDefinition* f, TypeEnv types, Env values)
+  auto TypeCheckFunDef(FunctionDeclaration* f, TypeEnv types, Env values)
       -> TCResult;
 
   auto TypeCheckCase(Nonnull<const Value*> expected, Nonnull<Pattern*> pat,
@@ -117,7 +117,7 @@ class TypeChecker {
                      Nonnull<ReturnTypeContext*> return_type_context)
       -> Match::Clause;
 
-  auto TypeOfFunDef(TypeEnv types, Env values, FunctionDefinition* fun_def)
+  auto TypeOfFunDef(TypeEnv types, Env values, FunctionDeclaration* fun_def)
       -> Nonnull<const Value*>;
   auto TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/, Env ct_top)
       -> Nonnull<const Value*>;

+ 1 - 1
executable_semantics/interpreter/value.h

@@ -11,7 +11,7 @@
 #include <vector>
 
 #include "common/ostream.h"
-#include "executable_semantics/ast/function_definition.h"
+#include "executable_semantics/ast/declaration.h"
 #include "executable_semantics/ast/statement.h"
 #include "executable_semantics/common/nonnull.h"
 #include "executable_semantics/interpreter/address.h"

+ 8 - 14
executable_semantics/syntax/parser.ypp

@@ -68,7 +68,6 @@
   #include "executable_semantics/ast/ast.h"
   #include "executable_semantics/ast/declaration.h"
   #include "executable_semantics/ast/expression.h"
-  #include "executable_semantics/ast/function_definition.h"
   #include "executable_semantics/ast/paren_contents.h"
   #include "executable_semantics/ast/pattern.h"
   #include "executable_semantics/common/arena.h"
@@ -99,8 +98,7 @@
 %type <std::string> optional_library_path
 %type <bool> api_or_impl
 %type <Nonnull<Declaration*>> declaration
-%type <Nonnull<FunctionDefinition*>> function_declaration
-%type <Nonnull<FunctionDefinition*>> function_definition
+%type <Nonnull<FunctionDeclaration*>> function_declaration
 %type <std::vector<Nonnull<Declaration*>>> declaration_list
 %type <Nonnull<Statement*>> statement
 %type <Nonnull<Statement*>> if_statement
@@ -651,28 +649,26 @@ deduced_params:
 | LEFT_SQUARE_BRACKET deduced_param_list RIGHT_SQUARE_BRACKET
     { $$ = $2; }
 ;
-function_definition:
+function_declaration:
   FN identifier deduced_params maybe_empty_tuple_pattern return_type block
     {
       auto [return_exp, is_omitted_exp] = $5;
-      $$ = arena->New<FunctionDefinition>(
+      $$ = arena->New<FunctionDeclaration>(
           context.source_loc(), $2, $3, $4,
           arena->New<ExpressionPattern>(return_exp), is_omitted_exp, $6);
     }
 | FN identifier deduced_params maybe_empty_tuple_pattern ARROW AUTO block
     {
       // The return type is not considered "omitted" because it's `auto`.
-      $$ = arena->New<FunctionDefinition>(
+      $$ = arena->New<FunctionDeclaration>(
           context.source_loc(), $2, $3, $4,
           arena->New<AutoPattern>(context.source_loc()),
           /*is_omitted_exp=*/false, $7);
     }
-;
-function_declaration:
-  FN identifier deduced_params maybe_empty_tuple_pattern return_type SEMICOLON
+| FN identifier deduced_params maybe_empty_tuple_pattern return_type SEMICOLON
     {
       auto [return_exp, is_omitted_exp] = $5;
-      $$ = arena->New<FunctionDefinition>(
+      $$ = arena->New<FunctionDeclaration>(
           context.source_loc(), $2, $3, $4,
           arena->New<ExpressionPattern>(return_exp), is_omitted_exp,
           std::nullopt);
@@ -720,10 +716,8 @@ alternative_list_contents:
     }
 ;
 declaration:
-  function_definition
-    { $$ = arena->New<FunctionDeclaration>($1); }
-| function_declaration
-    { $$ = arena->New<FunctionDeclaration>($1); }
+  function_declaration
+    { $$ = $1; }
 | CLASS identifier LEFT_CURLY_BRACE member_list RIGHT_CURLY_BRACE
     { $$ = arena->New<ClassDeclaration>(context.source_loc(), $2, $4); }
 | CHOICE identifier LEFT_CURLY_BRACE alternative_list RIGHT_CURLY_BRACE