typecheck.cpp 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "executable_semantics/interpreter/typecheck.h"
  5. #include <algorithm>
  6. #include <iterator>
  7. #include <map>
  8. #include <set>
  9. #include <vector>
  10. #include "common/ostream.h"
  11. #include "executable_semantics/ast/function_definition.h"
  12. #include "executable_semantics/common/arena.h"
  13. #include "executable_semantics/common/error.h"
  14. #include "executable_semantics/common/tracing_flag.h"
  15. #include "executable_semantics/interpreter/interpreter.h"
  16. #include "executable_semantics/interpreter/value.h"
  17. #include "llvm/ADT/StringExtras.h"
  18. #include "llvm/Support/Casting.h"
  19. using llvm::cast;
  20. using llvm::dyn_cast;
  21. namespace Carbon {
  22. void PrintTypeEnv(TypeEnv types, llvm::raw_ostream& out) {
  23. llvm::ListSeparator sep;
  24. for (const auto& [name, type] : types) {
  25. out << sep << name << ": " << *type;
  26. }
  27. }
  28. static void ExpectType(SourceLocation loc, const std::string& context,
  29. const Value* expected, const Value* actual) {
  30. if (!TypeEqual(expected, actual)) {
  31. FATAL_COMPILATION_ERROR(loc) << "type error in " << context << "\n"
  32. << "expected: " << *expected << "\n"
  33. << "actual: " << *actual;
  34. }
  35. }
  36. static void ExpectPointerType(SourceLocation loc, const std::string& context,
  37. const Value* actual) {
  38. if (actual->Tag() != Value::Kind::PointerType) {
  39. FATAL_COMPILATION_ERROR(loc) << "type error in " << context << "\n"
  40. << "expected a pointer type\n"
  41. << "actual: " << *actual;
  42. }
  43. }
  44. static SourceLocation ReifyFakeSourceLoc() {
  45. return SourceLocation("<reify>", 0);
  46. }
  47. // Reify type to type expression.
  48. static auto ReifyType(const Value* t, SourceLocation loc)
  49. -> Ptr<const Expression> {
  50. switch (t->Tag()) {
  51. case Value::Kind::IntType:
  52. return global_arena->New<IntTypeLiteral>(ReifyFakeSourceLoc());
  53. case Value::Kind::BoolType:
  54. return global_arena->New<BoolTypeLiteral>(ReifyFakeSourceLoc());
  55. case Value::Kind::TypeType:
  56. return global_arena->New<TypeTypeLiteral>(ReifyFakeSourceLoc());
  57. case Value::Kind::ContinuationType:
  58. return global_arena->New<ContinuationTypeLiteral>(ReifyFakeSourceLoc());
  59. case Value::Kind::FunctionType: {
  60. const auto& fn_type = cast<FunctionType>(*t);
  61. return global_arena->New<FunctionTypeLiteral>(
  62. ReifyFakeSourceLoc(), ReifyType(fn_type.Param(), loc),
  63. ReifyType(fn_type.Ret(), loc),
  64. /*is_omitted_return_type=*/false);
  65. }
  66. case Value::Kind::TupleValue: {
  67. std::vector<FieldInitializer> args;
  68. for (const TupleElement& field : cast<TupleValue>(*t).Elements()) {
  69. args.push_back(
  70. FieldInitializer(field.name, ReifyType(field.value, loc)));
  71. }
  72. return global_arena->New<TupleLiteral>(ReifyFakeSourceLoc(), args);
  73. }
  74. case Value::Kind::ClassType:
  75. return global_arena->New<IdentifierExpression>(
  76. ReifyFakeSourceLoc(), cast<ClassType>(*t).Name());
  77. case Value::Kind::ChoiceType:
  78. return global_arena->New<IdentifierExpression>(
  79. ReifyFakeSourceLoc(), cast<ChoiceType>(*t).Name());
  80. case Value::Kind::PointerType:
  81. return global_arena->New<PrimitiveOperatorExpression>(
  82. ReifyFakeSourceLoc(), Operator::Ptr,
  83. std::vector<Ptr<const Expression>>(
  84. {ReifyType(cast<PointerType>(*t).Type(), loc)}));
  85. case Value::Kind::VariableType:
  86. return global_arena->New<IdentifierExpression>(
  87. ReifyFakeSourceLoc(), cast<VariableType>(*t).Name());
  88. case Value::Kind::StringType:
  89. return global_arena->New<StringTypeLiteral>(ReifyFakeSourceLoc());
  90. case Value::Kind::AlternativeConstructorValue:
  91. case Value::Kind::AlternativeValue:
  92. case Value::Kind::AutoType:
  93. case Value::Kind::BindingPlaceholderValue:
  94. case Value::Kind::BoolValue:
  95. case Value::Kind::ContinuationValue:
  96. case Value::Kind::FunctionValue:
  97. case Value::Kind::IntValue:
  98. case Value::Kind::PointerValue:
  99. case Value::Kind::StringValue:
  100. case Value::Kind::StructValue:
  101. FATAL() << "expected a type, not " << *t;
  102. }
  103. }
  104. // Perform type argument deduction, matching the parameter type `param`
  105. // against the argument type `arg`. Whenever there is an VariableType
  106. // in the parameter type, it is deduced to be the corresponding type
  107. // inside the argument type.
  108. // The `deduced` parameter is an accumulator, that is, it holds the
  109. // results so-far.
  110. static auto ArgumentDeduction(SourceLocation loc, TypeEnv deduced,
  111. const Value* param, const Value* arg) -> TypeEnv {
  112. switch (param->Tag()) {
  113. case Value::Kind::VariableType: {
  114. const auto& var_type = cast<VariableType>(*param);
  115. std::optional<const Value*> d = deduced.Get(var_type.Name());
  116. if (!d) {
  117. deduced.Set(var_type.Name(), arg);
  118. } else {
  119. ExpectType(loc, "argument deduction", *d, arg);
  120. }
  121. return deduced;
  122. }
  123. case Value::Kind::TupleValue: {
  124. if (arg->Tag() != Value::Kind::TupleValue) {
  125. ExpectType(loc, "argument deduction", param, arg);
  126. }
  127. const auto& param_tup = cast<TupleValue>(*param);
  128. const auto& arg_tup = cast<TupleValue>(*arg);
  129. if (param_tup.Elements().size() != arg_tup.Elements().size()) {
  130. ExpectType(loc, "argument deduction", param, arg);
  131. }
  132. for (size_t i = 0; i < param_tup.Elements().size(); ++i) {
  133. if (param_tup.Elements()[i].name != arg_tup.Elements()[i].name) {
  134. FATAL_COMPILATION_ERROR(loc)
  135. << "mismatch in tuple names, " << param_tup.Elements()[i].name
  136. << " != " << arg_tup.Elements()[i].name;
  137. }
  138. deduced = ArgumentDeduction(loc, deduced, param_tup.Elements()[i].value,
  139. arg_tup.Elements()[i].value);
  140. }
  141. return deduced;
  142. }
  143. case Value::Kind::FunctionType: {
  144. if (arg->Tag() != Value::Kind::FunctionType) {
  145. ExpectType(loc, "argument deduction", param, arg);
  146. }
  147. const auto& param_fn = cast<FunctionType>(*param);
  148. const auto& arg_fn = cast<FunctionType>(*arg);
  149. // TODO: handle situation when arg has deduced parameters.
  150. deduced =
  151. ArgumentDeduction(loc, deduced, param_fn.Param(), arg_fn.Param());
  152. deduced = ArgumentDeduction(loc, deduced, param_fn.Ret(), arg_fn.Ret());
  153. return deduced;
  154. }
  155. case Value::Kind::PointerType: {
  156. if (arg->Tag() != Value::Kind::PointerType) {
  157. ExpectType(loc, "argument deduction", param, arg);
  158. }
  159. return ArgumentDeduction(loc, deduced, cast<PointerType>(*param).Type(),
  160. cast<PointerType>(*arg).Type());
  161. }
  162. // Nothing to do in the case for `auto`.
  163. case Value::Kind::AutoType: {
  164. return deduced;
  165. }
  166. // For the following cases, we check for type equality.
  167. case Value::Kind::ContinuationType:
  168. case Value::Kind::ClassType:
  169. case Value::Kind::ChoiceType:
  170. case Value::Kind::IntType:
  171. case Value::Kind::BoolType:
  172. case Value::Kind::TypeType:
  173. case Value::Kind::StringType:
  174. ExpectType(loc, "argument deduction", param, arg);
  175. return deduced;
  176. // The rest of these cases should never happen.
  177. case Value::Kind::IntValue:
  178. case Value::Kind::BoolValue:
  179. case Value::Kind::FunctionValue:
  180. case Value::Kind::PointerValue:
  181. case Value::Kind::StructValue:
  182. case Value::Kind::AlternativeValue:
  183. case Value::Kind::BindingPlaceholderValue:
  184. case Value::Kind::AlternativeConstructorValue:
  185. case Value::Kind::ContinuationValue:
  186. case Value::Kind::StringValue:
  187. FATAL() << "In ArgumentDeduction: expected type, not value " << *param;
  188. }
  189. }
  190. static auto Substitute(TypeEnv dict, const Value* type) -> const Value* {
  191. switch (type->Tag()) {
  192. case Value::Kind::VariableType: {
  193. std::optional<const Value*> t =
  194. dict.Get(cast<VariableType>(*type).Name());
  195. if (!t) {
  196. return type;
  197. } else {
  198. return *t;
  199. }
  200. }
  201. case Value::Kind::TupleValue: {
  202. std::vector<TupleElement> elts;
  203. for (const auto& elt : cast<TupleValue>(*type).Elements()) {
  204. auto t = Substitute(dict, elt.value);
  205. elts.push_back({.name = elt.name, .value = t});
  206. }
  207. return global_arena->RawNew<TupleValue>(elts);
  208. }
  209. case Value::Kind::FunctionType: {
  210. const auto& fn_type = cast<FunctionType>(*type);
  211. auto param = Substitute(dict, fn_type.Param());
  212. auto ret = Substitute(dict, fn_type.Ret());
  213. return global_arena->RawNew<FunctionType>(std::vector<GenericBinding>(),
  214. param, ret);
  215. }
  216. case Value::Kind::PointerType: {
  217. return global_arena->RawNew<PointerType>(
  218. Substitute(dict, cast<PointerType>(*type).Type()));
  219. }
  220. case Value::Kind::AutoType:
  221. case Value::Kind::IntType:
  222. case Value::Kind::BoolType:
  223. case Value::Kind::TypeType:
  224. case Value::Kind::ClassType:
  225. case Value::Kind::ChoiceType:
  226. case Value::Kind::ContinuationType:
  227. case Value::Kind::StringType:
  228. return type;
  229. // The rest of these cases should never happen.
  230. case Value::Kind::IntValue:
  231. case Value::Kind::BoolValue:
  232. case Value::Kind::FunctionValue:
  233. case Value::Kind::PointerValue:
  234. case Value::Kind::StructValue:
  235. case Value::Kind::AlternativeValue:
  236. case Value::Kind::BindingPlaceholderValue:
  237. case Value::Kind::AlternativeConstructorValue:
  238. case Value::Kind::ContinuationValue:
  239. case Value::Kind::StringValue:
  240. FATAL() << "In Substitute: expected type, not value " << *type;
  241. }
  242. }
  243. // The TypeCheckExp function performs semantic analysis on an expression.
  244. // It returns a new version of the expression, its type, and an
  245. // updated environment which are bundled into a TCResult object.
  246. // The purpose of the updated environment is
  247. // to bring pattern variables into scope, for example, in a match case.
  248. // The new version of the expression may include more information,
  249. // for example, the type arguments deduced for the type parameters of a
  250. // generic.
  251. //
  252. // e is the expression to be analyzed.
  253. // types maps variable names to the type of their run-time value.
  254. // values maps variable names to their compile-time values. It is not
  255. // directly used in this function but is passed to InterExp.
  256. auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
  257. -> TCExpression {
  258. if (tracing_output) {
  259. llvm::outs() << "checking expression " << *e << "\ntypes: ";
  260. PrintTypeEnv(types, llvm::outs());
  261. llvm::outs() << "\nvalues: ";
  262. PrintEnv(values, llvm::outs());
  263. llvm::outs() << "\n";
  264. }
  265. switch (e->Tag()) {
  266. case Expression::Kind::IndexExpression: {
  267. const auto& index = cast<IndexExpression>(*e);
  268. auto res = TypeCheckExp(index.Aggregate(), types, values);
  269. auto t = res.type;
  270. switch (t->Tag()) {
  271. case Value::Kind::TupleValue: {
  272. auto i = cast<IntValue>(*InterpExp(values, index.Offset())).Val();
  273. std::string f = std::to_string(i);
  274. const Value* field_t = cast<TupleValue>(*t).FindField(f);
  275. if (field_t == nullptr) {
  276. FATAL_COMPILATION_ERROR(e->SourceLoc())
  277. << "field " << f << " is not in the tuple " << *t;
  278. }
  279. auto new_e = global_arena->New<IndexExpression>(
  280. e->SourceLoc(), res.exp,
  281. global_arena->New<IntLiteral>(e->SourceLoc(), i));
  282. return TCExpression(new_e, field_t, res.types);
  283. }
  284. default:
  285. FATAL_COMPILATION_ERROR(e->SourceLoc()) << "expected a tuple";
  286. }
  287. }
  288. case Expression::Kind::TupleLiteral: {
  289. std::vector<FieldInitializer> new_args;
  290. std::vector<TupleElement> arg_types;
  291. auto new_types = types;
  292. for (const auto& arg : cast<TupleLiteral>(*e).Fields()) {
  293. auto arg_res = TypeCheckExp(arg.expression, new_types, values);
  294. new_types = arg_res.types;
  295. new_args.push_back(FieldInitializer(arg.name, arg_res.exp));
  296. arg_types.push_back({.name = arg.name, .value = arg_res.type});
  297. }
  298. auto tuple_e = global_arena->New<TupleLiteral>(e->SourceLoc(), new_args);
  299. auto tuple_t = global_arena->RawNew<TupleValue>(std::move(arg_types));
  300. return TCExpression(tuple_e, tuple_t, new_types);
  301. }
  302. case Expression::Kind::FieldAccessExpression: {
  303. const auto& access = cast<FieldAccessExpression>(*e);
  304. auto res = TypeCheckExp(access.Aggregate(), types, values);
  305. auto t = res.type;
  306. switch (t->Tag()) {
  307. case Value::Kind::ClassType: {
  308. const auto& t_class = cast<ClassType>(*t);
  309. // Search for a field
  310. for (auto& field : t_class.Fields()) {
  311. if (access.Field() == field.first) {
  312. Ptr<const Expression> new_e =
  313. global_arena->New<FieldAccessExpression>(
  314. e->SourceLoc(), res.exp, access.Field());
  315. return TCExpression(new_e, field.second, res.types);
  316. }
  317. }
  318. // Search for a method
  319. for (auto& method : t_class.Methods()) {
  320. if (access.Field() == method.first) {
  321. Ptr<const Expression> new_e =
  322. global_arena->New<FieldAccessExpression>(
  323. e->SourceLoc(), res.exp, access.Field());
  324. return TCExpression(new_e, method.second, res.types);
  325. }
  326. }
  327. FATAL_COMPILATION_ERROR(e->SourceLoc())
  328. << "class " << t_class.Name() << " does not have a field named "
  329. << access.Field();
  330. }
  331. case Value::Kind::TupleValue: {
  332. const auto& tup = cast<TupleValue>(*t);
  333. for (const TupleElement& field : tup.Elements()) {
  334. if (access.Field() == field.name) {
  335. auto new_e = global_arena->New<FieldAccessExpression>(
  336. e->SourceLoc(), res.exp, access.Field());
  337. return TCExpression(new_e, field.value, res.types);
  338. }
  339. }
  340. FATAL_COMPILATION_ERROR(e->SourceLoc())
  341. << "tuple " << tup << " does not have a field named "
  342. << access.Field();
  343. }
  344. case Value::Kind::ChoiceType: {
  345. const auto& choice = cast<ChoiceType>(*t);
  346. for (const auto& vt : choice.Alternatives()) {
  347. if (access.Field() == vt.first) {
  348. Ptr<const Expression> new_e =
  349. global_arena->New<FieldAccessExpression>(
  350. e->SourceLoc(), res.exp, access.Field());
  351. auto fun_ty = global_arena->RawNew<FunctionType>(
  352. std::vector<GenericBinding>(), vt.second, t);
  353. return TCExpression(new_e, fun_ty, res.types);
  354. }
  355. }
  356. FATAL_COMPILATION_ERROR(e->SourceLoc())
  357. << "choice " << choice.Name() << " does not have a field named "
  358. << access.Field();
  359. }
  360. default:
  361. FATAL_COMPILATION_ERROR(e->SourceLoc())
  362. << "field access, expected a struct\n"
  363. << *e;
  364. }
  365. }
  366. case Expression::Kind::IdentifierExpression: {
  367. const auto& ident = cast<IdentifierExpression>(*e);
  368. std::optional<const Value*> type = types.Get(ident.Name());
  369. if (type) {
  370. return TCExpression(e, *type, types);
  371. } else {
  372. FATAL_COMPILATION_ERROR(e->SourceLoc())
  373. << "could not find `" << ident.Name() << "`";
  374. }
  375. }
  376. case Expression::Kind::IntLiteral:
  377. return TCExpression(e, global_arena->RawNew<IntType>(), types);
  378. case Expression::Kind::BoolLiteral:
  379. return TCExpression(e, global_arena->RawNew<BoolType>(), types);
  380. case Expression::Kind::PrimitiveOperatorExpression: {
  381. const auto& op = cast<PrimitiveOperatorExpression>(*e);
  382. std::vector<Ptr<const Expression>> es;
  383. std::vector<const Value*> ts;
  384. auto new_types = types;
  385. for (Ptr<const Expression> argument : op.Arguments()) {
  386. auto res = TypeCheckExp(argument, types, values);
  387. new_types = res.types;
  388. es.push_back(res.exp);
  389. ts.push_back(res.type);
  390. }
  391. auto new_e = global_arena->New<PrimitiveOperatorExpression>(
  392. e->SourceLoc(), op.Op(), es);
  393. switch (op.Op()) {
  394. case Operator::Neg:
  395. ExpectType(e->SourceLoc(), "negation",
  396. global_arena->RawNew<IntType>(), ts[0]);
  397. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  398. new_types);
  399. case Operator::Add:
  400. ExpectType(e->SourceLoc(), "addition(1)",
  401. global_arena->RawNew<IntType>(), ts[0]);
  402. ExpectType(e->SourceLoc(), "addition(2)",
  403. global_arena->RawNew<IntType>(), ts[1]);
  404. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  405. new_types);
  406. case Operator::Sub:
  407. ExpectType(e->SourceLoc(), "subtraction(1)",
  408. global_arena->RawNew<IntType>(), ts[0]);
  409. ExpectType(e->SourceLoc(), "subtraction(2)",
  410. global_arena->RawNew<IntType>(), ts[1]);
  411. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  412. new_types);
  413. case Operator::Mul:
  414. ExpectType(e->SourceLoc(), "multiplication(1)",
  415. global_arena->RawNew<IntType>(), ts[0]);
  416. ExpectType(e->SourceLoc(), "multiplication(2)",
  417. global_arena->RawNew<IntType>(), ts[1]);
  418. return TCExpression(new_e, global_arena->RawNew<IntType>(),
  419. new_types);
  420. case Operator::And:
  421. ExpectType(e->SourceLoc(), "&&(1)", global_arena->RawNew<BoolType>(),
  422. ts[0]);
  423. ExpectType(e->SourceLoc(), "&&(2)", global_arena->RawNew<BoolType>(),
  424. ts[1]);
  425. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  426. new_types);
  427. case Operator::Or:
  428. ExpectType(e->SourceLoc(), "||(1)", global_arena->RawNew<BoolType>(),
  429. ts[0]);
  430. ExpectType(e->SourceLoc(), "||(2)", global_arena->RawNew<BoolType>(),
  431. ts[1]);
  432. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  433. new_types);
  434. case Operator::Not:
  435. ExpectType(e->SourceLoc(), "!", global_arena->RawNew<BoolType>(),
  436. ts[0]);
  437. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  438. new_types);
  439. case Operator::Eq:
  440. ExpectType(e->SourceLoc(), "==", ts[0], ts[1]);
  441. return TCExpression(new_e, global_arena->RawNew<BoolType>(),
  442. new_types);
  443. case Operator::Deref:
  444. ExpectPointerType(e->SourceLoc(), "*", ts[0]);
  445. return TCExpression(new_e, cast<PointerType>(*ts[0]).Type(),
  446. new_types);
  447. case Operator::Ptr:
  448. ExpectType(e->SourceLoc(), "*", global_arena->RawNew<TypeType>(),
  449. ts[0]);
  450. return TCExpression(new_e, global_arena->RawNew<TypeType>(),
  451. new_types);
  452. }
  453. break;
  454. }
  455. case Expression::Kind::CallExpression: {
  456. const auto& call = cast<CallExpression>(*e);
  457. auto fun_res = TypeCheckExp(call.Function(), types, values);
  458. switch (fun_res.type->Tag()) {
  459. case Value::Kind::FunctionType: {
  460. const auto& fun_t = cast<FunctionType>(*fun_res.type);
  461. auto arg_res = TypeCheckExp(call.Argument(), fun_res.types, values);
  462. auto parameter_type = fun_t.Param();
  463. auto return_type = fun_t.Ret();
  464. if (!fun_t.Deduced().empty()) {
  465. auto deduced_args = ArgumentDeduction(e->SourceLoc(), TypeEnv(),
  466. parameter_type, arg_res.type);
  467. for (auto& deduced_param : fun_t.Deduced()) {
  468. // TODO: change the following to a CHECK once the real checking
  469. // has been added to the type checking of function signatures.
  470. if (!deduced_args.Get(deduced_param.name)) {
  471. FATAL_COMPILATION_ERROR(e->SourceLoc())
  472. << "could not deduce type argument for type parameter "
  473. << deduced_param.name;
  474. }
  475. }
  476. parameter_type = Substitute(deduced_args, parameter_type);
  477. return_type = Substitute(deduced_args, return_type);
  478. } else {
  479. ExpectType(e->SourceLoc(), "call", parameter_type, arg_res.type);
  480. }
  481. auto new_e = global_arena->New<CallExpression>(
  482. e->SourceLoc(), fun_res.exp, arg_res.exp);
  483. return TCExpression(new_e, return_type, arg_res.types);
  484. }
  485. default: {
  486. FATAL_COMPILATION_ERROR(e->SourceLoc())
  487. << "in call, expected a function\n"
  488. << *e;
  489. }
  490. }
  491. break;
  492. }
  493. case Expression::Kind::FunctionTypeLiteral: {
  494. const auto& fn = cast<FunctionTypeLiteral>(*e);
  495. auto pt = InterpExp(values, fn.Parameter());
  496. auto rt = InterpExp(values, fn.ReturnType());
  497. auto new_e = global_arena->New<FunctionTypeLiteral>(
  498. e->SourceLoc(), ReifyType(pt, e->SourceLoc()),
  499. ReifyType(rt, e->SourceLoc()),
  500. /*is_omitted_return_type=*/false);
  501. return TCExpression(new_e, global_arena->RawNew<TypeType>(), types);
  502. }
  503. case Expression::Kind::StringLiteral:
  504. return TCExpression(e, global_arena->RawNew<StringType>(), types);
  505. case Expression::Kind::IntrinsicExpression:
  506. switch (cast<IntrinsicExpression>(*e).Intrinsic()) {
  507. case IntrinsicExpression::IntrinsicKind::Print:
  508. return TCExpression(e, &TupleValue::Empty(), types);
  509. }
  510. case Expression::Kind::IntTypeLiteral:
  511. case Expression::Kind::BoolTypeLiteral:
  512. case Expression::Kind::StringTypeLiteral:
  513. case Expression::Kind::TypeTypeLiteral:
  514. case Expression::Kind::ContinuationTypeLiteral:
  515. return TCExpression(e, global_arena->RawNew<TypeType>(), types);
  516. }
  517. }
  518. // Equivalent to TypeCheckExp, but operates on Patterns instead of Expressions.
  519. // `expected` is the type that this pattern is expected to have, if the
  520. // surrounding context gives us that information. Otherwise, it is null.
  521. auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
  522. const Value* expected) -> TCPattern {
  523. if (tracing_output) {
  524. llvm::outs() << "checking pattern " << *p;
  525. if (expected) {
  526. llvm::outs() << ", expecting " << *expected;
  527. }
  528. llvm::outs() << "\ntypes: ";
  529. PrintTypeEnv(types, llvm::outs());
  530. llvm::outs() << "\nvalues: ";
  531. PrintEnv(values, llvm::outs());
  532. llvm::outs() << "\n";
  533. }
  534. switch (p->Tag()) {
  535. case Pattern::Kind::AutoPattern: {
  536. return {.pattern = p,
  537. .type = global_arena->RawNew<TypeType>(),
  538. .types = types};
  539. }
  540. case Pattern::Kind::BindingPattern: {
  541. const auto& binding = cast<BindingPattern>(*p);
  542. TCPattern binding_type_result =
  543. TypeCheckPattern(binding.Type(), types, values, nullptr);
  544. const Value* type = InterpPattern(values, binding_type_result.pattern);
  545. if (expected != nullptr) {
  546. std::optional<Env> values =
  547. PatternMatch(type, expected, binding.Type()->SourceLoc());
  548. if (values == std::nullopt) {
  549. FATAL_COMPILATION_ERROR(binding.Type()->SourceLoc())
  550. << "Type pattern '" << *type << "' does not match actual type '"
  551. << *expected << "'";
  552. }
  553. CHECK(values->begin() == values->end())
  554. << "Name bindings within type patterns are unsupported";
  555. type = expected;
  556. }
  557. auto new_p = global_arena->New<BindingPattern>(
  558. binding.SourceLoc(), binding.Name(),
  559. global_arena->New<ExpressionPattern>(
  560. ReifyType(type, binding.SourceLoc())));
  561. if (binding.Name().has_value()) {
  562. types.Set(*binding.Name(), type);
  563. }
  564. return {.pattern = new_p, .type = type, .types = types};
  565. }
  566. case Pattern::Kind::TuplePattern: {
  567. const auto& tuple = cast<TuplePattern>(*p);
  568. std::vector<TuplePattern::Field> new_fields;
  569. std::vector<TupleElement> field_types;
  570. auto new_types = types;
  571. if (expected && expected->Tag() != Value::Kind::TupleValue) {
  572. FATAL_COMPILATION_ERROR(p->SourceLoc()) << "didn't expect a tuple";
  573. }
  574. if (expected && tuple.Fields().size() !=
  575. cast<TupleValue>(*expected).Elements().size()) {
  576. FATAL_COMPILATION_ERROR(tuple.SourceLoc())
  577. << "tuples of different length";
  578. }
  579. for (size_t i = 0; i < tuple.Fields().size(); ++i) {
  580. const TuplePattern::Field& field = tuple.Fields()[i];
  581. const Value* expected_field_type = nullptr;
  582. if (expected != nullptr) {
  583. const TupleElement& expected_element =
  584. cast<TupleValue>(*expected).Elements()[i];
  585. if (expected_element.name != field.name) {
  586. FATAL_COMPILATION_ERROR(tuple.SourceLoc())
  587. << "field names do not match, expected "
  588. << expected_element.name << " but got " << field.name;
  589. }
  590. expected_field_type = expected_element.value;
  591. }
  592. auto field_result = TypeCheckPattern(field.pattern, new_types, values,
  593. expected_field_type);
  594. new_types = field_result.types;
  595. new_fields.push_back(
  596. TuplePattern::Field(field.name, field_result.pattern));
  597. field_types.push_back({.name = field.name, .value = field_result.type});
  598. }
  599. auto new_tuple =
  600. global_arena->New<TuplePattern>(tuple.SourceLoc(), new_fields);
  601. auto tuple_t = global_arena->RawNew<TupleValue>(std::move(field_types));
  602. return {.pattern = new_tuple, .type = tuple_t, .types = new_types};
  603. }
  604. case Pattern::Kind::AlternativePattern: {
  605. const auto& alternative = cast<AlternativePattern>(*p);
  606. const Value* choice_type = InterpExp(values, alternative.ChoiceType());
  607. if (choice_type->Tag() != Value::Kind::ChoiceType) {
  608. FATAL_COMPILATION_ERROR(alternative.SourceLoc())
  609. << "alternative pattern does not name a choice type.";
  610. }
  611. if (expected != nullptr) {
  612. ExpectType(alternative.SourceLoc(), "alternative pattern", expected,
  613. choice_type);
  614. }
  615. const Value* parameter_types =
  616. FindInVarValues(alternative.AlternativeName(),
  617. cast<ChoiceType>(*choice_type).Alternatives());
  618. if (parameter_types == nullptr) {
  619. FATAL_COMPILATION_ERROR(alternative.SourceLoc())
  620. << "'" << alternative.AlternativeName()
  621. << "' is not an alternative of " << choice_type;
  622. }
  623. TCPattern arg_results = TypeCheckPattern(alternative.Arguments(), types,
  624. values, parameter_types);
  625. // TODO: Think about a cleaner way to cast between Ptr types.
  626. auto arguments = Ptr<const TuplePattern>(
  627. cast<const TuplePattern>(arg_results.pattern.Get()));
  628. return {.pattern = global_arena->New<AlternativePattern>(
  629. alternative.SourceLoc(),
  630. ReifyType(choice_type, alternative.SourceLoc()),
  631. alternative.AlternativeName(), arguments),
  632. .type = choice_type,
  633. .types = arg_results.types};
  634. }
  635. case Pattern::Kind::ExpressionPattern: {
  636. TCExpression result =
  637. TypeCheckExp(cast<ExpressionPattern>(*p).Expression(), types, values);
  638. return {.pattern = global_arena->New<ExpressionPattern>(result.exp),
  639. .type = result.type,
  640. .types = result.types};
  641. }
  642. }
  643. }
  644. static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
  645. Ptr<const Statement> body, TypeEnv types, Env values,
  646. const Value*& ret_type, bool is_omitted_ret_type)
  647. -> std::pair<Ptr<const Pattern>, Ptr<const Statement>> {
  648. auto pat_res = TypeCheckPattern(pat, types, values, expected);
  649. auto res =
  650. TypeCheckStmt(body, pat_res.types, values, ret_type, is_omitted_ret_type);
  651. return std::make_pair(pat, res.stmt);
  652. }
  653. // The TypeCheckStmt function performs semantic analysis on a statement.
  654. // It returns a new version of the statement and a new type environment.
  655. //
  656. // The ret_type parameter is used for analyzing return statements.
  657. // It is the declared return type of the enclosing function definition.
  658. // If the return type is "auto", then the return type is inferred from
  659. // the first return statement.
  660. auto TypeCheckStmt(Ptr<const Statement> s, TypeEnv types, Env values,
  661. const Value*& ret_type, bool is_omitted_ret_type)
  662. -> TCStatement {
  663. switch (s->Tag()) {
  664. case Statement::Kind::Match: {
  665. const auto& match = cast<Match>(*s);
  666. auto res = TypeCheckExp(match.Exp(), types, values);
  667. auto res_type = res.type;
  668. auto new_clauses = global_arena->RawNew<
  669. std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>>();
  670. for (auto& clause : *match.Clauses()) {
  671. new_clauses->push_back(TypecheckCase(res_type, clause.first,
  672. clause.second, types, values,
  673. ret_type, is_omitted_ret_type));
  674. }
  675. auto new_s =
  676. global_arena->New<Match>(s->SourceLoc(), res.exp, new_clauses);
  677. return TCStatement(new_s, types);
  678. }
  679. case Statement::Kind::While: {
  680. const auto& while_stmt = cast<While>(*s);
  681. auto cnd_res = TypeCheckExp(while_stmt.Cond(), types, values);
  682. ExpectType(s->SourceLoc(), "condition of `while`",
  683. global_arena->RawNew<BoolType>(), cnd_res.type);
  684. auto body_res = TypeCheckStmt(while_stmt.Body(), types, values, ret_type,
  685. is_omitted_ret_type);
  686. auto new_s =
  687. global_arena->New<While>(s->SourceLoc(), cnd_res.exp, body_res.stmt);
  688. return TCStatement(new_s, types);
  689. }
  690. case Statement::Kind::Break:
  691. case Statement::Kind::Continue:
  692. return TCStatement(s, types);
  693. case Statement::Kind::Block: {
  694. const auto& block = cast<Block>(*s);
  695. if (block.Stmt()) {
  696. auto stmt_res = TypeCheckStmt(*block.Stmt(), types, values, ret_type,
  697. is_omitted_ret_type);
  698. return TCStatement(
  699. global_arena->New<Block>(s->SourceLoc(), stmt_res.stmt), types);
  700. } else {
  701. return TCStatement(s, types);
  702. }
  703. }
  704. case Statement::Kind::VariableDefinition: {
  705. const auto& var = cast<VariableDefinition>(*s);
  706. auto res = TypeCheckExp(var.Init(), types, values);
  707. const Value* rhs_ty = res.type;
  708. auto lhs_res = TypeCheckPattern(var.Pat(), types, values, rhs_ty);
  709. auto new_s = global_arena->New<VariableDefinition>(s->SourceLoc(),
  710. var.Pat(), res.exp);
  711. return TCStatement(new_s, lhs_res.types);
  712. }
  713. case Statement::Kind::Sequence: {
  714. const auto& seq = cast<Sequence>(*s);
  715. auto stmt_res = TypeCheckStmt(seq.Stmt(), types, values, ret_type,
  716. is_omitted_ret_type);
  717. auto checked_types = stmt_res.types;
  718. std::optional<Ptr<const Statement>> next_stmt;
  719. if (seq.Next()) {
  720. auto next_res = TypeCheckStmt(*seq.Next(), checked_types, values,
  721. ret_type, is_omitted_ret_type);
  722. next_stmt = next_res.stmt;
  723. checked_types = next_res.types;
  724. }
  725. return TCStatement(
  726. global_arena->New<Sequence>(s->SourceLoc(), stmt_res.stmt, next_stmt),
  727. checked_types);
  728. }
  729. case Statement::Kind::Assign: {
  730. const auto& assign = cast<Assign>(*s);
  731. auto rhs_res = TypeCheckExp(assign.Rhs(), types, values);
  732. auto rhs_t = rhs_res.type;
  733. auto lhs_res = TypeCheckExp(assign.Lhs(), types, values);
  734. auto lhs_t = lhs_res.type;
  735. ExpectType(s->SourceLoc(), "assign", lhs_t, rhs_t);
  736. auto new_s =
  737. global_arena->New<Assign>(s->SourceLoc(), lhs_res.exp, rhs_res.exp);
  738. return TCStatement(new_s, lhs_res.types);
  739. }
  740. case Statement::Kind::ExpressionStatement: {
  741. auto res =
  742. TypeCheckExp(cast<ExpressionStatement>(*s).Exp(), types, values);
  743. auto new_s =
  744. global_arena->New<ExpressionStatement>(s->SourceLoc(), res.exp);
  745. return TCStatement(new_s, types);
  746. }
  747. case Statement::Kind::If: {
  748. const auto& if_stmt = cast<If>(*s);
  749. auto cnd_res = TypeCheckExp(if_stmt.Cond(), types, values);
  750. ExpectType(s->SourceLoc(), "condition of `if`",
  751. global_arena->RawNew<BoolType>(), cnd_res.type);
  752. auto then_res = TypeCheckStmt(if_stmt.ThenStmt(), types, values, ret_type,
  753. is_omitted_ret_type);
  754. std::optional<Ptr<const Statement>> else_stmt;
  755. if (if_stmt.ElseStmt()) {
  756. auto else_res = TypeCheckStmt(*if_stmt.ElseStmt(), types, values,
  757. ret_type, is_omitted_ret_type);
  758. else_stmt = else_res.stmt;
  759. }
  760. auto new_s = global_arena->New<If>(s->SourceLoc(), cnd_res.exp,
  761. then_res.stmt, else_stmt);
  762. return TCStatement(new_s, types);
  763. }
  764. case Statement::Kind::Return: {
  765. const auto& ret = cast<Return>(*s);
  766. auto res = TypeCheckExp(ret.Exp(), types, values);
  767. if (ret_type->Tag() == Value::Kind::AutoType) {
  768. // The following infers the return type from the first 'return'
  769. // statement. This will get more difficult with subtyping, when we
  770. // should infer the least-upper bound of all the 'return' statements.
  771. ret_type = res.type;
  772. } else {
  773. ExpectType(s->SourceLoc(), "return", ret_type, res.type);
  774. }
  775. if (ret.IsOmittedExp() != is_omitted_ret_type) {
  776. FATAL_COMPILATION_ERROR(s->SourceLoc())
  777. << *s << " should" << (is_omitted_ret_type ? " not" : "")
  778. << " provide a return value, to match the function's signature.";
  779. }
  780. return TCStatement(global_arena->New<Return>(s->SourceLoc(), res.exp,
  781. ret.IsOmittedExp()),
  782. types);
  783. }
  784. case Statement::Kind::Continuation: {
  785. const auto& cont = cast<Continuation>(*s);
  786. TCStatement body_result = TypeCheckStmt(cont.Body(), types, values,
  787. ret_type, is_omitted_ret_type);
  788. auto new_continuation = global_arena->New<Continuation>(
  789. s->SourceLoc(), cont.ContinuationVariable(), body_result.stmt);
  790. types.Set(cont.ContinuationVariable(),
  791. global_arena->RawNew<ContinuationType>());
  792. return TCStatement(new_continuation, types);
  793. }
  794. case Statement::Kind::Run: {
  795. TCExpression argument_result =
  796. TypeCheckExp(cast<Run>(*s).Argument(), types, values);
  797. ExpectType(s->SourceLoc(), "argument of `run`",
  798. global_arena->RawNew<ContinuationType>(),
  799. argument_result.type);
  800. auto new_run =
  801. global_arena->New<Run>(s->SourceLoc(), argument_result.exp);
  802. return TCStatement(new_run, types);
  803. }
  804. case Statement::Kind::Await: {
  805. // nothing to do here
  806. return TCStatement(s, types);
  807. }
  808. } // switch
  809. }
  810. static auto CheckOrEnsureReturn(std::optional<Ptr<const Statement>> opt_stmt,
  811. bool omitted_ret_type, SourceLocation loc)
  812. -> Ptr<const Statement> {
  813. if (!opt_stmt) {
  814. if (omitted_ret_type) {
  815. return global_arena->New<Return>(loc);
  816. } else {
  817. FATAL_COMPILATION_ERROR(loc)
  818. << "control-flow reaches end of function that provides a `->` return "
  819. "type without reaching a return statement";
  820. }
  821. }
  822. Ptr<const Statement> stmt = *opt_stmt;
  823. switch (stmt->Tag()) {
  824. case Statement::Kind::Match: {
  825. const auto& match = cast<Match>(*stmt);
  826. auto new_clauses = global_arena->RawNew<
  827. std::list<std::pair<Ptr<const Pattern>, Ptr<const Statement>>>>();
  828. for (const auto& clause : *match.Clauses()) {
  829. auto s = CheckOrEnsureReturn(clause.second, omitted_ret_type,
  830. stmt->SourceLoc());
  831. new_clauses->push_back(std::make_pair(clause.first, s));
  832. }
  833. return global_arena->New<Match>(stmt->SourceLoc(), match.Exp(),
  834. new_clauses);
  835. }
  836. case Statement::Kind::Block:
  837. return global_arena->New<Block>(
  838. stmt->SourceLoc(),
  839. CheckOrEnsureReturn(cast<Block>(*stmt).Stmt(), omitted_ret_type,
  840. stmt->SourceLoc()));
  841. case Statement::Kind::If: {
  842. const auto& if_stmt = cast<If>(*stmt);
  843. return global_arena->New<If>(
  844. stmt->SourceLoc(), if_stmt.Cond(),
  845. CheckOrEnsureReturn(if_stmt.ThenStmt(), omitted_ret_type,
  846. stmt->SourceLoc()),
  847. CheckOrEnsureReturn(if_stmt.ElseStmt(), omitted_ret_type,
  848. stmt->SourceLoc()));
  849. }
  850. case Statement::Kind::Return:
  851. return stmt;
  852. case Statement::Kind::Sequence: {
  853. const auto& seq = cast<Sequence>(*stmt);
  854. if (seq.Next()) {
  855. return global_arena->New<Sequence>(
  856. stmt->SourceLoc(), seq.Stmt(),
  857. CheckOrEnsureReturn(seq.Next(), omitted_ret_type,
  858. stmt->SourceLoc()));
  859. } else {
  860. return CheckOrEnsureReturn(seq.Stmt(), omitted_ret_type,
  861. stmt->SourceLoc());
  862. }
  863. }
  864. case Statement::Kind::Continuation:
  865. case Statement::Kind::Run:
  866. case Statement::Kind::Await:
  867. return stmt;
  868. case Statement::Kind::Assign:
  869. case Statement::Kind::ExpressionStatement:
  870. case Statement::Kind::While:
  871. case Statement::Kind::Break:
  872. case Statement::Kind::Continue:
  873. case Statement::Kind::VariableDefinition:
  874. if (omitted_ret_type) {
  875. return global_arena->New<Sequence>(stmt->SourceLoc(), stmt,
  876. global_arena->New<Return>(loc));
  877. } else {
  878. FATAL_COMPILATION_ERROR(stmt->SourceLoc())
  879. << "control-flow reaches end of function that provides a `->` "
  880. "return type without reaching a return statement";
  881. }
  882. }
  883. }
  884. // TODO: factor common parts of TypeCheckFunDef and TypeOfFunDef into
  885. // a function.
  886. // TODO: Add checking to function definitions to ensure that
  887. // all deduced type parameters will be deduced.
  888. static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
  889. Env values) -> Ptr<const FunctionDefinition> {
  890. // Bring the deduced parameters into scope
  891. for (const auto& deduced : f->deduced_parameters) {
  892. // auto t = InterpExp(values, deduced.type);
  893. types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
  894. Address a = state->heap.AllocateValue(*types.Get(deduced.name));
  895. values.Set(deduced.name, a);
  896. }
  897. // Type check the parameter pattern
  898. auto param_res = TypeCheckPattern(f->param_pattern, types, values, nullptr);
  899. // Evaluate the return type expression
  900. auto return_type = InterpPattern(values, f->return_type);
  901. if (f->name == "main") {
  902. ExpectType(f->source_location, "return type of `main`",
  903. global_arena->RawNew<IntType>(), return_type);
  904. // TODO: Check that main doesn't have any parameters.
  905. }
  906. std::optional<Ptr<const Statement>> body_stmt;
  907. if (f->body) {
  908. auto res = TypeCheckStmt(*f->body, param_res.types, values, return_type,
  909. f->is_omitted_return_type);
  910. body_stmt = res.stmt;
  911. }
  912. auto body = CheckOrEnsureReturn(body_stmt, f->is_omitted_return_type,
  913. f->source_location);
  914. return global_arena->New<FunctionDefinition>(
  915. f->source_location, f->name, f->deduced_parameters, f->param_pattern,
  916. global_arena->New<ExpressionPattern>(
  917. ReifyType(return_type, f->source_location)),
  918. /*is_omitted_return_type=*/false, body);
  919. }
  920. static auto TypeOfFunDef(TypeEnv types, Env values,
  921. const FunctionDefinition* fun_def) -> const Value* {
  922. // Bring the deduced parameters into scope
  923. for (const auto& deduced : fun_def->deduced_parameters) {
  924. // auto t = InterpExp(values, deduced.type);
  925. types.Set(deduced.name, global_arena->RawNew<VariableType>(deduced.name));
  926. Address a = state->heap.AllocateValue(*types.Get(deduced.name));
  927. values.Set(deduced.name, a);
  928. }
  929. // Type check the parameter pattern
  930. auto param_res =
  931. TypeCheckPattern(fun_def->param_pattern, types, values, nullptr);
  932. // Evaluate the return type expression
  933. auto ret = InterpPattern(values, fun_def->return_type);
  934. if (ret->Tag() == Value::Kind::AutoType) {
  935. auto f = TypeCheckFunDef(fun_def, types, values);
  936. ret = InterpPattern(values, f->return_type);
  937. }
  938. return global_arena->RawNew<FunctionType>(fun_def->deduced_parameters,
  939. param_res.type, ret);
  940. }
  941. static auto TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
  942. Env ct_top) -> const Value* {
  943. VarValues fields;
  944. VarValues methods;
  945. for (Ptr<const Member> m : sd->members) {
  946. switch (m->Tag()) {
  947. case Member::Kind::FieldMember: {
  948. Ptr<const BindingPattern> binding = cast<FieldMember>(*m).Binding();
  949. if (!binding->Name().has_value()) {
  950. FATAL_COMPILATION_ERROR(binding->SourceLoc())
  951. << "Struct members must have names";
  952. }
  953. const auto* binding_type =
  954. dyn_cast<ExpressionPattern>(binding->Type().Get());
  955. if (binding_type == nullptr) {
  956. FATAL_COMPILATION_ERROR(binding->SourceLoc())
  957. << "Struct members must have explicit types";
  958. }
  959. auto type = InterpExp(ct_top, binding_type->Expression());
  960. fields.push_back(std::make_pair(*binding->Name(), type));
  961. break;
  962. }
  963. }
  964. }
  965. return global_arena->RawNew<ClassType>(sd->name, std::move(fields),
  966. std::move(methods));
  967. }
  968. static auto GetName(const Declaration& d) -> const std::string& {
  969. switch (d.Tag()) {
  970. case Declaration::Kind::FunctionDeclaration:
  971. return cast<FunctionDeclaration>(d).Definition().name;
  972. case Declaration::Kind::ClassDeclaration:
  973. return cast<ClassDeclaration>(d).Definition().name;
  974. case Declaration::Kind::ChoiceDeclaration:
  975. return cast<ChoiceDeclaration>(d).Name();
  976. case Declaration::Kind::VariableDeclaration: {
  977. Ptr<const BindingPattern> binding =
  978. cast<VariableDeclaration>(d).Binding();
  979. if (!binding->Name().has_value()) {
  980. FATAL_COMPILATION_ERROR(binding->SourceLoc())
  981. << "Top-level variable declarations must have names";
  982. }
  983. return *binding->Name();
  984. }
  985. }
  986. }
  987. auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
  988. const Env& values) -> Ptr<const Declaration> {
  989. switch (d->Tag()) {
  990. case Declaration::Kind::FunctionDeclaration:
  991. return global_arena->New<FunctionDeclaration>(TypeCheckFunDef(
  992. &cast<FunctionDeclaration>(*d).Definition(), types, values));
  993. case Declaration::Kind::ClassDeclaration: {
  994. const ClassDefinition& class_def =
  995. cast<ClassDeclaration>(*d).Definition();
  996. std::list<Ptr<Member>> fields;
  997. for (Ptr<Member> m : class_def.members) {
  998. switch (m->Tag()) {
  999. case Member::Kind::FieldMember:
  1000. // TODO: Interpret the type expression and store the result.
  1001. fields.push_back(m);
  1002. break;
  1003. }
  1004. }
  1005. return global_arena->New<ClassDeclaration>(class_def.loc, class_def.name,
  1006. std::move(fields));
  1007. }
  1008. case Declaration::Kind::ChoiceDeclaration:
  1009. // TODO
  1010. return d;
  1011. case Declaration::Kind::VariableDeclaration: {
  1012. const auto& var = cast<VariableDeclaration>(*d);
  1013. // Signals a type error if the initializing expression does not have
  1014. // the declared type of the variable, otherwise returns this
  1015. // declaration with annotated types.
  1016. TCExpression type_checked_initializer =
  1017. TypeCheckExp(var.Initializer(), types, values);
  1018. const auto* binding_type =
  1019. dyn_cast<ExpressionPattern>(var.Binding()->Type().Get());
  1020. if (binding_type == nullptr) {
  1021. // TODO: consider adding support for `auto`
  1022. FATAL_COMPILATION_ERROR(var.SourceLoc())
  1023. << "Type of a top-level variable must be an expression.";
  1024. }
  1025. const Value* declared_type =
  1026. InterpExp(values, binding_type->Expression());
  1027. ExpectType(var.SourceLoc(), "initializer of variable", declared_type,
  1028. type_checked_initializer.type);
  1029. return d;
  1030. }
  1031. }
  1032. }
  1033. static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
  1034. switch (d.Tag()) {
  1035. case Declaration::Kind::FunctionDeclaration: {
  1036. const FunctionDefinition& func_def =
  1037. cast<FunctionDeclaration>(d).Definition();
  1038. auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
  1039. tops->types.Set(func_def.name, t);
  1040. InitEnv(d, &tops->values);
  1041. break;
  1042. }
  1043. case Declaration::Kind::ClassDeclaration: {
  1044. const ClassDefinition& class_def = cast<ClassDeclaration>(d).Definition();
  1045. auto st = TypeOfClassDef(&class_def, tops->types, tops->values);
  1046. Address a = state->heap.AllocateValue(st);
  1047. tops->values.Set(class_def.name, a); // Is this obsolete?
  1048. std::vector<TupleElement> field_types;
  1049. for (const auto& [field_name, field_value] :
  1050. cast<ClassType>(*st).Fields()) {
  1051. field_types.push_back({.name = field_name, .value = field_value});
  1052. }
  1053. auto fun_ty = global_arena->RawNew<FunctionType>(
  1054. std::vector<GenericBinding>(),
  1055. global_arena->RawNew<TupleValue>(std::move(field_types)), st);
  1056. tops->types.Set(class_def.name, fun_ty);
  1057. break;
  1058. }
  1059. case Declaration::Kind::ChoiceDeclaration: {
  1060. const auto& choice = cast<ChoiceDeclaration>(d);
  1061. VarValues alts;
  1062. for (const auto& [name, signature] : choice.Alternatives()) {
  1063. auto t = InterpExp(tops->values, signature);
  1064. alts.push_back(std::make_pair(name, t));
  1065. }
  1066. auto ct =
  1067. global_arena->RawNew<ChoiceType>(choice.Name(), std::move(alts));
  1068. Address a = state->heap.AllocateValue(ct);
  1069. tops->values.Set(choice.Name(), a); // Is this obsolete?
  1070. tops->types.Set(choice.Name(), ct);
  1071. break;
  1072. }
  1073. case Declaration::Kind::VariableDeclaration: {
  1074. const auto& var = cast<VariableDeclaration>(d);
  1075. // Associate the variable name with it's declared type in the
  1076. // compile-time symbol table.
  1077. Ptr<const Expression> type =
  1078. cast<ExpressionPattern>(*var.Binding()->Type()).Expression();
  1079. const Value* declared_type = InterpExp(tops->values, type);
  1080. tops->types.Set(*var.Binding()->Name(), declared_type);
  1081. break;
  1082. }
  1083. }
  1084. }
  1085. auto TopLevel(const std::list<Ptr<const Declaration>>& fs) -> TypeCheckContext {
  1086. TypeCheckContext tops;
  1087. bool found_main = false;
  1088. for (auto const& d : fs) {
  1089. if (GetName(*d) == "main") {
  1090. found_main = true;
  1091. }
  1092. TopLevel(*d, &tops);
  1093. }
  1094. if (found_main == false) {
  1095. FATAL_COMPILATION_ERROR_NO_LINE()
  1096. << "program must contain a function named `main`";
  1097. }
  1098. return tops;
  1099. }
  1100. } // namespace Carbon