Browse Source

Support complex type patterns in bindings. (#759)

Co-authored-by: Jon Meow <46229924+jonmeow@users.noreply.github.com>
Geoff Romer 4 years ago
parent
commit
7138ee400f

+ 54 - 41
executable_semantics/interpreter/interpreter.cpp

@@ -29,8 +29,6 @@ namespace Carbon {
 
 State* state = nullptr;
 
-auto PatternMatch(const Value* pat, const Value* val, Env,
-                  std::list<std::string>*, int) -> std::optional<Env>;
 void Step();
 //
 // Auxiliary Functions
@@ -201,12 +199,17 @@ void CallFunction(int line_num, std::vector<const Value*> operas,
     case Value::Kind::FunctionValue: {
       const auto& fn = cast<FunctionValue>(*operas[0]);
       // Bind arguments to parameters
-      std::list<std::string> params;
       std::optional<Env> matches =
-          PatternMatch(fn.Param(), operas[1], globals, &params, line_num);
+          PatternMatch(fn.Param(), operas[1], line_num);
       CHECK(matches) << "internal error in call_function, pattern match failed";
       // Create the new frame and push it on the stack
-      auto* scope = global_arena->RawNew<Scope>(*matches, params);
+      Env values = globals;
+      std::list<std::string> params;
+      for (const auto& [name, value] : *matches) {
+        values.Set(name, value);
+        params.push_back(name);
+      }
+      auto* scope = global_arena->RawNew<Scope>(values, params);
       auto* frame = global_arena->RawNew<Frame>(
           fn.Name(), Stack(scope),
           Stack<Action*>(global_arena->RawNew<StatementAction>(fn.Body())));
@@ -266,20 +269,14 @@ void CreateTuple(Frame* frame, Action* act, const Expression* exp) {
   frame->todo.Push(global_arena->RawNew<ValAction>(tv));
 }
 
-// Returns an updated environment that includes the bindings of
-//    pattern variables to their matched values, if matching succeeds.
-//
-// The names of the pattern variables are added to the vars parameter.
-// Returns nullopt if the value doesn't match the pattern.
-auto PatternMatch(const Value* p, const Value* v, Env values,
-                  std::list<std::string>* vars, int line_num)
+auto PatternMatch(const Value* p, const Value* v, int line_num)
     -> std::optional<Env> {
   switch (p->Tag()) {
     case Value::Kind::BindingPlaceholderValue: {
       const auto& placeholder = cast<BindingPlaceholderValue>(*p);
+      Env values;
       if (placeholder.Name().has_value()) {
         Address a = state->heap.AllocateValue(CopyVal(v, line_num));
-        vars->push_back(*placeholder.Name());
         values.Set(*placeholder.Name(), a);
       }
       return values;
@@ -290,22 +287,26 @@ auto PatternMatch(const Value* p, const Value* v, Env values,
           const auto& p_tup = cast<TupleValue>(*p);
           const auto& v_tup = cast<TupleValue>(*v);
           if (p_tup.Elements().size() != v_tup.Elements().size()) {
-            FATAL_RUNTIME_ERROR(line_num)
+            FATAL_PROGRAM_ERROR(line_num)
                 << "arity mismatch in tuple pattern match:\n  pattern: "
                 << p_tup << "\n  value: " << v_tup;
           }
-          for (const TupleElement& pattern_element : p_tup.Elements()) {
-            const Value* value_field = v_tup.FindField(pattern_element.name);
-            if (value_field == nullptr) {
-              FATAL_RUNTIME_ERROR(line_num)
-                  << "field " << pattern_element.name << "not in " << *v;
+          Env values;
+          for (size_t i = 0; i < p_tup.Elements().size(); ++i) {
+            if (p_tup.Elements()[i].name != v_tup.Elements()[i].name) {
+              FATAL_PROGRAM_ERROR(line_num)
+                  << "Tuple field name '" << v_tup.Elements()[i].name
+                  << "' does not match pattern field name '"
+                  << p_tup.Elements()[i].name << "'";
             }
             std::optional<Env> matches = PatternMatch(
-                pattern_element.value, value_field, values, vars, line_num);
+                p_tup.Elements()[i].value, v_tup.Elements()[i].value, line_num);
             if (!matches) {
               return std::nullopt;
             }
-            values = *matches;
+            for (const auto& [name, value] : *matches) {
+              values.Set(name, value);
+            }
           }  // for
           return values;
         }
@@ -321,12 +322,7 @@ auto PatternMatch(const Value* p, const Value* v, Env values,
               p_alt.AltName() != v_alt.AltName()) {
             return std::nullopt;
           }
-          std::optional<Env> matches = PatternMatch(
-              p_alt.Argument(), v_alt.Argument(), values, vars, line_num);
-          if (!matches) {
-            return std::nullopt;
-          }
-          return *matches;
+          return PatternMatch(p_alt.Argument(), v_alt.Argument(), line_num);
         }
         default:
           FATAL() << "expected a choice alternative in pattern, not " << *v;
@@ -336,19 +332,32 @@ auto PatternMatch(const Value* p, const Value* v, Env values,
         case Value::Kind::FunctionType: {
           const auto& p_fn = cast<FunctionType>(*p);
           const auto& v_fn = cast<FunctionType>(*v);
-          std::optional<Env> matches =
-              PatternMatch(p_fn.Param(), v_fn.Param(), values, vars, line_num);
-          if (!matches) {
+          std::optional<Env> param_matches =
+              PatternMatch(p_fn.Param(), v_fn.Param(), line_num);
+          if (!param_matches) {
+            return std::nullopt;
+          }
+          std::optional<Env> ret_matches =
+              PatternMatch(p_fn.Ret(), v_fn.Ret(), line_num);
+          if (!ret_matches) {
             return std::nullopt;
           }
-          return PatternMatch(p_fn.Ret(), v_fn.Ret(), *matches, vars, line_num);
+          Env values = *param_matches;
+          for (const auto& [name, value] : *ret_matches) {
+            values.Set(name, value);
+          }
+          return values;
         }
         default:
           return std::nullopt;
       }
+    case Value::Kind::AutoType:
+      // `auto` matches any type, without binding any new names. We rely
+      // on the typechecker to ensure that `v` is a type.
+      return Env();
     default:
       if (ValueEqual(p, v, line_num)) {
-        return values;
+        return Env();
       } else {
         return std::nullopt;
       }
@@ -913,12 +922,15 @@ void StepStmt() {
         } else {  // try to match
           auto v = act->Results()[0];
           auto pat = act->Results()[clause_num + 1];
-          auto values = CurrentEnv(state);
-          std::list<std::string> vars;
-          std::optional<Env> matches =
-              PatternMatch(pat, v, values, &vars, stmt->LineNumber());
+          std::optional<Env> matches = PatternMatch(pat, v, stmt->LineNumber());
           if (matches) {  // we have a match, start the body
-            auto* new_scope = global_arena->RawNew<Scope>(*matches, vars);
+            Env values = CurrentEnv(state);
+            std::list<std::string> vars;
+            for (const auto& [name, value] : *matches) {
+              values.Set(name, value);
+              vars.push_back(name);
+            }
+            auto* new_scope = global_arena->RawNew<Scope>(values, vars);
             frame->scopes.Push(new_scope);
             const Statement* body_block =
                 global_arena->RawNew<Block>(stmt->LineNumber(), c->second);
@@ -1025,13 +1037,14 @@ void StepStmt() {
         const Value* v = act->Results()[0];
         const Value* p = act->Results()[1];
 
-        std::optional<Env> matches =
-            PatternMatch(p, v, frame->scopes.Top()->values,
-                         &frame->scopes.Top()->locals, stmt->LineNumber());
+        std::optional<Env> matches = PatternMatch(p, v, stmt->LineNumber());
         CHECK(matches)
             << stmt->LineNumber()
             << ": internal error in variable definition, match failed";
-        frame->scopes.Top()->values = *matches;
+        for (const auto& [name, value] : *matches) {
+          frame->scopes.Top()->values.Set(name, value);
+          frame->scopes.Top()->locals.push_back(name);
+        }
         frame->todo.Pop(1);
       }
       break;

+ 5 - 0
executable_semantics/interpreter/interpreter.h

@@ -35,6 +35,11 @@ void PrintEnv(Env values, llvm::raw_ostream& out);
 
 /***** Interpreters *****/
 
+// Attempts to match `v` against the pattern `p`. If matching succeeds, returns
+// the bindings of pattern variables to their matched values.
+auto PatternMatch(const Value* p, const Value* v, int line_num)
+    -> std::optional<Env>;
+
 auto InterpProgram(const std::list<const Declaration*>& fs) -> int;
 auto InterpExp(Env values, const Expression* e) -> const Value*;
 auto InterpPattern(Env values, const Pattern* p) -> const Value*;

+ 17 - 29
executable_semantics/interpreter/typecheck.cpp

@@ -552,32 +552,20 @@ auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
     }
     case Pattern::Kind::BindingPattern: {
       const auto& binding = cast<BindingPattern>(*p);
-      const Value* type;
-      switch (binding.Type()->Tag()) {
-        case Pattern::Kind::AutoPattern: {
-          if (expected == nullptr) {
-            FATAL_COMPILATION_ERROR(binding.LineNumber())
-                << "auto not allowed here";
-          } else {
-            type = expected;
-          }
-          break;
-        }
-        case Pattern::Kind::ExpressionPattern: {
-          type = InterpExp(
-              values, cast<ExpressionPattern>(binding.Type())->Expression());
-          CHECK(type->Tag() != Value::Kind::AutoType);
-          if (expected != nullptr) {
-            ExpectType(binding.LineNumber(), "pattern variable", type,
-                       expected);
-          }
-          break;
+      TCPattern binding_type_result =
+          TypeCheckPattern(binding.Type(), types, values, nullptr);
+      const Value* type = InterpPattern(values, binding_type_result.pattern);
+      if (expected != nullptr) {
+        std::optional<Env> values =
+            PatternMatch(type, expected, binding.Type()->LineNumber());
+        if (values == std::nullopt) {
+          FATAL_COMPILATION_ERROR(binding.Type()->LineNumber())
+              << "Type pattern '" << *type << "' does not match actual type '"
+              << *expected << "'";
         }
-        case Pattern::Kind::TuplePattern:
-        case Pattern::Kind::BindingPattern:
-        case Pattern::Kind::AlternativePattern:
-          FATAL_COMPILATION_ERROR(binding.LineNumber())
-              << "Unsupported type pattern";
+        CHECK(values->begin() == values->end())
+            << "Name bindings within type patterns are unsupported";
+        type = expected;
       }
       auto new_p = global_arena->RawNew<BindingPattern>(
           binding.LineNumber(), binding.Name(),
@@ -908,8 +896,8 @@ static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
   // Bring the deduced parameters into scope
   for (const auto& deduced : f->deduced_parameters) {
     // auto t = InterpExp(values, deduced.type);
-    Address a = state->heap.AllocateValue(
-        global_arena->RawNew<VariableType>(deduced.name));
+    types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
+    Address a = state->heap.AllocateValue(*types.Get(deduced.name));
     values.Set(deduced.name, a);
   }
   // Type check the parameter pattern
@@ -937,8 +925,8 @@ static auto TypeOfFunDef(TypeEnv types, Env values,
   // Bring the deduced parameters into scope
   for (const auto& deduced : fun_def->deduced_parameters) {
     // auto t = InterpExp(values, deduced.type);
-    Address a = state->heap.AllocateValue(
-        global_arena->RawNew<VariableType>(deduced.name));
+    types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
+    Address a = state->heap.AllocateValue(*types.Get(deduced.name));
     values.Set(deduced.name, a);
   }
   // Type check the parameter pattern

+ 1 - 0
executable_semantics/test_list.bzl

@@ -105,6 +105,7 @@ TEST_LIST = [
     "type_compute",
     "type_compute2",
     "type_compute3",
+    "type_match",
     "while1",
     "zero",
 ]

+ 1 - 3
executable_semantics/testdata/tuple5.golden

@@ -1,4 +1,2 @@
-COMPILATION ERROR: 8: type error in pattern variable
-expected: (x = i32, y = i32)
-actual: (y = i32, x = i32)
+PROGRAM ERROR: 8: Tuple field name 'y' does not match pattern field name 'x'
 EXIT CODE: 255

+ 8 - 0
executable_semantics/testdata/type_match.carbon

@@ -0,0 +1,8 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+fn main() -> i32 {
+  var t: (auto, (i32, i32)) = ((1,2),(3,4));
+  return t[0][0] + t[1][1] - 5;
+}

+ 1 - 0
executable_semantics/testdata/type_match.golden

@@ -0,0 +1 @@
+result: 0