typecheck.cpp 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "executable_semantics/interpreter/typecheck.h"
  5. #include <algorithm>
  6. #include <iterator>
  7. #include <map>
  8. #include <set>
  9. #include <vector>
  10. #include "common/ostream.h"
  11. #include "executable_semantics/ast/function_definition.h"
  12. #include "executable_semantics/common/arena.h"
  13. #include "executable_semantics/common/error.h"
  14. #include "executable_semantics/common/tracing_flag.h"
  15. #include "executable_semantics/interpreter/interpreter.h"
  16. #include "executable_semantics/interpreter/value.h"
  17. #include "llvm/ADT/StringExtras.h"
  18. #include "llvm/Support/Casting.h"
  19. using llvm::cast;
  20. using llvm::dyn_cast;
  21. namespace Carbon {
  22. void PrintTypeEnv(TypeEnv types, llvm::raw_ostream& out) {
  23. llvm::ListSeparator sep;
  24. for (const auto& [name, type] : types) {
  25. out << sep << name << ": " << *type;
  26. }
  27. }
  28. static void ExpectType(SourceLocation loc, const std::string& context,
  29. const Value* expected, const Value* actual) {
  30. if (!TypeEqual(expected, actual)) {
  31. FATAL_COMPILATION_ERROR(loc) << "type error in " << context << "\n"
  32. << "expected: " << *expected << "\n"
  33. << "actual: " << *actual;
  34. }
  35. }
  36. static void ExpectPointerType(SourceLocation loc, const std::string& context,
  37. const Value* actual) {
  38. if (actual->Tag() != Value::Kind::PointerType) {
  39. FATAL_COMPILATION_ERROR(loc) << "type error in " << context << "\n"
  40. << "expected a pointer type\n"
  41. << "actual: " << *actual;
  42. }
  43. }
  44. static SourceLocation ReifyFakeSourceLoc() {
  45. return SourceLocation("<reify>", 0);
  46. }
  47. // Reify type to type expression.
  48. static auto ReifyType(const Value* t, SourceLocation loc)
  49. -> Ptr<const Expression> {
  50. switch (t->Tag()) {
  51. case Value::Kind::IntType:
  52. return global_arena->New<IntTypeLiteral>(ReifyFakeSourceLoc());
  53. case Value::Kind::BoolType:
  54. return global_arena->New<BoolTypeLiteral>(ReifyFakeSourceLoc());
  55. case Value::Kind::TypeType:
  56. return global_arena->New<TypeTypeLiteral>(ReifyFakeSourceLoc());
  57. case Value::Kind::ContinuationType:
  58. return global_arena->New<ContinuationTypeLiteral>(ReifyFakeSourceLoc());
  59. case Value::Kind::FunctionType: {
  60. const auto& fn_type = cast<FunctionType>(*t);
  61. return global_arena->New<FunctionTypeLiteral>(
  62. ReifyFakeSourceLoc(), ReifyType(fn_type.Param(), loc),
  63. ReifyType(fn_type.Ret(), loc),
  64. /*is_omitted_return_type=*/false);
  65. }
  66. case Value::Kind::TupleValue: {
  67. std::vector<FieldInitializer> args;
  68. for (const TupleElement& field : cast<TupleValue>(*t).Elements()) {
  69. args.push_back(
  70. FieldInitializer(field.name, ReifyType(field.value, loc)));
  71. }
  72. return global_arena->New<TupleLiteral>(ReifyFakeSourceLoc(), args);
  73. }
  74. case Value::Kind::ClassType:
  75. return global_arena->New<IdentifierExpression>(
  76. ReifyFakeSourceLoc(), cast<ClassType>(*t).Name());
  77. case Value::Kind::ChoiceType:
  78. return global_arena->New<IdentifierExpression>(
  79. ReifyFakeSourceLoc(), cast<ChoiceType>(*t).Name());
  80. case Value::Kind::PointerType:
  81. return global_arena->New<PrimitiveOperatorExpression>(
  82. ReifyFakeSourceLoc(), Operator::Ptr,
  83. std::vector<Ptr<const Expression>>(
  84. {ReifyType(cast<PointerType>(*t).Type(), loc)}));
  85. case Value::Kind::VariableType:
  86. return global_arena->New<IdentifierExpression>(
  87. ReifyFakeSourceLoc(), cast<VariableType>(*t).Name());
  88. case Value::Kind::StringType:
  89. return global_arena->New<StringTypeLiteral>(ReifyFakeSourceLoc());
  90. case Value::Kind::AlternativeConstructorValue:
  91. case Value::Kind::AlternativeValue:
  92. case Value::Kind::AutoType:
  93. case Value::Kind::BindingPlaceholderValue:
  94. case Value::Kind::BoolValue:
  95. case Value::Kind::ContinuationValue:
  96. case Value::Kind::FunctionValue:
  97. case Value::Kind::IntValue:
  98. case Value::Kind::PointerValue:
  99. case Value::Kind::StringValue:
  100. case Value::Kind::StructValue:
  101. FATAL() << "expected a type, not " << *t;
  102. }
  103. }
  104. // Perform type argument deduction, matching the parameter type `param`
  105. // against the argument type `arg`. Whenever there is an VariableType
  106. // in the parameter type, it is deduced to be the corresponding type
  107. // inside the argument type.
  108. // The `deduced` parameter is an accumulator, that is, it holds the
  109. // results so-far.
  110. static auto ArgumentDeduction(SourceLocation loc, TypeEnv deduced,
  111. const Value* param, const Value* arg) -> TypeEnv {
  112. switch (param->Tag()) {
  113. case Value::Kind::VariableType: {
  114. const auto& var_type = cast<VariableType>(*param);
  115. std::optional<const Value*> d = deduced.Get(var_type.Name());
  116. if (!d) {
  117. deduced.Set(var_type.Name(), arg);
  118. } else {
  119. ExpectType(loc, "argument deduction", *d, arg);
  120. }
  121. return deduced;
  122. }
  123. case Value::Kind::TupleValue: {
  124. if (arg->Tag() != Value::Kind::TupleValue) {
  125. ExpectType(loc, "argument deduction", param, arg);
  126. }
  127. const auto& param_tup = cast<TupleValue>(*param);
  128. const auto& arg_tup = cast<TupleValue>(*arg);
  129. if (param_tup.Elements().size() != arg_tup.Elements().size()) {
  130. ExpectType(loc, "argument deduction", param, arg);
  131. }
  132. for (size_t i = 0; i < param_tup.Elements().size(); ++i) {
  133. if (param_tup.Elements()[i].name != arg_tup.Elements()[i].name) {
  134. FATAL_COMPILATION_ERROR(loc)
  135. << "mismatch in tuple names, " << param_tup.Elements()[i].name
  136. << " != " << arg_tup.Elements()[i].name;
  137. }
  138. deduced = ArgumentDeduction(loc, deduced, param_tup.Elements()[i].value,
  139. arg_tup.Elements()[i].value);
  140. }
  141. return deduced;
  142. }
  143. case Value::Kind::FunctionType: {
  144. if (arg->Tag() != Value::Kind::FunctionType) {
  145. ExpectType(loc, "argument deduction", param, arg);
  146. }
  147. const auto& param_fn = cast<FunctionType>(*param);
  148. const auto& arg_fn = cast<FunctionType>(*arg);
  149. // TODO: handle situation when arg has deduced parameters.
  150. deduced =
  151. ArgumentDeduction(loc, deduced, param_fn.Param(), arg_fn.Param());
  152. deduced = ArgumentDeduction(loc, deduced, param_fn.Ret(), arg_fn.Ret());
  153. return deduced;
  154. }
  155. case Value::Kind::PointerType: {
  156. if (arg->Tag() != Value::Kind::PointerType) {
  157. ExpectType(loc, "argument deduction", param, arg);
  158. }
  159. return ArgumentDeduction(loc, deduced, cast<PointerType>(*param).Type(),
  160. cast<PointerType>(*arg).Type());
  161. }
  162. // Nothing to do in the case for `auto`.
  163. case Value::Kind::AutoType: {
  164. return deduced;
  165. }
  166. // For the following cases, we check for type equality.
  167. case Value::Kind::ContinuationType:
  168. case Value::Kind::ClassType:
  169. case Value::Kind::ChoiceType:
  170. case Value::Kind::IntType:
  171. case Value::Kind::BoolType:
  172. case Value::Kind::TypeType:
  173. case Value::Kind::StringType:
  174. ExpectType(loc, "argument deduction", param, arg);
  175. return deduced;
  176. // The rest of these cases should never happen.
  177. case Value::Kind::IntValue:
  178. case Value::Kind::BoolValue:
  179. case Value::Kind::FunctionValue:
  180. case Value::Kind::PointerValue:
  181. case Value::Kind::StructValue:
  182. case Value::Kind::AlternativeValue:
  183. case Value::Kind::BindingPlaceholderValue:
  184. case Value::Kind::AlternativeConstructorValue:
  185. case Value::Kind::ContinuationValue:
  186. case Value::Kind::StringValue:
  187. FATAL() << "In ArgumentDeduction: expected type, not value " << *param;
  188. }
  189. }
  190. static auto Substitute(TypeEnv dict, const Value* type) -> const Value* {
  191. switch (type->Tag()) {
  192. case Value::Kind::VariableType: {
  193. std::optional<const Value*> t =
  194. dict.Get(cast<VariableType>(*type).Name());
  195. if (!t) {
  196. return type;
  197. } else {
  198. return *t;
  199. }
  200. }
  201. case Value::Kind::TupleValue: {
  202. std::vector<TupleElement> elts;
  203. for (const auto& elt : cast<TupleValue>(*type).Elements()) {
  204. auto t = Substitute(dict, elt.value);
  205. elts.push_back({.name = elt.name, .value = t});
  206. }
  207. return global_arena->RawNew<TupleValue>(elts);
  208. }
  209. case Value::Kind::FunctionType: {
  210. const auto& fn_type = cast<FunctionType>(*type);
  211. auto param = Substitute(dict, fn_type.Param());
  212. auto ret = Substitute(dict, fn_type.Ret());
  213. return global_arena->RawNew<FunctionType>(std::vector<GenericBinding>(),
  214. param, ret);
  215. }
  216. case Value::Kind::PointerType: {
  217. return global_arena->RawNew<PointerType>(
  218. Substitute(dict, cast<PointerType>(*type).Type()));
  219. }
  220. case Value::Kind::AutoType:
  221. case Value::Kind::IntType:
  222. case Value::Kind::BoolType:
  223. case Value::Kind::TypeType:
  224. case Value::Kind::ClassType:
  225. case Value::Kind::ChoiceType:
  226. case Value::Kind::ContinuationType:
  227. case Value::Kind::StringType:
  228. return type;
  229. // The rest of these cases should never happen.
  230. case Value::Kind::IntValue:
  231. case Value::Kind::BoolValue:
  232. case Value::Kind::FunctionValue:
  233. case Value::Kind::PointerValue:
  234. case Value::Kind::StructValue:
  235. case Value::Kind::AlternativeValue:
  236. case Value::Kind::BindingPlaceholderValue:
  237. case Value::Kind::AlternativeConstructorValue:
  238. case Value::Kind::ContinuationValue:
  239. case Value::Kind::StringValue:
  240. FATAL() << "In Substitute: expected type, not value " << *type;
  241. }
  242. }
  243. // The TypeCheckExp function performs semantic analysis on an expression.
  244. // It returns a new version of the expression, its type, and an
  245. // updated environment which are bundled into a TCResult object.
  246. // The purpose of the updated environment is
  247. // to bring pattern variables into scope, for example, in a match case.
  248. // The new version of the expression may include more information,
  249. // for example, the type arguments deduced for the type parameters of a
  250. // generic.
  251. //
  252. // e is the expression to be analyzed.
  253. // types maps variable names to the type of their run-time value.
  254. // values maps variable names to their compile-time values. It is not
  255. // directly used in this function but is passed to InterExp.
  256. auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
  257. -> TCExpression {
  258. if (tracing_output) {
  259. llvm::outs() << "checking expression " << *e << "\ntypes: ";
  260. PrintTypeEnv(types, llvm::outs());
  261. llvm::outs() << "\nvalues: ";
  262. PrintEnv(values, llvm::outs());
  263. llvm::outs() << "\n";
  264. }
  265. switch (e->Tag()) {
  266. case Expression::Kind::IndexExpression: {
  267. const auto& index = cast<IndexExpression>(*e);
  268. auto res = TypeCheckExp(index.Aggregate(), types, values);
  269. auto t = res.type;
  270. switch (t->Tag()) {
  271. case Value::Kind::TupleValue: {
  272. auto i = cast<IntValue>(*InterpExp(values, index.Offset())).Val();
  273. std::string f = std::to_string(i);
  274. const Value* field_t = cast<TupleValue>(*t).FindField(f);
  275. if (field_t == nullptr) {
  276. FATAL_COMPILATION_ERROR(e->SourceLoc())
  277. << "field " << f << " is not in the tuple " << *t;
  278. }
  279. auto new_e = global_arena->New<IndexExpression>(
  280. e->SourceLoc(), res.exp,
  281. global_arena->New<IntLiteral>(e->SourceLoc(), i));
  282. return TCExpression(new_e, field_t, res.types);
  283. }
  284. default:
  285. FATAL_COMPILATION_ERROR(e->SourceLoc()) << "expected a tuple";
  286. }
  287. }
  288. case Expression::Kind::TupleLiteral: {
  289. std::vector<FieldInitializer> new_args;
  290. std::vector<TupleElement> arg_types;
  291. auto new_types = types;
  292. for (const auto& arg : cast<TupleLiteral>(*e).Fields()) {
  293. auto arg_res = TypeCheckExp(arg.expression, new_types, values);
  294. new_types = arg_res.types;
  295. new_args.push_back(FieldInitializer(arg.name, arg_res.exp));
  296. arg_types.push_back({.name = arg.name, .value = arg_res.type});
  297. }
  298. auto tuple_e = global_arena->New<TupleLiteral>(e->SourceLoc(), new_args);
  299. auto tuple_t = global_arena->RawNew<TupleValue>(std::move(arg_types));
  300. return TCExpression(tuple_e, tuple_t, new_types);
  301. }
  302. case Expression::Kind::FieldAccessExpression: {
  303. const auto& access = cast<FieldAccessExpression>(*e);
  304. auto res = TypeCheckExp(access.Aggregate(), types, values);
  305. auto t = res.type;
  306. switch (t->Tag()) {
  307. case Value::Kind::ClassType: {
  308. const auto& t_class = cast<ClassType>(*t);
  309. // Search for a field
  310. for (auto& field : t_class.Fields()) {
  311. if (access.Field() == field.first) {
  312. Ptr<const Expression> new_e =
  313. global_arena->New<FieldAccessExpression>(
  314. e->SourceLoc(), res.exp, access.Field());
  315. return TCExpression(new_e, field.second, res.types);
  316. }
  317. }
  318. // Search for a method
  319. for (auto& method : t_class.Methods()) {
  320. if (access.Field() == method.first) {
  321. Ptr<const Expression> new_e =
  322. global_arena->New<FieldAccessExpression>(
  323. e->SourceLoc(), res.exp, access.Field());
  324. return TCExpression(new_e, method.second, res.types);
  325. }
  326. }
  327. FATAL_COMPILATION_ERROR(e->SourceLoc())
  328. << "class " << t_class.Name() << " does not have a field named "
  329. << access.Field();
  330. }
  331. case Value::Kind::TupleValue: {
  332. const auto& tup = cast<TupleValue>(*t);
  333. for (const TupleElement& field : tup.Elements()) {
  334. if (access.Field() == field.name) {
  335. auto new_e = global_arena->New<FieldAccessExpression>(
  336. e->SourceLoc(), res.exp, access.Field());
  337. return TCExpression(new_e, field.value, res.types);
  338. }
  339. }
  340. FATAL_COMPILATION_ERROR(e->SourceLoc())
  341. << "tuple " << tup << " does not have a field named "
  342. << access.Field();
  343. }
  344. case Value::Kind::ChoiceType: {
  345. const auto& choice = cast<ChoiceType>(*t);
  346. for (const auto& vt : choice.Alternatives()) {
  347. if (access.Field() == vt.first) {
  348. Ptr<const Expression> new_e =
  349. global_arena->New<FieldAccessExpression>(
  350. e->SourceLoc(), res.exp, access.Field());
  351. auto fun_ty = global_arena->RawNew<FunctionType>(
  352. std::vector<GenericBinding>(), vt.second, t);
  353. return TCExpression(new_e, fun_ty, res.types);
  354. }
  355. }
  356. FATAL_COMPILATION_ERROR(e->SourceLoc())
  357. << "choice " << choice.Name() << " does not have a field named "
  358. << access.Field();
  359. }
  360. default:
  361. FATAL_COMPILATION_ERROR(e->SourceLoc())
  362. << "field access, expected a struct\n"
  363. << *e;
  364. }
  365. }
  366. case Expression::Kind::IdentifierExpression: {
  367. const auto& ident = cast<IdentifierExpression>(*e);
  368. std::optional<const Value*> type = types.Get(ident.Name());
  369. if (type) {
  370. return TCExpression(e, *type, types);
  371. } else {
  372. FATAL_COMPILATION_ERROR(e->SourceLoc())
  373. << "could not find `" << ident.Name() << "`";
  374. }
  375. }
  376. case Expression::Kind::IntLiteral:
  377. return TCExpression(e, global_arena->RawNew<IntType>(), types);
  378. case Expression::Kind::BoolLiteral:
  379. return TCExpression(e, global_arena->RawNew<BoolType>(), types);
  380. case Expression::Kind::PrimitiveOperatorExpression: {
  381. const auto& op = cast<PrimitiveOperatorExpression>(*e);
  382. std::vector<Ptr<const Expression>> es;
  383. std::vector<const Value*> ts;
  384. auto new_types = types;
  385. for (Ptr<const Expression> argument : op.Arguments()) {
  386. auto res = TypeCheckExp(argument, types, values);
  387. new_types = res.types;
  388. es.push_back(res.exp);
  389. ts.push_back(res.type);
  390. }
  391. auto new_e = global_arena->New<PrimitiveOperatorExpression>(
  392. e->SourceLoc(), op.Op(), es);
  393. switch (op.Op()) {
  394. case Operator::Neg:
  395. ExpectType(e->SourceLoc(), "negation",
  396. global_arena->RawNew<IntType>(), ts[0]);
  397. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  398. new_types);
  399. case Operator::Add:
  400. ExpectType(e->SourceLoc(), "addition(1)",
  401. global_arena->RawNew<IntType>(), ts[0]);
  402. ExpectType(e->SourceLoc(), "addition(2)",
  403. global_arena->RawNew<IntType>(), ts[1]);
  404. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  405. new_types);
  406. case Operator::Sub:
  407. ExpectType(e->SourceLoc(), "subtraction(1)",
  408. global_arena->RawNew<IntType>(), ts[0]);
  409. ExpectType(e->SourceLoc(), "subtraction(2)",
  410. global_arena->RawNew<IntType>(), ts[1]);
  411. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  412. new_types);
  413. case Operator::Mul:
  414. ExpectType(e->SourceLoc(), "multiplication(1)",
  415. global_arena->RawNew<IntType>(), ts[0]);
  416. ExpectType(e->SourceLoc(), "multiplication(2)",
  417. global_arena->RawNew<IntType>(), ts[1]);
  418. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  419. new_types);
  420. case Operator::And:
  421. ExpectType(e->SourceLoc(), "&&(1)", global_arena->RawNew<BoolType>(),
  422. ts[0]);
  423. ExpectType(e->SourceLoc(), "&&(2)", global_arena->RawNew<BoolType>(),
  424. ts[1]);
  425. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  426. new_types);
  427. case Operator::Or:
  428. ExpectType(e->SourceLoc(), "||(1)", global_arena->RawNew<BoolType>(),
  429. ts[0]);
  430. ExpectType(e->SourceLoc(), "||(2)", global_arena->RawNew<BoolType>(),
  431. ts[1]);
  432. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  433. new_types);
  434. case Operator::Not:
  435. ExpectType(e->SourceLoc(), "!", global_arena->RawNew<BoolType>(),
  436. ts[0]);
  437. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  438. new_types);
  439. case Operator::Eq:
  440. ExpectType(e->SourceLoc(), "==", ts[0], ts[1]);
  441. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  442. new_types);
  443. case Operator::Deref:
  444. ExpectPointerType(e->SourceLoc(), "*", ts[0]);
  445. return TCExpression(new_e, cast<PointerType>(*ts[0]).Type(),
  446. new_types);
  447. case Operator::Ptr:
  448. ExpectType(e->SourceLoc(), "*", global_arena->RawNew<TypeType>(),
  449. ts[0]);
  450. return TCExpression(new_e, global_arena->RawNew<TypeType>(),
  451. new_types);
  452. }
  453. break;
  454. }
  455. case Expression::Kind::CallExpression: {
  456. const auto& call = cast<CallExpression>(*e);
  457. auto fun_res = TypeCheckExp(call.Function(), types, values);
  458. switch (fun_res.type->Tag()) {
  459. case Value::Kind::FunctionType: {
  460. const auto& fun_t = cast<FunctionType>(*fun_res.type);
  461. auto arg_res = TypeCheckExp(call.Argument(), fun_res.types, values);
  462. auto parameter_type = fun_t.Param();
  463. auto return_type = fun_t.Ret();
  464. if (!fun_t.Deduced().empty()) {
  465. auto deduced_args = ArgumentDeduction(e->SourceLoc(), TypeEnv(),
  466. parameter_type, arg_res.type);
  467. for (auto& deduced_param : fun_t.Deduced()) {
  468. // TODO: change the following to a CHECK once the real checking
  469. // has been added to the type checking of function signatures.
  470. if (!deduced_args.Get(deduced_param.name)) {
  471. FATAL_COMPILATION_ERROR(e->SourceLoc())
  472. << "could not deduce type argument for type parameter "
  473. << deduced_param.name;
  474. }
  475. }
  476. parameter_type = Substitute(deduced_args, parameter_type);
  477. return_type = Substitute(deduced_args, return_type);
  478. } else {
  479. ExpectType(e->SourceLoc(), "call", parameter_type, arg_res.type);
  480. }
  481. auto new_e = global_arena->New<CallExpression>(
  482. e->SourceLoc(), fun_res.exp, arg_res.exp);
  483. return TCExpression(new_e, return_type, arg_res.types);
  484. }
  485. default: {
  486. FATAL_COMPILATION_ERROR(e->SourceLoc())
  487. << "in call, expected a function\n"
  488. << *e;
  489. }
  490. }
  491. break;
  492. }
  493. case Expression::Kind::FunctionTypeLiteral: {
  494. const auto& fn = cast<FunctionTypeLiteral>(*e);
  495. auto pt = InterpExp(values, fn.Parameter());
  496. auto rt = InterpExp(values, fn.ReturnType());
  497. auto new_e = global_arena->New<FunctionTypeLiteral>(
  498. e->SourceLoc(), ReifyType(pt, e->SourceLoc()),
  499. ReifyType(rt, e->SourceLoc()),
  500. /*is_omitted_return_type=*/false);
  501. return TCExpression(new_e, global_arena->RawNew<TypeType>(), types);
  502. }
  503. case Expression::Kind::StringLiteral:
  504. return TCExpression(e, global_arena->RawNew<StringType>(), types);
  505. case Expression::Kind::IntrinsicExpression:
  506. switch (cast<IntrinsicExpression>(*e).Intrinsic()) {
  507. case IntrinsicExpression::IntrinsicKind::Print:
  508. return TCExpression(e, &TupleValue::Empty(), types);
  509. }
  510. case Expression::Kind::IntTypeLiteral:
  511. case Expression::Kind::BoolTypeLiteral:
  512. case Expression::Kind::StringTypeLiteral:
  513. case Expression::Kind::TypeTypeLiteral:
  514. case Expression::Kind::ContinuationTypeLiteral:
  515. return TCExpression(e, global_arena->RawNew<TypeType>(), types);
  516. }
  517. }
  518. // Equivalent to TypeCheckExp, but operates on Patterns instead of Expressions.
  519. // `expected` is the type that this pattern is expected to have, if the
  520. // surrounding context gives us that information. Otherwise, it is null.
  521. auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
  522. const Value* expected) -> TCPattern {
  523. if (tracing_output) {
  524. llvm::outs() << "checking pattern " << *p;
  525. if (expected) {
  526. llvm::outs() << ", expecting " << *expected;
  527. }
  528. llvm::outs() << "\ntypes: ";
  529. PrintTypeEnv(types, llvm::outs());
  530. llvm::outs() << "\nvalues: ";
  531. PrintEnv(values, llvm::outs());
  532. llvm::outs() << "\n";
  533. }
  534. switch (p->Tag()) {
  535. case Pattern::Kind::AutoPattern: {
  536. return {.pattern = p,
  537. .type = global_arena->RawNew<TypeType>(),
  538. .types = types};
  539. }
  540. case Pattern::Kind::BindingPattern: {
  541. const auto& binding = cast<BindingPattern>(*p);
  542. TCPattern binding_type_result =
  543. TypeCheckPattern(binding.Type(), types, values, nullptr);
  544. const Value* type = InterpPattern(values, binding_type_result.pattern);
  545. if (expected != nullptr) {
  546. std::optional<Env> values =
  547. PatternMatch(type, expected, binding.Type()->SourceLoc());
  548. if (values == std::nullopt) {
  549. FATAL_COMPILATION_ERROR(binding.Type()->SourceLoc())
  550. << "Type pattern '" << *type << "' does not match actual type '"
  551. << *expected << "'";
  552. }
  553. CHECK(values->begin() == values->end())
  554. << "Name bindings within type patterns are unsupported";
  555. type = expected;
  556. }
  557. auto new_p = global_arena->New<BindingPattern>(
  558. binding.SourceLoc(), binding.Name(),
  559. global_arena->New<ExpressionPattern>(
  560. ReifyType(type, binding.SourceLoc())));
  561. if (binding.Name().has_value()) {
  562. types.Set(*binding.Name(), type);
  563. }
  564. return {.pattern = new_p, .type = type, .types = types};
  565. }
  566. case Pattern::Kind::TuplePattern: {
  567. const auto& tuple = cast<TuplePattern>(*p);
  568. std::vector<TuplePattern::Field> new_fields;
  569. std::vector<TupleElement> field_types;
  570. auto new_types = types;
  571. if (expected && expected->Tag() != Value::Kind::TupleValue) {
  572. FATAL_COMPILATION_ERROR(p->SourceLoc()) << "didn't expect a tuple";
  573. }
  574. if (expected && tuple.Fields().size() !=
  575. cast<TupleValue>(*expected).Elements().size()) {
  576. FATAL_COMPILATION_ERROR(tuple.SourceLoc())
  577. << "tuples of different length";
  578. }
  579. for (size_t i = 0; i < tuple.Fields().size(); ++i) {
  580. const TuplePattern::Field& field = tuple.Fields()[i];
  581. const Value* expected_field_type = nullptr;
  582. if (expected != nullptr) {
  583. const TupleElement& expected_element =
  584. cast<TupleValue>(*expected).Elements()[i];
  585. if (expected_element.name != field.name) {
  586. FATAL_COMPILATION_ERROR(tuple.SourceLoc())
  587. << "field names do not match, expected "
  588. << expected_element.name << " but got " << field.name;
  589. }
  590. expected_field_type = expected_element.value;
  591. }
  592. auto field_result = TypeCheckPattern(field.pattern, new_types, values,
  593. expected_field_type);
  594. new_types = field_result.types;
  595. new_fields.push_back(
  596. TuplePattern::Field(field.name, field_result.pattern));
  597. field_types.push_back({.name = field.name, .value = field_result.type});
  598. }
  599. auto new_tuple =
  600. global_arena->New<TuplePattern>(tuple.SourceLoc(), new_fields);
  601. auto tuple_t = global_arena->RawNew<TupleValue>(std::move(field_types));
  602. return {.pattern = new_tuple, .type = tuple_t, .types = new_types};
  603. }
  604. case Pattern::Kind::AlternativePattern: {
  605. const auto& alternative = cast<AlternativePattern>(*p);
  606. const Value* choice_type = InterpExp(values, alternative.ChoiceType());
  607. if (choice_type->Tag() != Value::Kind::ChoiceType) {
  608. FATAL_COMPILATION_ERROR(alternative.SourceLoc())
  609. << "alternative pattern does not name a choice type.";
  610. }
  611. if (expected != nullptr) {
  612. ExpectType(alternative.SourceLoc(), "alternative pattern", expected,
  613. choice_type);
  614. }
  615. const Value* parameter_types =
  616. FindInVarValues(alternative.AlternativeName(),
  617. cast<ChoiceType>(*choice_type).Alternatives());
  618. if (parameter_types == nullptr) {
  619. FATAL_COMPILATION_ERROR(alternative.SourceLoc())
  620. << "'" << alternative.AlternativeName()
  621. << "' is not an alternative of " << choice_type;
  622. }
  623. TCPattern arg_results = TypeCheckPattern(alternative.Arguments(), types,
  624. values, parameter_types);
  625. // TODO: Think about a cleaner way to cast between Ptr types.
  626. auto arguments = Ptr<const TuplePattern>(
  627. cast<const TuplePattern>(arg_results.pattern.Get()));
  628. return {.pattern = global_arena->New<AlternativePattern>(
  629. alternative.SourceLoc(),
  630. ReifyType(choice_type, alternative.SourceLoc()),
  631. alternative.AlternativeName(), arguments),
  632. .type = choice_type,
  633. .types = arg_results.types};
  634. }
  635. case Pattern::Kind::ExpressionPattern: {
  636. TCExpression result =
  637. TypeCheckExp(cast<ExpressionPattern>(*p).Expression(), types, values);
  638. return {.pattern = global_arena->New<ExpressionPattern>(result.exp),
  639. .type = result.type,
  640. .types = result.types};
  641. }
  642. }
  643. }
  644. static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
  645. const Statement* body, TypeEnv types, Env values,
  646. const Value*& ret_type, bool is_omitted_ret_type)
  647. -> std::pair<Ptr<const Pattern>, const Statement*> {
  648. auto pat_res = TypeCheckPattern(pat, types, values, expected);
  649. auto res =
  650. TypeCheckStmt(body, pat_res.types, values, ret_type, is_omitted_ret_type);
  651. return std::make_pair(pat, res.stmt);
  652. }
  653. // The TypeCheckStmt function performs semantic analysis on a statement.
  654. // It returns a new version of the statement and a new type environment.
  655. //
  656. // The ret_type parameter is used for analyzing return statements.
  657. // It is the declared return type of the enclosing function definition.
  658. // If the return type is "auto", then the return type is inferred from
  659. // the first return statement.
  660. auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
  661. const Value*& ret_type, bool is_omitted_ret_type)
  662. -> TCStatement {
  663. if (!s) {
  664. return TCStatement(s, types);
  665. }
  666. switch (s->Tag()) {
  667. case Statement::Kind::Match: {
  668. const auto& match = cast<Match>(*s);
  669. auto res = TypeCheckExp(match.Exp(), types, values);
  670. auto res_type = res.type;
  671. auto new_clauses = global_arena->RawNew<
  672. std::list<std::pair<Ptr<const Pattern>, const Statement*>>>();
  673. for (auto& clause : *match.Clauses()) {
  674. new_clauses->push_back(TypecheckCase(res_type, clause.first,
  675. clause.second, types, values,
  676. ret_type, is_omitted_ret_type));
  677. }
  678. const Statement* new_s =
  679. global_arena->RawNew<Match>(s->SourceLoc(), res.exp, new_clauses);
  680. return TCStatement(new_s, types);
  681. }
  682. case Statement::Kind::While: {
  683. const auto& while_stmt = cast<While>(*s);
  684. auto cnd_res = TypeCheckExp(while_stmt.Cond(), types, values);
  685. ExpectType(s->SourceLoc(), "condition of `while`",
  686. global_arena->RawNew<BoolType>(), cnd_res.type);
  687. auto body_res = TypeCheckStmt(while_stmt.Body(), types, values, ret_type,
  688. is_omitted_ret_type);
  689. auto new_s = global_arena->RawNew<While>(s->SourceLoc(), cnd_res.exp,
  690. body_res.stmt);
  691. return TCStatement(new_s, types);
  692. }
  693. case Statement::Kind::Break:
  694. case Statement::Kind::Continue:
  695. return TCStatement(s, types);
  696. case Statement::Kind::Block: {
  697. auto stmt_res = TypeCheckStmt(cast<Block>(*s).Stmt(), types, values,
  698. ret_type, is_omitted_ret_type);
  699. return TCStatement(
  700. global_arena->RawNew<Block>(s->SourceLoc(), stmt_res.stmt), types);
  701. }
  702. case Statement::Kind::VariableDefinition: {
  703. const auto& var = cast<VariableDefinition>(*s);
  704. auto res = TypeCheckExp(var.Init(), types, values);
  705. const Value* rhs_ty = res.type;
  706. auto lhs_res = TypeCheckPattern(var.Pat(), types, values, rhs_ty);
  707. const Statement* new_s = global_arena->RawNew<VariableDefinition>(
  708. s->SourceLoc(), var.Pat(), res.exp);
  709. return TCStatement(new_s, lhs_res.types);
  710. }
  711. case Statement::Kind::Sequence: {
  712. const auto& seq = cast<Sequence>(*s);
  713. auto stmt_res = TypeCheckStmt(seq.Stmt(), types, values, ret_type,
  714. is_omitted_ret_type);
  715. auto types2 = stmt_res.types;
  716. auto next_res = TypeCheckStmt(seq.Next(), types2, values, ret_type,
  717. is_omitted_ret_type);
  718. auto types3 = next_res.types;
  719. return TCStatement(global_arena->RawNew<Sequence>(
  720. s->SourceLoc(), stmt_res.stmt, next_res.stmt),
  721. types3);
  722. }
  723. case Statement::Kind::Assign: {
  724. const auto& assign = cast<Assign>(*s);
  725. auto rhs_res = TypeCheckExp(assign.Rhs(), types, values);
  726. auto rhs_t = rhs_res.type;
  727. auto lhs_res = TypeCheckExp(assign.Lhs(), types, values);
  728. auto lhs_t = lhs_res.type;
  729. ExpectType(s->SourceLoc(), "assign", lhs_t, rhs_t);
  730. auto new_s = global_arena->RawNew<Assign>(s->SourceLoc(), lhs_res.exp,
  731. rhs_res.exp);
  732. return TCStatement(new_s, lhs_res.types);
  733. }
  734. case Statement::Kind::ExpressionStatement: {
  735. auto res =
  736. TypeCheckExp(cast<ExpressionStatement>(*s).Exp(), types, values);
  737. auto new_s =
  738. global_arena->RawNew<ExpressionStatement>(s->SourceLoc(), res.exp);
  739. return TCStatement(new_s, types);
  740. }
  741. case Statement::Kind::If: {
  742. const auto& if_stmt = cast<If>(*s);
  743. auto cnd_res = TypeCheckExp(if_stmt.Cond(), types, values);
  744. ExpectType(s->SourceLoc(), "condition of `if`",
  745. global_arena->RawNew<BoolType>(), cnd_res.type);
  746. auto then_res = TypeCheckStmt(if_stmt.ThenStmt(), types, values, ret_type,
  747. is_omitted_ret_type);
  748. auto else_res = TypeCheckStmt(if_stmt.ElseStmt(), types, values, ret_type,
  749. is_omitted_ret_type);
  750. auto new_s = global_arena->RawNew<If>(s->SourceLoc(), cnd_res.exp,
  751. then_res.stmt, else_res.stmt);
  752. return TCStatement(new_s, types);
  753. }
  754. case Statement::Kind::Return: {
  755. const auto& ret = cast<Return>(*s);
  756. auto res = TypeCheckExp(ret.Exp(), types, values);
  757. if (ret_type->Tag() == Value::Kind::AutoType) {
  758. // The following infers the return type from the first 'return'
  759. // statement. This will get more difficult with subtyping, when we
  760. // should infer the least-upper bound of all the 'return' statements.
  761. ret_type = res.type;
  762. } else {
  763. ExpectType(s->SourceLoc(), "return", ret_type, res.type);
  764. }
  765. if (ret.IsOmittedExp() != is_omitted_ret_type) {
  766. FATAL_COMPILATION_ERROR(s->SourceLoc())
  767. << *s << " should" << (is_omitted_ret_type ? " not" : "")
  768. << " provide a return value, to match the function's signature.";
  769. }
  770. return TCStatement(global_arena->RawNew<Return>(s->SourceLoc(), res.exp,
  771. ret.IsOmittedExp()),
  772. types);
  773. }
  774. case Statement::Kind::Continuation: {
  775. const auto& cont = cast<Continuation>(*s);
  776. TCStatement body_result = TypeCheckStmt(cont.Body(), types, values,
  777. ret_type, is_omitted_ret_type);
  778. const Statement* new_continuation = global_arena->RawNew<Continuation>(
  779. s->SourceLoc(), cont.ContinuationVariable(), body_result.stmt);
  780. types.Set(cont.ContinuationVariable(),
  781. global_arena->RawNew<ContinuationType>());
  782. return TCStatement(new_continuation, types);
  783. }
  784. case Statement::Kind::Run: {
  785. TCExpression argument_result =
  786. TypeCheckExp(cast<Run>(*s).Argument(), types, values);
  787. ExpectType(s->SourceLoc(), "argument of `run`",
  788. global_arena->RawNew<ContinuationType>(),
  789. argument_result.type);
  790. const Statement* new_run =
  791. global_arena->RawNew<Run>(s->SourceLoc(), argument_result.exp);
  792. return TCStatement(new_run, types);
  793. }
  794. case Statement::Kind::Await: {
  795. // nothing to do here
  796. return TCStatement(s, types);
  797. }
  798. } // switch
  799. }
  800. static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
  801. SourceLocation loc) -> const Statement* {
  802. if (!stmt) {
  803. if (omitted_ret_type) {
  804. return global_arena->RawNew<Return>(loc);
  805. } else {
  806. FATAL_COMPILATION_ERROR(loc)
  807. << "control-flow reaches end of function that provides a `->` return "
  808. "type without reaching a return statement";
  809. }
  810. }
  811. switch (stmt->Tag()) {
  812. case Statement::Kind::Match: {
  813. const auto& match = cast<Match>(*stmt);
  814. auto new_clauses = global_arena->RawNew<
  815. std::list<std::pair<Ptr<const Pattern>, const Statement*>>>();
  816. for (const auto& clause : *match.Clauses()) {
  817. auto s = CheckOrEnsureReturn(clause.second, omitted_ret_type,
  818. stmt->SourceLoc());
  819. new_clauses->push_back(std::make_pair(clause.first, s));
  820. }
  821. return global_arena->RawNew<Match>(stmt->SourceLoc(), match.Exp(),
  822. new_clauses);
  823. }
  824. case Statement::Kind::Block:
  825. return global_arena->RawNew<Block>(
  826. stmt->SourceLoc(),
  827. CheckOrEnsureReturn(cast<Block>(*stmt).Stmt(), omitted_ret_type,
  828. stmt->SourceLoc()));
  829. case Statement::Kind::If: {
  830. const auto& if_stmt = cast<If>(*stmt);
  831. return global_arena->RawNew<If>(
  832. stmt->SourceLoc(), if_stmt.Cond(),
  833. CheckOrEnsureReturn(if_stmt.ThenStmt(), omitted_ret_type,
  834. stmt->SourceLoc()),
  835. CheckOrEnsureReturn(if_stmt.ElseStmt(), omitted_ret_type,
  836. stmt->SourceLoc()));
  837. }
  838. case Statement::Kind::Return:
  839. return stmt;
  840. case Statement::Kind::Sequence: {
  841. const auto& seq = cast<Sequence>(*stmt);
  842. if (seq.Next()) {
  843. return global_arena->RawNew<Sequence>(
  844. stmt->SourceLoc(), seq.Stmt(),
  845. CheckOrEnsureReturn(seq.Next(), omitted_ret_type,
  846. stmt->SourceLoc()));
  847. } else {
  848. return CheckOrEnsureReturn(seq.Stmt(), omitted_ret_type,
  849. stmt->SourceLoc());
  850. }
  851. }
  852. case Statement::Kind::Continuation:
  853. case Statement::Kind::Run:
  854. case Statement::Kind::Await:
  855. return stmt;
  856. case Statement::Kind::Assign:
  857. case Statement::Kind::ExpressionStatement:
  858. case Statement::Kind::While:
  859. case Statement::Kind::Break:
  860. case Statement::Kind::Continue:
  861. case Statement::Kind::VariableDefinition:
  862. if (omitted_ret_type) {
  863. return global_arena->RawNew<Sequence>(
  864. stmt->SourceLoc(), stmt, global_arena->RawNew<Return>(loc));
  865. } else {
  866. FATAL_COMPILATION_ERROR(stmt->SourceLoc())
  867. << "control-flow reaches end of function that provides a `->` "
  868. "return type without reaching a return statement";
  869. }
  870. }
  871. }
  872. // TODO: factor common parts of TypeCheckFunDef and TypeOfFunDef into
  873. // a function.
  874. // TODO: Add checking to function definitions to ensure that
  875. // all deduced type parameters will be deduced.
  876. static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
  877. Env values) -> Ptr<const FunctionDefinition> {
  878. // Bring the deduced parameters into scope
  879. for (const auto& deduced : f->deduced_parameters) {
  880. // auto t = InterpExp(values, deduced.type);
  881. types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
  882. Address a = state->heap.AllocateValue(*types.Get(deduced.name));
  883. values.Set(deduced.name, a);
  884. }
  885. // Type check the parameter pattern
  886. auto param_res = TypeCheckPattern(f->param_pattern, types, values, nullptr);
  887. // Evaluate the return type expression
  888. auto return_type = InterpPattern(values, f->return_type);
  889. if (f->name == "main") {
  890. ExpectType(f->source_location, "return type of `main`",
  891. global_arena->RawNew<IntType>(), return_type);
  892. // TODO: Check that main doesn't have any parameters.
  893. }
  894. auto res = TypeCheckStmt(f->body, param_res.types, values, return_type,
  895. f->is_omitted_return_type);
  896. auto body = CheckOrEnsureReturn(res.stmt, f->is_omitted_return_type,
  897. f->source_location);
  898. return global_arena->New<FunctionDefinition>(
  899. f->source_location, f->name, f->deduced_parameters, f->param_pattern,
  900. global_arena->New<ExpressionPattern>(
  901. ReifyType(return_type, f->source_location)),
  902. /*is_omitted_return_type=*/false, body);
  903. }
  904. static auto TypeOfFunDef(TypeEnv types, Env values,
  905. const FunctionDefinition* fun_def) -> const Value* {
  906. // Bring the deduced parameters into scope
  907. for (const auto& deduced : fun_def->deduced_parameters) {
  908. // auto t = InterpExp(values, deduced.type);
  909. types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
  910. Address a = state->heap.AllocateValue(*types.Get(deduced.name));
  911. values.Set(deduced.name, a);
  912. }
  913. // Type check the parameter pattern
  914. auto param_res =
  915. TypeCheckPattern(fun_def->param_pattern, types, values, nullptr);
  916. // Evaluate the return type expression
  917. auto ret = InterpPattern(values, fun_def->return_type);
  918. if (ret->Tag() == Value::Kind::AutoType) {
  919. auto f = TypeCheckFunDef(fun_def, types, values);
  920. ret = InterpPattern(values, f->return_type);
  921. }
  922. return global_arena->RawNew<FunctionType>(fun_def->deduced_parameters,
  923. param_res.type, ret);
  924. }
  925. static auto TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
  926. Env ct_top) -> const Value* {
  927. VarValues fields;
  928. VarValues methods;
  929. for (Ptr<const Member> m : sd->members) {
  930. switch (m->Tag()) {
  931. case Member::Kind::FieldMember: {
  932. Ptr<const BindingPattern> binding = cast<FieldMember>(*m).Binding();
  933. if (!binding->Name().has_value()) {
  934. FATAL_COMPILATION_ERROR(binding->SourceLoc())
  935. << "Struct members must have names";
  936. }
  937. const auto* binding_type =
  938. dyn_cast<ExpressionPattern>(binding->Type().Get());
  939. if (binding_type == nullptr) {
  940. FATAL_COMPILATION_ERROR(binding->SourceLoc())
  941. << "Struct members must have explicit types";
  942. }
  943. auto type = InterpExp(ct_top, binding_type->Expression());
  944. fields.push_back(std::make_pair(*binding->Name(), type));
  945. break;
  946. }
  947. }
  948. }
  949. return global_arena->RawNew<ClassType>(sd->name, std::move(fields),
  950. std::move(methods));
  951. }
  952. static auto GetName(const Declaration& d) -> const std::string& {
  953. switch (d.Tag()) {
  954. case Declaration::Kind::FunctionDeclaration:
  955. return cast<FunctionDeclaration>(d).Definition().name;
  956. case Declaration::Kind::ClassDeclaration:
  957. return cast<ClassDeclaration>(d).Definition().name;
  958. case Declaration::Kind::ChoiceDeclaration:
  959. return cast<ChoiceDeclaration>(d).Name();
  960. case Declaration::Kind::VariableDeclaration: {
  961. Ptr<const BindingPattern> binding =
  962. cast<VariableDeclaration>(d).Binding();
  963. if (!binding->Name().has_value()) {
  964. FATAL_COMPILATION_ERROR(binding->SourceLoc())
  965. << "Top-level variable declarations must have names";
  966. }
  967. return *binding->Name();
  968. }
  969. }
  970. }
  971. auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
  972. const Env& values) -> Ptr<const Declaration> {
  973. switch (d->Tag()) {
  974. case Declaration::Kind::FunctionDeclaration:
  975. return global_arena->New<FunctionDeclaration>(TypeCheckFunDef(
  976. &cast<FunctionDeclaration>(*d).Definition(), types, values));
  977. case Declaration::Kind::ClassDeclaration: {
  978. const ClassDefinition& class_def =
  979. cast<ClassDeclaration>(*d).Definition();
  980. std::list<Ptr<Member>> fields;
  981. for (Ptr<Member> m : class_def.members) {
  982. switch (m->Tag()) {
  983. case Member::Kind::FieldMember:
  984. // TODO: Interpret the type expression and store the result.
  985. fields.push_back(m);
  986. break;
  987. }
  988. }
  989. return global_arena->New<ClassDeclaration>(class_def.loc, class_def.name,
  990. std::move(fields));
  991. }
  992. case Declaration::Kind::ChoiceDeclaration:
  993. // TODO
  994. return d;
  995. case Declaration::Kind::VariableDeclaration: {
  996. const auto& var = cast<VariableDeclaration>(*d);
  997. // Signals a type error if the initializing expression does not have
  998. // the declared type of the variable, otherwise returns this
  999. // declaration with annotated types.
  1000. TCExpression type_checked_initializer =
  1001. TypeCheckExp(var.Initializer(), types, values);
  1002. const auto* binding_type =
  1003. dyn_cast<ExpressionPattern>(var.Binding()->Type().Get());
  1004. if (binding_type == nullptr) {
  1005. // TODO: consider adding support for `auto`
  1006. FATAL_COMPILATION_ERROR(var.SourceLoc())
  1007. << "Type of a top-level variable must be an expression.";
  1008. }
  1009. const Value* declared_type =
  1010. InterpExp(values, binding_type->Expression());
  1011. ExpectType(var.SourceLoc(), "initializer of variable", declared_type,
  1012. type_checked_initializer.type);
  1013. return d;
  1014. }
  1015. }
  1016. }
  1017. static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
  1018. switch (d.Tag()) {
  1019. case Declaration::Kind::FunctionDeclaration: {
  1020. const FunctionDefinition& func_def =
  1021. cast<FunctionDeclaration>(d).Definition();
  1022. auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
  1023. tops->types.Set(func_def.name, t);
  1024. InitEnv(d, &tops->values);
  1025. break;
  1026. }
  1027. case Declaration::Kind::ClassDeclaration: {
  1028. const ClassDefinition& class_def = cast<ClassDeclaration>(d).Definition();
  1029. auto st = TypeOfClassDef(&class_def, tops->types, tops->values);
  1030. Address a = state->heap.AllocateValue(st);
  1031. tops->values.Set(class_def.name, a); // Is this obsolete?
  1032. std::vector<TupleElement> field_types;
  1033. for (const auto& [field_name, field_value] :
  1034. cast<ClassType>(*st).Fields()) {
  1035. field_types.push_back({.name = field_name, .value = field_value});
  1036. }
  1037. auto fun_ty = global_arena->RawNew<FunctionType>(
  1038. std::vector<GenericBinding>(),
  1039. global_arena->RawNew<TupleValue>(std::move(field_types)), st);
  1040. tops->types.Set(class_def.name, fun_ty);
  1041. break;
  1042. }
  1043. case Declaration::Kind::ChoiceDeclaration: {
  1044. const auto& choice = cast<ChoiceDeclaration>(d);
  1045. VarValues alts;
  1046. for (const auto& [name, signature] : choice.Alternatives()) {
  1047. auto t = InterpExp(tops->values, signature);
  1048. alts.push_back(std::make_pair(name, t));
  1049. }
  1050. auto ct =
  1051. global_arena->RawNew<ChoiceType>(choice.Name(), std::move(alts));
  1052. Address a = state->heap.AllocateValue(ct);
  1053. tops->values.Set(choice.Name(), a); // Is this obsolete?
  1054. tops->types.Set(choice.Name(), ct);
  1055. break;
  1056. }
  1057. case Declaration::Kind::VariableDeclaration: {
  1058. const auto& var = cast<VariableDeclaration>(d);
  1059. // Associate the variable name with it's declared type in the
  1060. // compile-time symbol table.
  1061. Ptr<const Expression> type =
  1062. cast<ExpressionPattern>(*var.Binding()->Type()).Expression();
  1063. const Value* declared_type = InterpExp(tops->values, type);
  1064. tops->types.Set(*var.Binding()->Name(), declared_type);
  1065. break;
  1066. }
  1067. }
  1068. }
  1069. auto TopLevel(const std::list<Ptr<const Declaration>>& fs) -> TypeCheckContext {
  1070. TypeCheckContext tops;
  1071. bool found_main = false;
  1072. for (auto const& d : fs) {
  1073. if (GetName(*d) == "main") {
  1074. found_main = true;
  1075. }
  1076. TopLevel(*d, &tops);
  1077. }
  1078. if (found_main == false) {
  1079. FATAL_COMPILATION_ERROR_NO_LINE()
  1080. << "program must contain a function named `main`";
  1081. }
  1082. return tops;
  1083. }
  1084. } // namespace Carbon