typecheck.cpp 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "executable_semantics/interpreter/typecheck.h"
  5. #include <algorithm>
  6. #include <iterator>
  7. #include <map>
  8. #include <set>
  9. #include <vector>
  10. #include "common/ostream.h"
  11. #include "executable_semantics/ast/function_definition.h"
  12. #include "executable_semantics/common/arena.h"
  13. #include "executable_semantics/common/error.h"
  14. #include "executable_semantics/common/tracing_flag.h"
  15. #include "executable_semantics/interpreter/interpreter.h"
  16. #include "executable_semantics/interpreter/value.h"
  17. #include "llvm/ADT/StringExtras.h"
  18. #include "llvm/Support/Casting.h"
  19. using llvm::cast;
  20. using llvm::dyn_cast;
  21. namespace Carbon {
  22. void PrintTypeEnv(TypeEnv types, llvm::raw_ostream& out) {
  23. llvm::ListSeparator sep;
  24. for (const auto& [name, type] : types) {
  25. out << sep << name << ": " << *type;
  26. }
  27. }
  28. static void ExpectType(int line_num, const std::string& context,
  29. const Value* expected, const Value* actual) {
  30. if (!TypeEqual(expected, actual)) {
  31. FATAL_COMPILATION_ERROR(line_num) << "type error in " << context << "\n"
  32. << "expected: " << *expected << "\n"
  33. << "actual: " << *actual;
  34. }
  35. }
  36. static void ExpectPointerType(int line_num, const std::string& context,
  37. const Value* actual) {
  38. if (actual->Tag() != Value::Kind::PointerType) {
  39. FATAL_COMPILATION_ERROR(line_num) << "type error in " << context << "\n"
  40. << "expected a pointer type\n"
  41. << "actual: " << *actual;
  42. }
  43. }
  44. // Reify type to type expression.
  45. static auto ReifyType(const Value* t, int line_num) -> const Expression* {
  46. switch (t->Tag()) {
  47. case Value::Kind::IntType:
  48. return global_arena->RawNew<IntTypeLiteral>(0);
  49. case Value::Kind::BoolType:
  50. return global_arena->RawNew<BoolTypeLiteral>(0);
  51. case Value::Kind::TypeType:
  52. return global_arena->RawNew<TypeTypeLiteral>(0);
  53. case Value::Kind::ContinuationType:
  54. return global_arena->RawNew<ContinuationTypeLiteral>(0);
  55. case Value::Kind::FunctionType: {
  56. const auto& fn_type = cast<FunctionType>(*t);
  57. return global_arena->RawNew<FunctionTypeLiteral>(
  58. 0, ReifyType(fn_type.Param(), line_num),
  59. ReifyType(fn_type.Ret(), line_num),
  60. /*is_omitted_return_type=*/false);
  61. }
  62. case Value::Kind::TupleValue: {
  63. std::vector<FieldInitializer> args;
  64. for (const TupleElement& field : cast<TupleValue>(*t).Elements()) {
  65. args.push_back(
  66. FieldInitializer(field.name, ReifyType(field.value, line_num)));
  67. }
  68. return global_arena->RawNew<TupleLiteral>(0, args);
  69. }
  70. case Value::Kind::StructType:
  71. return global_arena->RawNew<IdentifierExpression>(
  72. 0, cast<StructType>(*t).Name());
  73. case Value::Kind::ChoiceType:
  74. return global_arena->RawNew<IdentifierExpression>(
  75. 0, cast<ChoiceType>(*t).Name());
  76. case Value::Kind::PointerType:
  77. return global_arena->RawNew<PrimitiveOperatorExpression>(
  78. 0, Operator::Ptr,
  79. std::vector<const Expression*>(
  80. {ReifyType(cast<PointerType>(*t).Type(), line_num)}));
  81. case Value::Kind::VariableType:
  82. return global_arena->RawNew<IdentifierExpression>(
  83. 0, cast<VariableType>(*t).Name());
  84. case Value::Kind::StringType:
  85. return global_arena->RawNew<StringTypeLiteral>(0);
  86. case Value::Kind::AlternativeConstructorValue:
  87. case Value::Kind::AlternativeValue:
  88. case Value::Kind::AutoType:
  89. case Value::Kind::BindingPlaceholderValue:
  90. case Value::Kind::BoolValue:
  91. case Value::Kind::ContinuationValue:
  92. case Value::Kind::FunctionValue:
  93. case Value::Kind::IntValue:
  94. case Value::Kind::PointerValue:
  95. case Value::Kind::StringValue:
  96. case Value::Kind::StructValue:
  97. FATAL() << "expected a type, not " << *t;
  98. }
  99. }
  100. // Perform type argument deduction, matching the parameter type `param`
  101. // against the argument type `arg`. Whenever there is an VariableType
  102. // in the parameter type, it is deduced to be the corresponding type
  103. // inside the argument type.
  104. // The `deduced` parameter is an accumulator, that is, it holds the
  105. // results so-far.
  106. static auto ArgumentDeduction(int line_num, TypeEnv deduced, const Value* param,
  107. const Value* arg) -> TypeEnv {
  108. switch (param->Tag()) {
  109. case Value::Kind::VariableType: {
  110. const auto& var_type = cast<VariableType>(*param);
  111. std::optional<const Value*> d = deduced.Get(var_type.Name());
  112. if (!d) {
  113. deduced.Set(var_type.Name(), arg);
  114. } else {
  115. ExpectType(line_num, "argument deduction", *d, arg);
  116. }
  117. return deduced;
  118. }
  119. case Value::Kind::TupleValue: {
  120. if (arg->Tag() != Value::Kind::TupleValue) {
  121. ExpectType(line_num, "argument deduction", param, arg);
  122. }
  123. const auto& param_tup = cast<TupleValue>(*param);
  124. const auto& arg_tup = cast<TupleValue>(*arg);
  125. if (param_tup.Elements().size() != arg_tup.Elements().size()) {
  126. ExpectType(line_num, "argument deduction", param, arg);
  127. }
  128. for (size_t i = 0; i < param_tup.Elements().size(); ++i) {
  129. if (param_tup.Elements()[i].name != arg_tup.Elements()[i].name) {
  130. FATAL_COMPILATION_ERROR(line_num)
  131. << "mismatch in tuple names, " << param_tup.Elements()[i].name
  132. << " != " << arg_tup.Elements()[i].name;
  133. }
  134. deduced =
  135. ArgumentDeduction(line_num, deduced, param_tup.Elements()[i].value,
  136. arg_tup.Elements()[i].value);
  137. }
  138. return deduced;
  139. }
  140. case Value::Kind::FunctionType: {
  141. if (arg->Tag() != Value::Kind::FunctionType) {
  142. ExpectType(line_num, "argument deduction", param, arg);
  143. }
  144. const auto& param_fn = cast<FunctionType>(*param);
  145. const auto& arg_fn = cast<FunctionType>(*arg);
  146. // TODO: handle situation when arg has deduced parameters.
  147. deduced = ArgumentDeduction(line_num, deduced, param_fn.Param(),
  148. arg_fn.Param());
  149. deduced =
  150. ArgumentDeduction(line_num, deduced, param_fn.Ret(), arg_fn.Ret());
  151. return deduced;
  152. }
  153. case Value::Kind::PointerType: {
  154. if (arg->Tag() != Value::Kind::PointerType) {
  155. ExpectType(line_num, "argument deduction", param, arg);
  156. }
  157. return ArgumentDeduction(line_num, deduced,
  158. cast<PointerType>(*param).Type(),
  159. cast<PointerType>(*arg).Type());
  160. }
  161. // Nothing to do in the case for `auto`.
  162. case Value::Kind::AutoType: {
  163. return deduced;
  164. }
  165. // For the following cases, we check for type equality.
  166. case Value::Kind::ContinuationType:
  167. case Value::Kind::StructType:
  168. case Value::Kind::ChoiceType:
  169. case Value::Kind::IntType:
  170. case Value::Kind::BoolType:
  171. case Value::Kind::TypeType:
  172. case Value::Kind::StringType:
  173. ExpectType(line_num, "argument deduction", param, arg);
  174. return deduced;
  175. // The rest of these cases should never happen.
  176. case Value::Kind::IntValue:
  177. case Value::Kind::BoolValue:
  178. case Value::Kind::FunctionValue:
  179. case Value::Kind::PointerValue:
  180. case Value::Kind::StructValue:
  181. case Value::Kind::AlternativeValue:
  182. case Value::Kind::BindingPlaceholderValue:
  183. case Value::Kind::AlternativeConstructorValue:
  184. case Value::Kind::ContinuationValue:
  185. case Value::Kind::StringValue:
  186. FATAL() << "In ArgumentDeduction: expected type, not value " << *param;
  187. }
  188. }
  189. static auto Substitute(TypeEnv dict, const Value* type) -> const Value* {
  190. switch (type->Tag()) {
  191. case Value::Kind::VariableType: {
  192. std::optional<const Value*> t =
  193. dict.Get(cast<VariableType>(*type).Name());
  194. if (!t) {
  195. return type;
  196. } else {
  197. return *t;
  198. }
  199. }
  200. case Value::Kind::TupleValue: {
  201. std::vector<TupleElement> elts;
  202. for (const auto& elt : cast<TupleValue>(*type).Elements()) {
  203. auto t = Substitute(dict, elt.value);
  204. elts.push_back({.name = elt.name, .value = t});
  205. }
  206. return global_arena->RawNew<TupleValue>(elts);
  207. }
  208. case Value::Kind::FunctionType: {
  209. const auto& fn_type = cast<FunctionType>(*type);
  210. auto param = Substitute(dict, fn_type.Param());
  211. auto ret = Substitute(dict, fn_type.Ret());
  212. return global_arena->RawNew<FunctionType>(std::vector<GenericBinding>(),
  213. param, ret);
  214. }
  215. case Value::Kind::PointerType: {
  216. return global_arena->RawNew<PointerType>(
  217. Substitute(dict, cast<PointerType>(*type).Type()));
  218. }
  219. case Value::Kind::AutoType:
  220. case Value::Kind::IntType:
  221. case Value::Kind::BoolType:
  222. case Value::Kind::TypeType:
  223. case Value::Kind::StructType:
  224. case Value::Kind::ChoiceType:
  225. case Value::Kind::ContinuationType:
  226. case Value::Kind::StringType:
  227. return type;
  228. // The rest of these cases should never happen.
  229. case Value::Kind::IntValue:
  230. case Value::Kind::BoolValue:
  231. case Value::Kind::FunctionValue:
  232. case Value::Kind::PointerValue:
  233. case Value::Kind::StructValue:
  234. case Value::Kind::AlternativeValue:
  235. case Value::Kind::BindingPlaceholderValue:
  236. case Value::Kind::AlternativeConstructorValue:
  237. case Value::Kind::ContinuationValue:
  238. case Value::Kind::StringValue:
  239. FATAL() << "In Substitute: expected type, not value " << *type;
  240. }
  241. }
  242. // The TypeCheckExp function performs semantic analysis on an expression.
  243. // It returns a new version of the expression, its type, and an
  244. // updated environment which are bundled into a TCResult object.
  245. // The purpose of the updated environment is
  246. // to bring pattern variables into scope, for example, in a match case.
  247. // The new version of the expression may include more information,
  248. // for example, the type arguments deduced for the type parameters of a
  249. // generic.
  250. //
  251. // e is the expression to be analyzed.
  252. // types maps variable names to the type of their run-time value.
  253. // values maps variable names to their compile-time values. It is not
  254. // directly used in this function but is passed to InterExp.
  255. auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
  256. -> TCExpression {
  257. if (tracing_output) {
  258. llvm::outs() << "checking expression " << *e << "\ntypes: ";
  259. PrintTypeEnv(types, llvm::outs());
  260. llvm::outs() << "\nvalues: ";
  261. PrintEnv(values, llvm::outs());
  262. llvm::outs() << "\n";
  263. }
  264. switch (e->Tag()) {
  265. case Expression::Kind::IndexExpression: {
  266. const auto& index = cast<IndexExpression>(*e);
  267. auto res = TypeCheckExp(index.Aggregate(), types, values);
  268. auto t = res.type;
  269. switch (t->Tag()) {
  270. case Value::Kind::TupleValue: {
  271. auto i = cast<IntValue>(*InterpExp(values, index.Offset())).Val();
  272. std::string f = std::to_string(i);
  273. const Value* field_t = cast<TupleValue>(*t).FindField(f);
  274. if (field_t == nullptr) {
  275. FATAL_COMPILATION_ERROR(e->LineNumber())
  276. << "field " << f << " is not in the tuple " << *t;
  277. }
  278. auto new_e = global_arena->RawNew<IndexExpression>(
  279. e->LineNumber(), res.exp,
  280. global_arena->RawNew<IntLiteral>(e->LineNumber(), i));
  281. return TCExpression(new_e, field_t, res.types);
  282. }
  283. default:
  284. FATAL_COMPILATION_ERROR(e->LineNumber()) << "expected a tuple";
  285. }
  286. }
  287. case Expression::Kind::TupleLiteral: {
  288. std::vector<FieldInitializer> new_args;
  289. std::vector<TupleElement> arg_types;
  290. auto new_types = types;
  291. for (const auto& arg : cast<TupleLiteral>(*e).Fields()) {
  292. auto arg_res = TypeCheckExp(arg.expression, new_types, values);
  293. new_types = arg_res.types;
  294. new_args.push_back(FieldInitializer(arg.name, arg_res.exp));
  295. arg_types.push_back({.name = arg.name, .value = arg_res.type});
  296. }
  297. auto tuple_e =
  298. global_arena->RawNew<TupleLiteral>(e->LineNumber(), new_args);
  299. auto tuple_t = global_arena->RawNew<TupleValue>(std::move(arg_types));
  300. return TCExpression(tuple_e, tuple_t, new_types);
  301. }
  302. case Expression::Kind::FieldAccessExpression: {
  303. const auto& access = cast<FieldAccessExpression>(*e);
  304. auto res = TypeCheckExp(access.Aggregate(), types, values);
  305. auto t = res.type;
  306. switch (t->Tag()) {
  307. case Value::Kind::StructType: {
  308. const auto& t_struct = cast<StructType>(*t);
  309. // Search for a field
  310. for (auto& field : t_struct.Fields()) {
  311. if (access.Field() == field.first) {
  312. const Expression* new_e =
  313. global_arena->RawNew<FieldAccessExpression>(
  314. e->LineNumber(), res.exp, access.Field());
  315. return TCExpression(new_e, field.second, res.types);
  316. }
  317. }
  318. // Search for a method
  319. for (auto& method : t_struct.Methods()) {
  320. if (access.Field() == method.first) {
  321. const Expression* new_e =
  322. global_arena->RawNew<FieldAccessExpression>(
  323. e->LineNumber(), res.exp, access.Field());
  324. return TCExpression(new_e, method.second, res.types);
  325. }
  326. }
  327. FATAL_COMPILATION_ERROR(e->LineNumber())
  328. << "struct " << t_struct.Name() << " does not have a field named "
  329. << access.Field();
  330. }
  331. case Value::Kind::TupleValue: {
  332. const auto& tup = cast<TupleValue>(*t);
  333. for (const TupleElement& field : tup.Elements()) {
  334. if (access.Field() == field.name) {
  335. auto new_e = global_arena->RawNew<FieldAccessExpression>(
  336. e->LineNumber(), res.exp, access.Field());
  337. return TCExpression(new_e, field.value, res.types);
  338. }
  339. }
  340. FATAL_COMPILATION_ERROR(e->LineNumber())
  341. << "tuple " << tup << " does not have a field named "
  342. << access.Field();
  343. }
  344. case Value::Kind::ChoiceType: {
  345. const auto& choice = cast<ChoiceType>(*t);
  346. for (const auto& vt : choice.Alternatives()) {
  347. if (access.Field() == vt.first) {
  348. const Expression* new_e =
  349. global_arena->RawNew<FieldAccessExpression>(
  350. e->LineNumber(), res.exp, access.Field());
  351. auto fun_ty = global_arena->RawNew<FunctionType>(
  352. std::vector<GenericBinding>(), vt.second, t);
  353. return TCExpression(new_e, fun_ty, res.types);
  354. }
  355. }
  356. FATAL_COMPILATION_ERROR(e->LineNumber())
  357. << "choice " << choice.Name() << " does not have a field named "
  358. << access.Field();
  359. }
  360. default:
  361. FATAL_COMPILATION_ERROR(e->LineNumber())
  362. << "field access, expected a struct\n"
  363. << *e;
  364. }
  365. }
  366. case Expression::Kind::IdentifierExpression: {
  367. const auto& ident = cast<IdentifierExpression>(*e);
  368. std::optional<const Value*> type = types.Get(ident.Name());
  369. if (type) {
  370. return TCExpression(e, *type, types);
  371. } else {
  372. FATAL_COMPILATION_ERROR(e->LineNumber())
  373. << "could not find `" << ident.Name() << "`";
  374. }
  375. }
  376. case Expression::Kind::IntLiteral:
  377. return TCExpression(e, global_arena->RawNew<IntType>(), types);
  378. case Expression::Kind::BoolLiteral:
  379. return TCExpression(e, global_arena->RawNew<BoolType>(), types);
  380. case Expression::Kind::PrimitiveOperatorExpression: {
  381. const auto& op = cast<PrimitiveOperatorExpression>(*e);
  382. std::vector<const Expression*> es;
  383. std::vector<const Value*> ts;
  384. auto new_types = types;
  385. for (const Expression* argument : op.Arguments()) {
  386. auto res = TypeCheckExp(argument, types, values);
  387. new_types = res.types;
  388. es.push_back(res.exp);
  389. ts.push_back(res.type);
  390. }
  391. auto new_e = global_arena->RawNew<PrimitiveOperatorExpression>(
  392. e->LineNumber(), op.Op(), es);
  393. switch (op.Op()) {
  394. case Operator::Neg:
  395. ExpectType(e->LineNumber(), "negation",
  396. global_arena->RawNew<IntType>(), ts[0]);
  397. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  398. new_types);
  399. case Operator::Add:
  400. ExpectType(e->LineNumber(), "addition(1)",
  401. global_arena->RawNew<IntType>(), ts[0]);
  402. ExpectType(e->LineNumber(), "addition(2)",
  403. global_arena->RawNew<IntType>(), ts[1]);
  404. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  405. new_types);
  406. case Operator::Sub:
  407. ExpectType(e->LineNumber(), "subtraction(1)",
  408. global_arena->RawNew<IntType>(), ts[0]);
  409. ExpectType(e->LineNumber(), "subtraction(2)",
  410. global_arena->RawNew<IntType>(), ts[1]);
  411. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  412. new_types);
  413. case Operator::Mul:
  414. ExpectType(e->LineNumber(), "multiplication(1)",
  415. global_arena->RawNew<IntType>(), ts[0]);
  416. ExpectType(e->LineNumber(), "multiplication(2)",
  417. global_arena->RawNew<IntType>(), ts[1]);
  418. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  419. new_types);
  420. case Operator::And:
  421. ExpectType(e->LineNumber(), "&&(1)", global_arena->RawNew<BoolType>(),
  422. ts[0]);
  423. ExpectType(e->LineNumber(), "&&(2)", global_arena->RawNew<BoolType>(),
  424. ts[1]);
  425. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  426. new_types);
  427. case Operator::Or:
  428. ExpectType(e->LineNumber(), "||(1)", global_arena->RawNew<BoolType>(),
  429. ts[0]);
  430. ExpectType(e->LineNumber(), "||(2)", global_arena->RawNew<BoolType>(),
  431. ts[1]);
  432. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  433. new_types);
  434. case Operator::Not:
  435. ExpectType(e->LineNumber(), "!", global_arena->RawNew<BoolType>(),
  436. ts[0]);
  437. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  438. new_types);
  439. case Operator::Eq:
  440. ExpectType(e->LineNumber(), "==", ts[0], ts[1]);
  441. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  442. new_types);
  443. case Operator::Deref:
  444. ExpectPointerType(e->LineNumber(), "*", ts[0]);
  445. return TCExpression(new_e, cast<PointerType>(*ts[0]).Type(),
  446. new_types);
  447. case Operator::Ptr:
  448. ExpectType(e->LineNumber(), "*", global_arena->RawNew<TypeType>(),
  449. ts[0]);
  450. return TCExpression(new_e, global_arena->RawNew<TypeType>(),
  451. new_types);
  452. }
  453. break;
  454. }
  455. case Expression::Kind::CallExpression: {
  456. const auto& call = cast<CallExpression>(*e);
  457. auto fun_res = TypeCheckExp(call.Function(), types, values);
  458. switch (fun_res.type->Tag()) {
  459. case Value::Kind::FunctionType: {
  460. const auto& fun_t = cast<FunctionType>(*fun_res.type);
  461. auto arg_res = TypeCheckExp(call.Argument(), fun_res.types, values);
  462. auto parameter_type = fun_t.Param();
  463. auto return_type = fun_t.Ret();
  464. if (!fun_t.Deduced().empty()) {
  465. auto deduced_args = ArgumentDeduction(e->LineNumber(), TypeEnv(),
  466. parameter_type, arg_res.type);
  467. for (auto& deduced_param : fun_t.Deduced()) {
  468. // TODO: change the following to a CHECK once the real checking
  469. // has been added to the type checking of function signatures.
  470. if (!deduced_args.Get(deduced_param.name)) {
  471. FATAL_COMPILATION_ERROR(e->LineNumber())
  472. << "could not deduce type argument for type parameter "
  473. << deduced_param.name;
  474. }
  475. }
  476. parameter_type = Substitute(deduced_args, parameter_type);
  477. return_type = Substitute(deduced_args, return_type);
  478. } else {
  479. ExpectType(e->LineNumber(), "call", parameter_type, arg_res.type);
  480. }
  481. auto new_e = global_arena->RawNew<CallExpression>(
  482. e->LineNumber(), fun_res.exp, arg_res.exp);
  483. return TCExpression(new_e, return_type, arg_res.types);
  484. }
  485. default: {
  486. FATAL_COMPILATION_ERROR(e->LineNumber())
  487. << "in call, expected a function\n"
  488. << *e;
  489. }
  490. }
  491. break;
  492. }
  493. case Expression::Kind::FunctionTypeLiteral: {
  494. const auto& fn = cast<FunctionTypeLiteral>(*e);
  495. auto pt = InterpExp(values, fn.Parameter());
  496. auto rt = InterpExp(values, fn.ReturnType());
  497. auto new_e = global_arena->RawNew<FunctionTypeLiteral>(
  498. e->LineNumber(), ReifyType(pt, e->LineNumber()),
  499. ReifyType(rt, e->LineNumber()),
  500. /*is_omitted_return_type=*/false);
  501. return TCExpression(new_e, global_arena->RawNew<TypeType>(), types);
  502. }
  503. case Expression::Kind::StringLiteral:
  504. return TCExpression(e, global_arena->RawNew<StringType>(), types);
  505. case Expression::Kind::IntrinsicExpression:
  506. switch (cast<IntrinsicExpression>(*e).Intrinsic()) {
  507. case IntrinsicExpression::IntrinsicKind::Print:
  508. return TCExpression(e, &TupleValue::Empty(), types);
  509. }
  510. case Expression::Kind::IntTypeLiteral:
  511. case Expression::Kind::BoolTypeLiteral:
  512. case Expression::Kind::StringTypeLiteral:
  513. case Expression::Kind::TypeTypeLiteral:
  514. case Expression::Kind::ContinuationTypeLiteral:
  515. return TCExpression(e, global_arena->RawNew<TypeType>(), types);
  516. }
  517. }
  518. // Equivalent to TypeCheckExp, but operates on Patterns instead of Expressions.
  519. // `expected` is the type that this pattern is expected to have, if the
  520. // surrounding context gives us that information. Otherwise, it is null.
  521. auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
  522. const Value* expected) -> TCPattern {
  523. if (tracing_output) {
  524. llvm::outs() << "checking pattern " << *p;
  525. if (expected) {
  526. llvm::outs() << ", expecting " << *expected;
  527. }
  528. llvm::outs() << "\ntypes: ";
  529. PrintTypeEnv(types, llvm::outs());
  530. llvm::outs() << "\nvalues: ";
  531. PrintEnv(values, llvm::outs());
  532. llvm::outs() << "\n";
  533. }
  534. switch (p->Tag()) {
  535. case Pattern::Kind::AutoPattern: {
  536. return {.pattern = p,
  537. .type = global_arena->RawNew<TypeType>(),
  538. .types = types};
  539. }
  540. case Pattern::Kind::BindingPattern: {
  541. const auto& binding = cast<BindingPattern>(*p);
  542. TCPattern binding_type_result =
  543. TypeCheckPattern(binding.Type(), types, values, nullptr);
  544. const Value* type = InterpPattern(values, binding_type_result.pattern);
  545. if (expected != nullptr) {
  546. std::optional<Env> values =
  547. PatternMatch(type, expected, binding.Type()->LineNumber());
  548. if (values == std::nullopt) {
  549. FATAL_COMPILATION_ERROR(binding.Type()->LineNumber())
  550. << "Type pattern '" << *type << "' does not match actual type '"
  551. << *expected << "'";
  552. }
  553. CHECK(values->begin() == values->end())
  554. << "Name bindings within type patterns are unsupported";
  555. type = expected;
  556. }
  557. auto new_p = global_arena->RawNew<BindingPattern>(
  558. binding.LineNumber(), binding.Name(),
  559. global_arena->RawNew<ExpressionPattern>(
  560. ReifyType(type, binding.LineNumber())));
  561. if (binding.Name().has_value()) {
  562. types.Set(*binding.Name(), type);
  563. }
  564. return {.pattern = new_p, .type = type, .types = types};
  565. }
  566. case Pattern::Kind::TuplePattern: {
  567. const auto& tuple = cast<TuplePattern>(*p);
  568. std::vector<TuplePattern::Field> new_fields;
  569. std::vector<TupleElement> field_types;
  570. auto new_types = types;
  571. if (expected && expected->Tag() != Value::Kind::TupleValue) {
  572. FATAL_COMPILATION_ERROR(p->LineNumber()) << "didn't expect a tuple";
  573. }
  574. if (expected && tuple.Fields().size() !=
  575. cast<TupleValue>(*expected).Elements().size()) {
  576. FATAL_COMPILATION_ERROR(tuple.LineNumber())
  577. << "tuples of different length";
  578. }
  579. for (size_t i = 0; i < tuple.Fields().size(); ++i) {
  580. const TuplePattern::Field& field = tuple.Fields()[i];
  581. const Value* expected_field_type = nullptr;
  582. if (expected != nullptr) {
  583. const TupleElement& expected_element =
  584. cast<TupleValue>(*expected).Elements()[i];
  585. if (expected_element.name != field.name) {
  586. FATAL_COMPILATION_ERROR(tuple.LineNumber())
  587. << "field names do not match, expected "
  588. << expected_element.name << " but got " << field.name;
  589. }
  590. expected_field_type = expected_element.value;
  591. }
  592. auto field_result = TypeCheckPattern(field.pattern, new_types, values,
  593. expected_field_type);
  594. new_types = field_result.types;
  595. new_fields.push_back(
  596. TuplePattern::Field(field.name, field_result.pattern));
  597. field_types.push_back({.name = field.name, .value = field_result.type});
  598. }
  599. auto new_tuple =
  600. global_arena->RawNew<TuplePattern>(tuple.LineNumber(), new_fields);
  601. auto tuple_t = global_arena->RawNew<TupleValue>(std::move(field_types));
  602. return {.pattern = new_tuple, .type = tuple_t, .types = new_types};
  603. }
  604. case Pattern::Kind::AlternativePattern: {
  605. const auto& alternative = cast<AlternativePattern>(*p);
  606. const Value* choice_type = InterpExp(values, alternative.ChoiceType());
  607. if (choice_type->Tag() != Value::Kind::ChoiceType) {
  608. FATAL_COMPILATION_ERROR(alternative.LineNumber())
  609. << "alternative pattern does not name a choice type.";
  610. }
  611. if (expected != nullptr) {
  612. ExpectType(alternative.LineNumber(), "alternative pattern", expected,
  613. choice_type);
  614. }
  615. const Value* parameter_types =
  616. FindInVarValues(alternative.AlternativeName(),
  617. cast<ChoiceType>(*choice_type).Alternatives());
  618. if (parameter_types == nullptr) {
  619. FATAL_COMPILATION_ERROR(alternative.LineNumber())
  620. << "'" << alternative.AlternativeName()
  621. << "' is not an alternative of " << choice_type;
  622. }
  623. TCPattern arg_results = TypeCheckPattern(alternative.Arguments(), types,
  624. values, parameter_types);
  625. return {.pattern = global_arena->RawNew<AlternativePattern>(
  626. alternative.LineNumber(),
  627. ReifyType(choice_type, alternative.LineNumber()),
  628. alternative.AlternativeName(),
  629. cast<TuplePattern>(arg_results.pattern)),
  630. .type = choice_type,
  631. .types = arg_results.types};
  632. }
  633. case Pattern::Kind::ExpressionPattern: {
  634. TCExpression result =
  635. TypeCheckExp(cast<ExpressionPattern>(p)->Expression(), types, values);
  636. return {.pattern = global_arena->RawNew<ExpressionPattern>(result.exp),
  637. .type = result.type,
  638. .types = result.types};
  639. }
  640. }
  641. }
  642. static auto TypecheckCase(const Value* expected, const Pattern* pat,
  643. const Statement* body, TypeEnv types, Env values,
  644. const Value*& ret_type, bool is_omitted_ret_type)
  645. -> std::pair<const Pattern*, const Statement*> {
  646. auto pat_res = TypeCheckPattern(pat, types, values, expected);
  647. auto res =
  648. TypeCheckStmt(body, pat_res.types, values, ret_type, is_omitted_ret_type);
  649. return std::make_pair(pat, res.stmt);
  650. }
  651. // The TypeCheckStmt function performs semantic analysis on a statement.
  652. // It returns a new version of the statement and a new type environment.
  653. //
  654. // The ret_type parameter is used for analyzing return statements.
  655. // It is the declared return type of the enclosing function definition.
  656. // If the return type is "auto", then the return type is inferred from
  657. // the first return statement.
  658. auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
  659. const Value*& ret_type, bool is_omitted_ret_type)
  660. -> TCStatement {
  661. if (!s) {
  662. return TCStatement(s, types);
  663. }
  664. switch (s->Tag()) {
  665. case Statement::Kind::Match: {
  666. const auto& match = cast<Match>(*s);
  667. auto res = TypeCheckExp(match.Exp(), types, values);
  668. auto res_type = res.type;
  669. auto new_clauses = global_arena->RawNew<
  670. std::list<std::pair<const Pattern*, const Statement*>>>();
  671. for (auto& clause : *match.Clauses()) {
  672. new_clauses->push_back(TypecheckCase(res_type, clause.first,
  673. clause.second, types, values,
  674. ret_type, is_omitted_ret_type));
  675. }
  676. const Statement* new_s =
  677. global_arena->RawNew<Match>(s->LineNumber(), res.exp, new_clauses);
  678. return TCStatement(new_s, types);
  679. }
  680. case Statement::Kind::While: {
  681. const auto& while_stmt = cast<While>(*s);
  682. auto cnd_res = TypeCheckExp(while_stmt.Cond(), types, values);
  683. ExpectType(s->LineNumber(), "condition of `while`",
  684. global_arena->RawNew<BoolType>(), cnd_res.type);
  685. auto body_res = TypeCheckStmt(while_stmt.Body(), types, values, ret_type,
  686. is_omitted_ret_type);
  687. auto new_s = global_arena->RawNew<While>(s->LineNumber(), cnd_res.exp,
  688. body_res.stmt);
  689. return TCStatement(new_s, types);
  690. }
  691. case Statement::Kind::Break:
  692. case Statement::Kind::Continue:
  693. return TCStatement(s, types);
  694. case Statement::Kind::Block: {
  695. auto stmt_res = TypeCheckStmt(cast<Block>(*s).Stmt(), types, values,
  696. ret_type, is_omitted_ret_type);
  697. return TCStatement(
  698. global_arena->RawNew<Block>(s->LineNumber(), stmt_res.stmt), types);
  699. }
  700. case Statement::Kind::VariableDefinition: {
  701. const auto& var = cast<VariableDefinition>(*s);
  702. auto res = TypeCheckExp(var.Init(), types, values);
  703. const Value* rhs_ty = res.type;
  704. auto lhs_res = TypeCheckPattern(var.Pat(), types, values, rhs_ty);
  705. const Statement* new_s = global_arena->RawNew<VariableDefinition>(
  706. s->LineNumber(), var.Pat(), res.exp);
  707. return TCStatement(new_s, lhs_res.types);
  708. }
  709. case Statement::Kind::Sequence: {
  710. const auto& seq = cast<Sequence>(*s);
  711. auto stmt_res = TypeCheckStmt(seq.Stmt(), types, values, ret_type,
  712. is_omitted_ret_type);
  713. auto types2 = stmt_res.types;
  714. auto next_res = TypeCheckStmt(seq.Next(), types2, values, ret_type,
  715. is_omitted_ret_type);
  716. auto types3 = next_res.types;
  717. return TCStatement(global_arena->RawNew<Sequence>(
  718. s->LineNumber(), stmt_res.stmt, next_res.stmt),
  719. types3);
  720. }
  721. case Statement::Kind::Assign: {
  722. const auto& assign = cast<Assign>(*s);
  723. auto rhs_res = TypeCheckExp(assign.Rhs(), types, values);
  724. auto rhs_t = rhs_res.type;
  725. auto lhs_res = TypeCheckExp(assign.Lhs(), types, values);
  726. auto lhs_t = lhs_res.type;
  727. ExpectType(s->LineNumber(), "assign", lhs_t, rhs_t);
  728. auto new_s = global_arena->RawNew<Assign>(s->LineNumber(), lhs_res.exp,
  729. rhs_res.exp);
  730. return TCStatement(new_s, lhs_res.types);
  731. }
  732. case Statement::Kind::ExpressionStatement: {
  733. auto res =
  734. TypeCheckExp(cast<ExpressionStatement>(*s).Exp(), types, values);
  735. auto new_s =
  736. global_arena->RawNew<ExpressionStatement>(s->LineNumber(), res.exp);
  737. return TCStatement(new_s, types);
  738. }
  739. case Statement::Kind::If: {
  740. const auto& if_stmt = cast<If>(*s);
  741. auto cnd_res = TypeCheckExp(if_stmt.Cond(), types, values);
  742. ExpectType(s->LineNumber(), "condition of `if`",
  743. global_arena->RawNew<BoolType>(), cnd_res.type);
  744. auto then_res = TypeCheckStmt(if_stmt.ThenStmt(), types, values, ret_type,
  745. is_omitted_ret_type);
  746. auto else_res = TypeCheckStmt(if_stmt.ElseStmt(), types, values, ret_type,
  747. is_omitted_ret_type);
  748. auto new_s = global_arena->RawNew<If>(s->LineNumber(), cnd_res.exp,
  749. then_res.stmt, else_res.stmt);
  750. return TCStatement(new_s, types);
  751. }
  752. case Statement::Kind::Return: {
  753. const auto& ret = cast<Return>(*s);
  754. auto res = TypeCheckExp(ret.Exp(), types, values);
  755. if (ret_type->Tag() == Value::Kind::AutoType) {
  756. // The following infers the return type from the first 'return'
  757. // statement. This will get more difficult with subtyping, when we
  758. // should infer the least-upper bound of all the 'return' statements.
  759. ret_type = res.type;
  760. } else {
  761. ExpectType(s->LineNumber(), "return", ret_type, res.type);
  762. }
  763. if (ret.IsOmittedExp() != is_omitted_ret_type) {
  764. FATAL_COMPILATION_ERROR(s->LineNumber())
  765. << *s << " should" << (is_omitted_ret_type ? " not" : "")
  766. << " provide a return value, to match the function's signature.";
  767. }
  768. return TCStatement(global_arena->RawNew<Return>(s->LineNumber(), res.exp,
  769. ret.IsOmittedExp()),
  770. types);
  771. }
  772. case Statement::Kind::Continuation: {
  773. const auto& cont = cast<Continuation>(*s);
  774. TCStatement body_result = TypeCheckStmt(cont.Body(), types, values,
  775. ret_type, is_omitted_ret_type);
  776. const Statement* new_continuation = global_arena->RawNew<Continuation>(
  777. s->LineNumber(), cont.ContinuationVariable(), body_result.stmt);
  778. types.Set(cont.ContinuationVariable(),
  779. global_arena->RawNew<ContinuationType>());
  780. return TCStatement(new_continuation, types);
  781. }
  782. case Statement::Kind::Run: {
  783. TCExpression argument_result =
  784. TypeCheckExp(cast<Run>(*s).Argument(), types, values);
  785. ExpectType(s->LineNumber(), "argument of `run`",
  786. global_arena->RawNew<ContinuationType>(),
  787. argument_result.type);
  788. const Statement* new_run =
  789. global_arena->RawNew<Run>(s->LineNumber(), argument_result.exp);
  790. return TCStatement(new_run, types);
  791. }
  792. case Statement::Kind::Await: {
  793. // nothing to do here
  794. return TCStatement(s, types);
  795. }
  796. } // switch
  797. }
  798. static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
  799. int line_num) -> const Statement* {
  800. if (!stmt) {
  801. if (omitted_ret_type) {
  802. return global_arena->RawNew<Return>(line_num, nullptr,
  803. /*is_omitted_exp=*/true);
  804. } else {
  805. FATAL_COMPILATION_ERROR(line_num)
  806. << "control-flow reaches end of function that provides a `->` return "
  807. "type without reaching a return statement";
  808. }
  809. }
  810. switch (stmt->Tag()) {
  811. case Statement::Kind::Match: {
  812. const auto& match = cast<Match>(*stmt);
  813. auto new_clauses = global_arena->RawNew<
  814. std::list<std::pair<const Pattern*, const Statement*>>>();
  815. for (const auto& clause : *match.Clauses()) {
  816. auto s = CheckOrEnsureReturn(clause.second, omitted_ret_type,
  817. stmt->LineNumber());
  818. new_clauses->push_back(std::make_pair(clause.first, s));
  819. }
  820. return global_arena->RawNew<Match>(stmt->LineNumber(), match.Exp(),
  821. new_clauses);
  822. }
  823. case Statement::Kind::Block:
  824. return global_arena->RawNew<Block>(
  825. stmt->LineNumber(),
  826. CheckOrEnsureReturn(cast<Block>(*stmt).Stmt(), omitted_ret_type,
  827. stmt->LineNumber()));
  828. case Statement::Kind::If: {
  829. const auto& if_stmt = cast<If>(*stmt);
  830. return global_arena->RawNew<If>(
  831. stmt->LineNumber(), if_stmt.Cond(),
  832. CheckOrEnsureReturn(if_stmt.ThenStmt(), omitted_ret_type,
  833. stmt->LineNumber()),
  834. CheckOrEnsureReturn(if_stmt.ElseStmt(), omitted_ret_type,
  835. stmt->LineNumber()));
  836. }
  837. case Statement::Kind::Return:
  838. return stmt;
  839. case Statement::Kind::Sequence: {
  840. const auto& seq = cast<Sequence>(*stmt);
  841. if (seq.Next()) {
  842. return global_arena->RawNew<Sequence>(
  843. stmt->LineNumber(), seq.Stmt(),
  844. CheckOrEnsureReturn(seq.Next(), omitted_ret_type,
  845. stmt->LineNumber()));
  846. } else {
  847. return CheckOrEnsureReturn(seq.Stmt(), omitted_ret_type,
  848. stmt->LineNumber());
  849. }
  850. }
  851. case Statement::Kind::Continuation:
  852. case Statement::Kind::Run:
  853. case Statement::Kind::Await:
  854. return stmt;
  855. case Statement::Kind::Assign:
  856. case Statement::Kind::ExpressionStatement:
  857. case Statement::Kind::While:
  858. case Statement::Kind::Break:
  859. case Statement::Kind::Continue:
  860. case Statement::Kind::VariableDefinition:
  861. if (omitted_ret_type) {
  862. return global_arena->RawNew<Sequence>(
  863. stmt->LineNumber(), stmt,
  864. global_arena->RawNew<Return>(line_num, nullptr,
  865. /*is_omitted_exp=*/true));
  866. } else {
  867. FATAL_COMPILATION_ERROR(stmt->LineNumber())
  868. << "control-flow reaches end of function that provides a `->` "
  869. "return type without reaching a return statement";
  870. }
  871. }
  872. }
  873. // TODO: factor common parts of TypeCheckFunDef and TypeOfFunDef into
  874. // a function.
  875. // TODO: Add checking to function definitions to ensure that
  876. // all deduced type parameters will be deduced.
  877. static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
  878. Env values) -> struct FunctionDefinition* {
  879. // Bring the deduced parameters into scope
  880. for (const auto& deduced : f->deduced_parameters) {
  881. // auto t = InterpExp(values, deduced.type);
  882. types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
  883. Address a = state->heap.AllocateValue(*types.Get(deduced.name));
  884. values.Set(deduced.name, a);
  885. }
  886. // Type check the parameter pattern
  887. auto param_res = TypeCheckPattern(f->param_pattern, types, values, nullptr);
  888. // Evaluate the return type expression
  889. auto return_type = InterpPattern(values, f->return_type);
  890. if (f->name == "main") {
  891. ExpectType(f->line_num, "return type of `main`",
  892. global_arena->RawNew<IntType>(), return_type);
  893. // TODO: Check that main doesn't have any parameters.
  894. }
  895. auto res = TypeCheckStmt(f->body, param_res.types, values, return_type,
  896. f->is_omitted_return_type);
  897. auto body =
  898. CheckOrEnsureReturn(res.stmt, f->is_omitted_return_type, f->line_num);
  899. return global_arena->RawNew<FunctionDefinition>(
  900. f->line_num, f->name, f->deduced_parameters, f->param_pattern,
  901. global_arena->RawNew<ExpressionPattern>(
  902. ReifyType(return_type, f->line_num)),
  903. /*is_omitted_return_type=*/false, body);
  904. }
  905. static auto TypeOfFunDef(TypeEnv types, Env values,
  906. const FunctionDefinition* fun_def) -> const Value* {
  907. // Bring the deduced parameters into scope
  908. for (const auto& deduced : fun_def->deduced_parameters) {
  909. // auto t = InterpExp(values, deduced.type);
  910. types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
  911. Address a = state->heap.AllocateValue(*types.Get(deduced.name));
  912. values.Set(deduced.name, a);
  913. }
  914. // Type check the parameter pattern
  915. auto param_res =
  916. TypeCheckPattern(fun_def->param_pattern, types, values, nullptr);
  917. // Evaluate the return type expression
  918. auto ret = InterpPattern(values, fun_def->return_type);
  919. if (ret->Tag() == Value::Kind::AutoType) {
  920. auto f = TypeCheckFunDef(fun_def, types, values);
  921. ret = InterpPattern(values, f->return_type);
  922. }
  923. return global_arena->RawNew<FunctionType>(fun_def->deduced_parameters,
  924. param_res.type, ret);
  925. }
  926. static auto TypeOfStructDef(const StructDefinition* sd, TypeEnv /*types*/,
  927. Env ct_top) -> const Value* {
  928. VarValues fields;
  929. VarValues methods;
  930. for (const Member* m : sd->members) {
  931. switch (m->Tag()) {
  932. case Member::Kind::FieldMember: {
  933. const BindingPattern* binding = cast<FieldMember>(*m).Binding();
  934. if (!binding->Name().has_value()) {
  935. FATAL_COMPILATION_ERROR(binding->LineNumber())
  936. << "Struct members must have names";
  937. }
  938. const Expression* type_expression =
  939. dyn_cast<ExpressionPattern>(binding->Type())->Expression();
  940. if (type_expression == nullptr) {
  941. FATAL_COMPILATION_ERROR(binding->LineNumber())
  942. << "Struct members must have explicit types";
  943. }
  944. auto type = InterpExp(ct_top, type_expression);
  945. fields.push_back(std::make_pair(*binding->Name(), type));
  946. break;
  947. }
  948. }
  949. }
  950. return global_arena->RawNew<StructType>(sd->name, std::move(fields),
  951. std::move(methods));
  952. }
  953. static auto GetName(const Declaration& d) -> const std::string& {
  954. switch (d.Tag()) {
  955. case Declaration::Kind::FunctionDeclaration:
  956. return cast<FunctionDeclaration>(d).Definition().name;
  957. case Declaration::Kind::StructDeclaration:
  958. return cast<StructDeclaration>(d).Definition().name;
  959. case Declaration::Kind::ChoiceDeclaration:
  960. return cast<ChoiceDeclaration>(d).Name();
  961. case Declaration::Kind::VariableDeclaration: {
  962. const BindingPattern* binding = cast<VariableDeclaration>(d).Binding();
  963. if (!binding->Name().has_value()) {
  964. FATAL_COMPILATION_ERROR(binding->LineNumber())
  965. << "Top-level variable declarations must have names";
  966. }
  967. return *binding->Name();
  968. }
  969. }
  970. }
  971. auto MakeTypeChecked(const Declaration& d, const TypeEnv& types,
  972. const Env& values) -> const Declaration* {
  973. switch (d.Tag()) {
  974. case Declaration::Kind::FunctionDeclaration:
  975. return global_arena->RawNew<FunctionDeclaration>(TypeCheckFunDef(
  976. &cast<FunctionDeclaration>(d).Definition(), types, values));
  977. case Declaration::Kind::StructDeclaration: {
  978. const StructDefinition& struct_def =
  979. cast<StructDeclaration>(d).Definition();
  980. std::list<Member*> fields;
  981. for (Member* m : struct_def.members) {
  982. switch (m->Tag()) {
  983. case Member::Kind::FieldMember:
  984. // TODO: Interpret the type expression and store the result.
  985. fields.push_back(m);
  986. break;
  987. }
  988. }
  989. return global_arena->RawNew<StructDeclaration>(
  990. struct_def.line_num, struct_def.name, std::move(fields));
  991. }
  992. case Declaration::Kind::ChoiceDeclaration:
  993. // TODO
  994. return &d;
  995. case Declaration::Kind::VariableDeclaration: {
  996. const auto& var = cast<VariableDeclaration>(d);
  997. // Signals a type error if the initializing expression does not have
  998. // the declared type of the variable, otherwise returns this
  999. // declaration with annotated types.
  1000. TCExpression type_checked_initializer =
  1001. TypeCheckExp(var.Initializer(), types, values);
  1002. const Expression* type =
  1003. dyn_cast<ExpressionPattern>(var.Binding()->Type())->Expression();
  1004. if (type == nullptr) {
  1005. // TODO: consider adding support for `auto`
  1006. FATAL_COMPILATION_ERROR(var.LineNumber())
  1007. << "Type of a top-level variable must be an expression.";
  1008. }
  1009. const Value* declared_type = InterpExp(values, type);
  1010. ExpectType(var.LineNumber(), "initializer of variable", declared_type,
  1011. type_checked_initializer.type);
  1012. return &d;
  1013. }
  1014. }
  1015. }
  1016. static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
  1017. switch (d.Tag()) {
  1018. case Declaration::Kind::FunctionDeclaration: {
  1019. const FunctionDefinition& func_def =
  1020. cast<FunctionDeclaration>(d).Definition();
  1021. auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
  1022. tops->types.Set(func_def.name, t);
  1023. InitEnv(d, &tops->values);
  1024. break;
  1025. }
  1026. case Declaration::Kind::StructDeclaration: {
  1027. const StructDefinition& struct_def =
  1028. cast<StructDeclaration>(d).Definition();
  1029. auto st = TypeOfStructDef(&struct_def, tops->types, tops->values);
  1030. Address a = state->heap.AllocateValue(st);
  1031. tops->values.Set(struct_def.name, a); // Is this obsolete?
  1032. std::vector<TupleElement> field_types;
  1033. for (const auto& [field_name, field_value] :
  1034. cast<StructType>(*st).Fields()) {
  1035. field_types.push_back({.name = field_name, .value = field_value});
  1036. }
  1037. auto fun_ty = global_arena->RawNew<FunctionType>(
  1038. std::vector<GenericBinding>(),
  1039. global_arena->RawNew<TupleValue>(std::move(field_types)), st);
  1040. tops->types.Set(struct_def.name, fun_ty);
  1041. break;
  1042. }
  1043. case Declaration::Kind::ChoiceDeclaration: {
  1044. const auto& choice = cast<ChoiceDeclaration>(d);
  1045. VarValues alts;
  1046. for (const auto& [name, signature] : choice.Alternatives()) {
  1047. auto t = InterpExp(tops->values, signature);
  1048. alts.push_back(std::make_pair(name, t));
  1049. }
  1050. auto ct =
  1051. global_arena->RawNew<ChoiceType>(choice.Name(), std::move(alts));
  1052. Address a = state->heap.AllocateValue(ct);
  1053. tops->values.Set(choice.Name(), a); // Is this obsolete?
  1054. tops->types.Set(choice.Name(), ct);
  1055. break;
  1056. }
  1057. case Declaration::Kind::VariableDeclaration: {
  1058. const auto& var = cast<VariableDeclaration>(d);
  1059. // Associate the variable name with it's declared type in the
  1060. // compile-time symbol table.
  1061. const Expression* type =
  1062. cast<ExpressionPattern>(var.Binding()->Type())->Expression();
  1063. const Value* declared_type = InterpExp(tops->values, type);
  1064. tops->types.Set(*var.Binding()->Name(), declared_type);
  1065. break;
  1066. }
  1067. }
  1068. }
  1069. auto TopLevel(const std::list<const Declaration*>& fs) -> TypeCheckContext {
  1070. TypeCheckContext tops;
  1071. bool found_main = false;
  1072. for (auto const& d : fs) {
  1073. if (GetName(*d) == "main") {
  1074. found_main = true;
  1075. }
  1076. TopLevel(*d, &tops);
  1077. }
  1078. if (found_main == false) {
  1079. FATAL_COMPILATION_ERROR_NO_LINE()
  1080. << "program must contain a function named `main`";
  1081. }
  1082. return tops;
  1083. }
  1084. } // namespace Carbon