typecheck.cpp 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "executable_semantics/interpreter/typecheck.h"
  5. #include <algorithm>
  6. #include <iterator>
  7. #include <map>
  8. #include <set>
  9. #include <vector>
  10. #include "common/ostream.h"
  11. #include "executable_semantics/ast/function_definition.h"
  12. #include "executable_semantics/common/error.h"
  13. #include "executable_semantics/common/tracing_flag.h"
  14. #include "executable_semantics/interpreter/interpreter.h"
  15. #include "executable_semantics/interpreter/value.h"
  16. #include "llvm/Support/Casting.h"
  17. using llvm::cast;
  18. using llvm::dyn_cast;
  19. namespace Carbon {
  20. void ExpectType(int line_num, const std::string& context, const Value* expected,
  21. const Value* actual) {
  22. if (!TypeEqual(expected, actual)) {
  23. FATAL_COMPILATION_ERROR(line_num) << "type error in " << context << "\n"
  24. << "expected: " << *expected << "\n"
  25. << "actual: " << *actual;
  26. }
  27. }
  28. void ExpectPointerType(int line_num, const std::string& context,
  29. const Value* actual) {
  30. if (actual->tag() != ValKind::PointerType) {
  31. FATAL_COMPILATION_ERROR(line_num) << "type error in " << context << "\n"
  32. << "expected a pointer type\n"
  33. << "actual: " << *actual;
  34. }
  35. }
  36. // Reify type to type expression.
  37. auto ReifyType(const Value* t, int line_num) -> const Expression* {
  38. switch (t->tag()) {
  39. case ValKind::IntType:
  40. return Expression::MakeIntTypeLiteral(0);
  41. case ValKind::BoolType:
  42. return Expression::MakeBoolTypeLiteral(0);
  43. case ValKind::TypeType:
  44. return Expression::MakeTypeTypeLiteral(0);
  45. case ValKind::ContinuationType:
  46. return Expression::MakeContinuationTypeLiteral(0);
  47. case ValKind::FunctionType:
  48. return Expression::MakeFunctionTypeLiteral(
  49. 0, ReifyType(t->GetFunctionType().param, line_num),
  50. ReifyType(t->GetFunctionType().ret, line_num));
  51. case ValKind::TupleValue: {
  52. std::vector<FieldInitializer> args;
  53. for (const TupleElement& field : t->GetTupleValue().elements) {
  54. args.push_back(
  55. FieldInitializer(field.name, ReifyType(field.value, line_num)));
  56. }
  57. return Expression::MakeTupleLiteral(0, args);
  58. }
  59. case ValKind::StructType:
  60. return Expression::MakeIdentifierExpression(0, t->GetStructType().name);
  61. case ValKind::ChoiceType:
  62. return Expression::MakeIdentifierExpression(0, t->GetChoiceType().name);
  63. case ValKind::PointerType:
  64. return Expression::MakePrimitiveOperatorExpression(
  65. 0, Operator::Ptr, {ReifyType(t->GetPointerType().type, line_num)});
  66. case ValKind::VariableType:
  67. return Expression::MakeIdentifierExpression(0, t->GetVariableType().name);
  68. default:
  69. llvm::errs() << line_num << ": expected a type, not " << *t << "\n";
  70. exit(-1);
  71. }
  72. }
  73. // Perform type argument deduction, matching the parameter type `param`
  74. // against the argument type `arg`. Whenever there is an VariableType
  75. // in the parameter type, it is deduced to be the corresponding type
  76. // inside the argument type.
  77. // The `deduced` parameter is an accumulator, that is, it holds the
  78. // results so-far.
  79. auto ArgumentDeduction(int line_num, TypeEnv deduced, const Value* param,
  80. const Value* arg) -> TypeEnv {
  81. switch (param->tag()) {
  82. case ValKind::VariableType: {
  83. std::optional<const Value*> d =
  84. deduced.Get(param->GetVariableType().name);
  85. if (!d) {
  86. deduced.Set(param->GetVariableType().name, arg);
  87. } else {
  88. ExpectType(line_num, "argument deduction", *d, arg);
  89. }
  90. return deduced;
  91. }
  92. case ValKind::TupleValue: {
  93. if (arg->tag() != ValKind::TupleValue) {
  94. ExpectType(line_num, "argument deduction", param, arg);
  95. }
  96. if (param->GetTupleValue().elements.size() !=
  97. arg->GetTupleValue().elements.size()) {
  98. ExpectType(line_num, "argument deduction", param, arg);
  99. }
  100. for (size_t i = 0; i < param->GetTupleValue().elements.size(); ++i) {
  101. if (param->GetTupleValue().elements[i].name !=
  102. arg->GetTupleValue().elements[i].name) {
  103. std::cerr << line_num << ": mismatch in tuple names, "
  104. << param->GetTupleValue().elements[i].name
  105. << " != " << arg->GetTupleValue().elements[i].name
  106. << std::endl;
  107. exit(-1);
  108. }
  109. deduced = ArgumentDeduction(line_num, deduced,
  110. param->GetTupleValue().elements[i].value,
  111. arg->GetTupleValue().elements[i].value);
  112. }
  113. return deduced;
  114. }
  115. case ValKind::FunctionType: {
  116. if (arg->tag() != ValKind::FunctionType) {
  117. ExpectType(line_num, "argument deduction", param, arg);
  118. }
  119. // TODO: handle situation when arg has deduced parameters.
  120. deduced =
  121. ArgumentDeduction(line_num, deduced, param->GetFunctionType().param,
  122. arg->GetFunctionType().param);
  123. deduced =
  124. ArgumentDeduction(line_num, deduced, param->GetFunctionType().ret,
  125. arg->GetFunctionType().ret);
  126. return deduced;
  127. }
  128. case ValKind::PointerType: {
  129. if (arg->tag() != ValKind::PointerType) {
  130. ExpectType(line_num, "argument deduction", param, arg);
  131. }
  132. return ArgumentDeduction(line_num, deduced, param->GetPointerType().type,
  133. arg->GetPointerType().type);
  134. }
  135. // Nothing to do in the case for `auto`.
  136. case ValKind::AutoType: {
  137. return deduced;
  138. }
  139. // For the following cases, we check for type equality.
  140. case ValKind::ContinuationType:
  141. case ValKind::StructType:
  142. case ValKind::ChoiceType:
  143. case ValKind::IntType:
  144. case ValKind::BoolType:
  145. case ValKind::TypeType: {
  146. ExpectType(line_num, "argument deduction", param, arg);
  147. return deduced;
  148. }
  149. // The rest of these cases should never happen.
  150. case ValKind::IntValue:
  151. case ValKind::BoolValue:
  152. case ValKind::FunctionValue:
  153. case ValKind::PointerValue:
  154. case ValKind::StructValue:
  155. case ValKind::AlternativeValue:
  156. case ValKind::BindingPlaceholderValue:
  157. case ValKind::AlternativeConstructorValue:
  158. case ValKind::ContinuationValue:
  159. llvm::errs() << line_num
  160. << ": internal error in ArgumentDeduction: expected type, "
  161. << "not value " << *param << "\n";
  162. exit(-1);
  163. }
  164. }
  165. auto Substitute(TypeEnv dict, const Value* type) -> const Value* {
  166. switch (type->tag()) {
  167. case ValKind::VariableType: {
  168. std::optional<const Value*> t = dict.Get(type->GetVariableType().name);
  169. if (!t) {
  170. return type;
  171. } else {
  172. return *t;
  173. }
  174. }
  175. case ValKind::TupleValue: {
  176. std::vector<TupleElement> elts;
  177. for (const auto& elt : type->GetTupleValue().elements) {
  178. auto t = Substitute(dict, elt.value);
  179. elts.push_back({.name = elt.name, .value = t});
  180. }
  181. return Value::MakeTupleValue(elts);
  182. }
  183. case ValKind::FunctionType: {
  184. auto param = Substitute(dict, type->GetFunctionType().param);
  185. auto ret = Substitute(dict, type->GetFunctionType().ret);
  186. return Value::MakeFunctionType({}, param, ret);
  187. }
  188. case ValKind::PointerType: {
  189. return Value::MakePointerType(
  190. Substitute(dict, type->GetPointerType().type));
  191. }
  192. case ValKind::AutoType:
  193. case ValKind::IntType:
  194. case ValKind::BoolType:
  195. case ValKind::TypeType:
  196. case ValKind::StructType:
  197. case ValKind::ChoiceType:
  198. case ValKind::ContinuationType:
  199. return type;
  200. // The rest of these cases should never happen.
  201. case ValKind::IntValue:
  202. case ValKind::BoolValue:
  203. case ValKind::FunctionValue:
  204. case ValKind::PointerValue:
  205. case ValKind::StructValue:
  206. case ValKind::AlternativeValue:
  207. case ValKind::BindingPlaceholderValue:
  208. case ValKind::AlternativeConstructorValue:
  209. case ValKind::ContinuationValue:
  210. llvm::errs() << "internal error in Substitute: expected type, "
  211. << "not value " << *type << "\n";
  212. exit(-1);
  213. }
  214. }
  215. // The TypeCheckExp function performs semantic analysis on an expression.
  216. // It returns a new version of the expression, its type, and an
  217. // updated environment which are bundled into a TCResult object.
  218. // The purpose of the updated environment is
  219. // to bring pattern variables into scope, for example, in a match case.
  220. // The new version of the expression may include more information,
  221. // for example, the type arguments deduced for the type parameters of a
  222. // generic.
  223. //
  224. // e is the expression to be analyzed.
  225. // types maps variable names to the type of their run-time value.
  226. // values maps variable names to their compile-time values. It is not
  227. // directly used in this function but is passed to InterExp.
  228. auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
  229. -> TCExpression {
  230. if (tracing_output) {
  231. llvm::outs() << "checking expression " << *e << "\n";
  232. }
  233. switch (e->tag()) {
  234. case ExpressionKind::IndexExpression: {
  235. auto res = TypeCheckExp(e->GetIndexExpression().aggregate, types, values);
  236. auto t = res.type;
  237. switch (t->tag()) {
  238. case ValKind::TupleValue: {
  239. auto i =
  240. InterpExp(values, e->GetIndexExpression().offset)->GetIntValue();
  241. std::string f = std::to_string(i);
  242. const Value* field_t = t->GetTupleValue().FindField(f);
  243. if (field_t == nullptr) {
  244. FATAL_COMPILATION_ERROR(e->line_num)
  245. << "field " << f << " is not in the tuple " << *t;
  246. }
  247. auto new_e = Expression::MakeIndexExpression(
  248. e->line_num, res.exp, Expression::MakeIntLiteral(e->line_num, i));
  249. return TCExpression(new_e, field_t, res.types);
  250. }
  251. default:
  252. FATAL_COMPILATION_ERROR(e->line_num) << "expected a tuple";
  253. }
  254. }
  255. case ExpressionKind::TupleLiteral: {
  256. std::vector<FieldInitializer> new_args;
  257. std::vector<TupleElement> arg_types;
  258. auto new_types = types;
  259. int i = 0;
  260. for (auto arg = e->GetTupleLiteral().fields.begin();
  261. arg != e->GetTupleLiteral().fields.end(); ++arg, ++i) {
  262. auto arg_res = TypeCheckExp(arg->expression, new_types, values);
  263. new_types = arg_res.types;
  264. new_args.push_back(FieldInitializer(arg->name, arg_res.exp));
  265. arg_types.push_back({.name = arg->name, .value = arg_res.type});
  266. }
  267. auto tuple_e = Expression::MakeTupleLiteral(e->line_num, new_args);
  268. auto tuple_t = Value::MakeTupleValue(std::move(arg_types));
  269. return TCExpression(tuple_e, tuple_t, new_types);
  270. }
  271. case ExpressionKind::FieldAccessExpression: {
  272. auto res =
  273. TypeCheckExp(e->GetFieldAccessExpression().aggregate, types, values);
  274. auto t = res.type;
  275. switch (t->tag()) {
  276. case ValKind::StructType:
  277. // Search for a field
  278. for (auto& field : t->GetStructType().fields) {
  279. if (e->GetFieldAccessExpression().field == field.first) {
  280. const Expression* new_e = Expression::MakeFieldAccessExpression(
  281. e->line_num, res.exp, e->GetFieldAccessExpression().field);
  282. return TCExpression(new_e, field.second, res.types);
  283. }
  284. }
  285. // Search for a method
  286. for (auto& method : t->GetStructType().methods) {
  287. if (e->GetFieldAccessExpression().field == method.first) {
  288. const Expression* new_e = Expression::MakeFieldAccessExpression(
  289. e->line_num, res.exp, e->GetFieldAccessExpression().field);
  290. return TCExpression(new_e, method.second, res.types);
  291. }
  292. }
  293. FATAL_COMPILATION_ERROR(e->line_num)
  294. << "struct " << t->GetStructType().name
  295. << " does not have a field named "
  296. << e->GetFieldAccessExpression().field;
  297. case ValKind::TupleValue:
  298. for (const TupleElement& field : t->GetTupleValue().elements) {
  299. if (e->GetFieldAccessExpression().field == field.name) {
  300. auto new_e = Expression::MakeFieldAccessExpression(
  301. e->line_num, res.exp, e->GetFieldAccessExpression().field);
  302. return TCExpression(new_e, field.value, res.types);
  303. }
  304. }
  305. FATAL_COMPILATION_ERROR(e->line_num)
  306. << "struct " << t->GetStructType().name
  307. << " does not have a field named "
  308. << e->GetFieldAccessExpression().field;
  309. case ValKind::ChoiceType:
  310. for (auto vt = t->GetChoiceType().alternatives.begin();
  311. vt != t->GetChoiceType().alternatives.end(); ++vt) {
  312. if (e->GetFieldAccessExpression().field == vt->first) {
  313. const Expression* new_e = Expression::MakeFieldAccessExpression(
  314. e->line_num, res.exp, e->GetFieldAccessExpression().field);
  315. auto fun_ty = Value::MakeFunctionType({}, vt->second, t);
  316. return TCExpression(new_e, fun_ty, res.types);
  317. }
  318. }
  319. FATAL_COMPILATION_ERROR(e->line_num)
  320. << "struct " << t->GetStructType().name
  321. << " does not have a field named "
  322. << e->GetFieldAccessExpression().field;
  323. default:
  324. FATAL_COMPILATION_ERROR(e->line_num)
  325. << "field access, expected a struct\n"
  326. << *e;
  327. }
  328. }
  329. case ExpressionKind::IdentifierExpression: {
  330. std::optional<const Value*> type =
  331. types.Get(e->GetIdentifierExpression().name);
  332. if (type) {
  333. return TCExpression(e, *type, types);
  334. } else {
  335. FATAL_COMPILATION_ERROR(e->line_num)
  336. << "could not find `" << e->GetIdentifierExpression().name << "`";
  337. }
  338. }
  339. case ExpressionKind::IntLiteral:
  340. return TCExpression(e, Value::MakeIntType(), types);
  341. case ExpressionKind::BoolLiteral:
  342. return TCExpression(e, Value::MakeBoolType(), types);
  343. case ExpressionKind::PrimitiveOperatorExpression: {
  344. std::vector<const Expression*> es;
  345. std::vector<const Value*> ts;
  346. auto new_types = types;
  347. for (const Expression* argument :
  348. e->GetPrimitiveOperatorExpression().arguments) {
  349. auto res = TypeCheckExp(argument, types, values);
  350. new_types = res.types;
  351. es.push_back(res.exp);
  352. ts.push_back(res.type);
  353. }
  354. auto new_e = Expression::MakePrimitiveOperatorExpression(
  355. e->line_num, e->GetPrimitiveOperatorExpression().op, es);
  356. switch (e->GetPrimitiveOperatorExpression().op) {
  357. case Operator::Neg:
  358. ExpectType(e->line_num, "negation", Value::MakeIntType(), ts[0]);
  359. return TCExpression(new_e, Value::MakeIntType(), new_types);
  360. case Operator::Add:
  361. ExpectType(e->line_num, "addition(1)", Value::MakeIntType(), ts[0]);
  362. ExpectType(e->line_num, "addition(2)", Value::MakeIntType(), ts[1]);
  363. return TCExpression(new_e, Value::MakeIntType(), new_types);
  364. case Operator::Sub:
  365. ExpectType(e->line_num, "subtraction(1)", Value::MakeIntType(),
  366. ts[0]);
  367. ExpectType(e->line_num, "subtraction(2)", Value::MakeIntType(),
  368. ts[1]);
  369. return TCExpression(new_e, Value::MakeIntType(), new_types);
  370. case Operator::Mul:
  371. ExpectType(e->line_num, "multiplication(1)", Value::MakeIntType(),
  372. ts[0]);
  373. ExpectType(e->line_num, "multiplication(2)", Value::MakeIntType(),
  374. ts[1]);
  375. return TCExpression(new_e, Value::MakeIntType(), new_types);
  376. case Operator::And:
  377. ExpectType(e->line_num, "&&(1)", Value::MakeBoolType(), ts[0]);
  378. ExpectType(e->line_num, "&&(2)", Value::MakeBoolType(), ts[1]);
  379. return TCExpression(new_e, Value::MakeBoolType(), new_types);
  380. case Operator::Or:
  381. ExpectType(e->line_num, "||(1)", Value::MakeBoolType(), ts[0]);
  382. ExpectType(e->line_num, "||(2)", Value::MakeBoolType(), ts[1]);
  383. return TCExpression(new_e, Value::MakeBoolType(), new_types);
  384. case Operator::Not:
  385. ExpectType(e->line_num, "!", Value::MakeBoolType(), ts[0]);
  386. return TCExpression(new_e, Value::MakeBoolType(), new_types);
  387. case Operator::Eq:
  388. ExpectType(e->line_num, "==", ts[0], ts[1]);
  389. return TCExpression(new_e, Value::MakeBoolType(), new_types);
  390. case Operator::Deref:
  391. ExpectPointerType(e->line_num, "*", ts[0]);
  392. return TCExpression(new_e, ts[0]->GetPointerType().type, new_types);
  393. case Operator::Ptr:
  394. ExpectType(e->line_num, "*", Value::MakeTypeType(), ts[0]);
  395. return TCExpression(new_e, Value::MakeTypeType(), new_types);
  396. }
  397. break;
  398. }
  399. case ExpressionKind::CallExpression: {
  400. auto fun_res =
  401. TypeCheckExp(e->GetCallExpression().function, types, values);
  402. switch (fun_res.type->tag()) {
  403. case ValKind::FunctionType: {
  404. auto fun_t = fun_res.type;
  405. auto arg_res = TypeCheckExp(e->GetCallExpression().argument,
  406. fun_res.types, values);
  407. auto parameter_type = fun_t->GetFunctionType().param;
  408. auto return_type = fun_t->GetFunctionType().ret;
  409. if (fun_t->GetFunctionType().deduced.size() > 0) {
  410. auto deduced_args = ArgumentDeduction(e->line_num, TypeEnv(),
  411. parameter_type, arg_res.type);
  412. for (auto& deduced_param : fun_t->GetFunctionType().deduced) {
  413. // TODO: change the following to a CHECK once the real checking
  414. // has been added to the type checking of function signatures.
  415. if (!deduced_args.Get(deduced_param.name)) {
  416. std::cerr << e->line_num
  417. << ": error, could not deduce type argument for type "
  418. "parameter "
  419. << deduced_param.name << std::endl;
  420. exit(-1);
  421. }
  422. }
  423. parameter_type = Substitute(deduced_args, parameter_type);
  424. return_type = Substitute(deduced_args, return_type);
  425. } else {
  426. ExpectType(e->line_num, "call", parameter_type, arg_res.type);
  427. }
  428. auto new_e = Expression::MakeCallExpression(e->line_num, fun_res.exp,
  429. arg_res.exp);
  430. return TCExpression(new_e, return_type, arg_res.types);
  431. }
  432. default: {
  433. FATAL_COMPILATION_ERROR(e->line_num)
  434. << "in call, expected a function\n"
  435. << *e;
  436. }
  437. }
  438. break;
  439. }
  440. case ExpressionKind::FunctionTypeLiteral: {
  441. auto pt = InterpExp(values, e->GetFunctionTypeLiteral().parameter);
  442. auto rt = InterpExp(values, e->GetFunctionTypeLiteral().return_type);
  443. auto new_e = Expression::MakeFunctionTypeLiteral(
  444. e->line_num, ReifyType(pt, e->line_num), ReifyType(rt, e->line_num));
  445. return TCExpression(new_e, Value::MakeTypeType(), types);
  446. }
  447. case ExpressionKind::IntTypeLiteral:
  448. return TCExpression(e, Value::MakeTypeType(), types);
  449. case ExpressionKind::BoolTypeLiteral:
  450. return TCExpression(e, Value::MakeTypeType(), types);
  451. case ExpressionKind::TypeTypeLiteral:
  452. return TCExpression(e, Value::MakeTypeType(), types);
  453. case ExpressionKind::ContinuationTypeLiteral:
  454. return TCExpression(e, Value::MakeTypeType(), types);
  455. }
  456. }
  457. // Equivalent to TypeCheckExp, but operates on Patterns instead of Expressions.
  458. // `expected` is the type that this pattern is expected to have, if the
  459. // surrounding context gives us that information. Otherwise, it is null.
  460. auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
  461. const Value* expected) -> TCPattern {
  462. if (tracing_output) {
  463. llvm::outs() << "checking pattern, ";
  464. if (expected) {
  465. llvm::outs() << "expecting " << *expected;
  466. }
  467. llvm::outs() << ", " << *p << "\n";
  468. }
  469. switch (p->Tag()) {
  470. case Pattern::Kind::AutoPattern: {
  471. return {.pattern = p, .type = Value::MakeTypeType(), .types = types};
  472. }
  473. case Pattern::Kind::BindingPattern: {
  474. const auto& binding = cast<BindingPattern>(*p);
  475. const Value* type;
  476. switch (binding.Type()->Tag()) {
  477. case Pattern::Kind::AutoPattern: {
  478. if (expected == nullptr) {
  479. FATAL_COMPILATION_ERROR(binding.LineNumber())
  480. << "auto not allowed here";
  481. } else {
  482. type = expected;
  483. }
  484. break;
  485. }
  486. case Pattern::Kind::ExpressionPattern: {
  487. type = InterpExp(
  488. values, cast<ExpressionPattern>(binding.Type())->Expression());
  489. CHECK(type->tag() != ValKind::AutoType);
  490. if (expected != nullptr) {
  491. ExpectType(binding.LineNumber(), "pattern variable", type,
  492. expected);
  493. }
  494. break;
  495. }
  496. case Pattern::Kind::TuplePattern:
  497. case Pattern::Kind::BindingPattern:
  498. case Pattern::Kind::AlternativePattern:
  499. FATAL_COMPILATION_ERROR(binding.LineNumber())
  500. << "Unsupported type pattern";
  501. }
  502. auto new_p = new BindingPattern(
  503. binding.LineNumber(), binding.Name(),
  504. new ExpressionPattern(ReifyType(type, binding.LineNumber())));
  505. if (binding.Name().has_value()) {
  506. types.Set(*binding.Name(), type);
  507. }
  508. return {.pattern = new_p, .type = type, .types = types};
  509. }
  510. case Pattern::Kind::TuplePattern: {
  511. const auto& tuple = cast<TuplePattern>(*p);
  512. std::vector<TuplePattern::Field> new_fields;
  513. std::vector<TupleElement> field_types;
  514. auto new_types = types;
  515. if (expected && expected->tag() != ValKind::TupleValue) {
  516. FATAL_COMPILATION_ERROR(p->LineNumber()) << "didn't expect a tuple";
  517. }
  518. if (expected &&
  519. tuple.Fields().size() != expected->GetTupleValue().elements.size()) {
  520. FATAL_COMPILATION_ERROR(tuple.LineNumber())
  521. << "tuples of different length";
  522. }
  523. for (size_t i = 0; i < tuple.Fields().size(); ++i) {
  524. const TuplePattern::Field& field = tuple.Fields()[i];
  525. const Value* expected_field_type = nullptr;
  526. if (expected != nullptr) {
  527. const TupleElement& expected_element =
  528. expected->GetTupleValue().elements[i];
  529. if (expected_element.name != field.name) {
  530. FATAL_COMPILATION_ERROR(tuple.LineNumber())
  531. << "field names do not match, expected "
  532. << expected_element.name << " but got " << field.name;
  533. }
  534. expected_field_type = expected_element.value;
  535. }
  536. auto field_result = TypeCheckPattern(field.pattern, new_types, values,
  537. expected_field_type);
  538. new_types = field_result.types;
  539. new_fields.push_back(
  540. TuplePattern::Field(field.name, field_result.pattern));
  541. field_types.push_back({.name = field.name, .value = field_result.type});
  542. }
  543. auto new_tuple = new TuplePattern(tuple.LineNumber(), new_fields);
  544. auto tuple_t = Value::MakeTupleValue(std::move(field_types));
  545. return {.pattern = new_tuple, .type = tuple_t, .types = new_types};
  546. }
  547. case Pattern::Kind::AlternativePattern: {
  548. const auto& alternative = cast<AlternativePattern>(*p);
  549. const Value* choice_type = InterpExp(values, alternative.ChoiceType());
  550. if (choice_type->tag() != ValKind::ChoiceType) {
  551. FATAL_COMPILATION_ERROR(alternative.LineNumber())
  552. << "alternative pattern does not name a choice type.";
  553. }
  554. if (expected != nullptr) {
  555. ExpectType(alternative.LineNumber(), "alternative pattern", expected,
  556. choice_type);
  557. }
  558. const Value* parameter_types =
  559. FindInVarValues(alternative.AlternativeName(),
  560. choice_type->GetChoiceType().alternatives);
  561. if (parameter_types == nullptr) {
  562. FATAL_COMPILATION_ERROR(alternative.LineNumber())
  563. << "'" << alternative.AlternativeName()
  564. << "' is not an alternative of " << choice_type;
  565. }
  566. TCPattern arg_results = TypeCheckPattern(alternative.Arguments(), types,
  567. values, parameter_types);
  568. return {.pattern = new AlternativePattern(
  569. alternative.LineNumber(),
  570. ReifyType(choice_type, alternative.LineNumber()),
  571. alternative.AlternativeName(),
  572. cast<TuplePattern>(arg_results.pattern)),
  573. .type = choice_type,
  574. .types = arg_results.types};
  575. }
  576. case Pattern::Kind::ExpressionPattern: {
  577. TCExpression result =
  578. TypeCheckExp(cast<ExpressionPattern>(p)->Expression(), types, values);
  579. return {.pattern = new ExpressionPattern(result.exp),
  580. .type = result.type,
  581. .types = result.types};
  582. }
  583. }
  584. }
  585. auto TypecheckCase(const Value* expected, const Pattern* pat,
  586. const Statement* body, TypeEnv types, Env values,
  587. const Value*& ret_type)
  588. -> std::pair<const Pattern*, const Statement*> {
  589. auto pat_res = TypeCheckPattern(pat, types, values, expected);
  590. auto res = TypeCheckStmt(body, pat_res.types, values, ret_type);
  591. return std::make_pair(pat, res.stmt);
  592. }
  593. // The TypeCheckStmt function performs semantic analysis on a statement.
  594. // It returns a new version of the statement and a new type environment.
  595. //
  596. // The ret_type parameter is used for analyzing return statements.
  597. // It is the declared return type of the enclosing function definition.
  598. // If the return type is "auto", then the return type is inferred from
  599. // the first return statement.
  600. auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
  601. const Value*& ret_type) -> TCStatement {
  602. if (!s) {
  603. return TCStatement(s, types);
  604. }
  605. switch (s->tag()) {
  606. case StatementKind::Match: {
  607. auto res = TypeCheckExp(s->GetMatch().exp, types, values);
  608. auto res_type = res.type;
  609. auto new_clauses =
  610. new std::list<std::pair<const Pattern*, const Statement*>>();
  611. for (auto& clause : *s->GetMatch().clauses) {
  612. new_clauses->push_back(TypecheckCase(
  613. res_type, clause.first, clause.second, types, values, ret_type));
  614. }
  615. const Statement* new_s =
  616. Statement::MakeMatch(s->line_num, res.exp, new_clauses);
  617. return TCStatement(new_s, types);
  618. }
  619. case StatementKind::While: {
  620. auto cnd_res = TypeCheckExp(s->GetWhile().cond, types, values);
  621. ExpectType(s->line_num, "condition of `while`", Value::MakeBoolType(),
  622. cnd_res.type);
  623. auto body_res =
  624. TypeCheckStmt(s->GetWhile().body, types, values, ret_type);
  625. auto new_s =
  626. Statement::MakeWhile(s->line_num, cnd_res.exp, body_res.stmt);
  627. return TCStatement(new_s, types);
  628. }
  629. case StatementKind::Break:
  630. case StatementKind::Continue:
  631. return TCStatement(s, types);
  632. case StatementKind::Block: {
  633. auto stmt_res =
  634. TypeCheckStmt(s->GetBlock().stmt, types, values, ret_type);
  635. return TCStatement(Statement::MakeBlock(s->line_num, stmt_res.stmt),
  636. types);
  637. }
  638. case StatementKind::VariableDefinition: {
  639. auto res = TypeCheckExp(s->GetVariableDefinition().init, types, values);
  640. const Value* rhs_ty = res.type;
  641. auto lhs_res = TypeCheckPattern(s->GetVariableDefinition().pat, types,
  642. values, rhs_ty);
  643. const Statement* new_s = Statement::MakeVariableDefinition(
  644. s->line_num, s->GetVariableDefinition().pat, res.exp);
  645. return TCStatement(new_s, lhs_res.types);
  646. }
  647. case StatementKind::Sequence: {
  648. auto stmt_res =
  649. TypeCheckStmt(s->GetSequence().stmt, types, values, ret_type);
  650. auto types2 = stmt_res.types;
  651. auto next_res =
  652. TypeCheckStmt(s->GetSequence().next, types2, values, ret_type);
  653. auto types3 = next_res.types;
  654. return TCStatement(
  655. Statement::MakeSequence(s->line_num, stmt_res.stmt, next_res.stmt),
  656. types3);
  657. }
  658. case StatementKind::Assign: {
  659. auto rhs_res = TypeCheckExp(s->GetAssign().rhs, types, values);
  660. auto rhs_t = rhs_res.type;
  661. auto lhs_res = TypeCheckExp(s->GetAssign().lhs, types, values);
  662. auto lhs_t = lhs_res.type;
  663. ExpectType(s->line_num, "assign", lhs_t, rhs_t);
  664. auto new_s = Statement::MakeAssign(s->line_num, lhs_res.exp, rhs_res.exp);
  665. return TCStatement(new_s, lhs_res.types);
  666. }
  667. case StatementKind::ExpressionStatement: {
  668. auto res = TypeCheckExp(s->GetExpressionStatement().exp, types, values);
  669. auto new_s = Statement::MakeExpressionStatement(s->line_num, res.exp);
  670. return TCStatement(new_s, types);
  671. }
  672. case StatementKind::If: {
  673. auto cnd_res = TypeCheckExp(s->GetIf().cond, types, values);
  674. ExpectType(s->line_num, "condition of `if`", Value::MakeBoolType(),
  675. cnd_res.type);
  676. auto thn_res =
  677. TypeCheckStmt(s->GetIf().then_stmt, types, values, ret_type);
  678. auto els_res =
  679. TypeCheckStmt(s->GetIf().else_stmt, types, values, ret_type);
  680. auto new_s = Statement::MakeIf(s->line_num, cnd_res.exp, thn_res.stmt,
  681. els_res.stmt);
  682. return TCStatement(new_s, types);
  683. }
  684. case StatementKind::Return: {
  685. auto res = TypeCheckExp(s->GetReturn().exp, types, values);
  686. if (ret_type->tag() == ValKind::AutoType) {
  687. // The following infers the return type from the first 'return'
  688. // statement. This will get more difficult with subtyping, when we
  689. // should infer the least-upper bound of all the 'return' statements.
  690. ret_type = res.type;
  691. } else {
  692. ExpectType(s->line_num, "return", ret_type, res.type);
  693. }
  694. return TCStatement(Statement::MakeReturn(s->line_num, res.exp), types);
  695. }
  696. case StatementKind::Continuation: {
  697. TCStatement body_result =
  698. TypeCheckStmt(s->GetContinuation().body, types, values, ret_type);
  699. const Statement* new_continuation = Statement::MakeContinuation(
  700. s->line_num, s->GetContinuation().continuation_variable,
  701. body_result.stmt);
  702. types.Set(s->GetContinuation().continuation_variable,
  703. Value::MakeContinuationType());
  704. return TCStatement(new_continuation, types);
  705. }
  706. case StatementKind::Run: {
  707. TCExpression argument_result =
  708. TypeCheckExp(s->GetRun().argument, types, values);
  709. ExpectType(s->line_num, "argument of `run`",
  710. Value::MakeContinuationType(), argument_result.type);
  711. const Statement* new_run =
  712. Statement::MakeRun(s->line_num, argument_result.exp);
  713. return TCStatement(new_run, types);
  714. }
  715. case StatementKind::Await: {
  716. // nothing to do here
  717. return TCStatement(s, types);
  718. }
  719. } // switch
  720. }
  721. auto CheckOrEnsureReturn(const Statement* stmt, bool void_return, int line_num)
  722. -> const Statement* {
  723. if (!stmt) {
  724. if (void_return) {
  725. return Statement::MakeReturn(line_num,
  726. Expression::MakeTupleLiteral(line_num, {}));
  727. } else {
  728. FATAL_COMPILATION_ERROR(line_num)
  729. << "control-flow reaches end of non-void function without a return";
  730. }
  731. }
  732. switch (stmt->tag()) {
  733. case StatementKind::Match: {
  734. auto new_clauses =
  735. new std::list<std::pair<const Pattern*, const Statement*>>();
  736. for (auto i = stmt->GetMatch().clauses->begin();
  737. i != stmt->GetMatch().clauses->end(); ++i) {
  738. auto s = CheckOrEnsureReturn(i->second, void_return, stmt->line_num);
  739. new_clauses->push_back(std::make_pair(i->first, s));
  740. }
  741. return Statement::MakeMatch(stmt->line_num, stmt->GetMatch().exp,
  742. new_clauses);
  743. }
  744. case StatementKind::Block:
  745. return Statement::MakeBlock(
  746. stmt->line_num, CheckOrEnsureReturn(stmt->GetBlock().stmt,
  747. void_return, stmt->line_num));
  748. case StatementKind::If:
  749. return Statement::MakeIf(
  750. stmt->line_num, stmt->GetIf().cond,
  751. CheckOrEnsureReturn(stmt->GetIf().then_stmt, void_return,
  752. stmt->line_num),
  753. CheckOrEnsureReturn(stmt->GetIf().else_stmt, void_return,
  754. stmt->line_num));
  755. case StatementKind::Return:
  756. return stmt;
  757. case StatementKind::Sequence:
  758. if (stmt->GetSequence().next) {
  759. return Statement::MakeSequence(
  760. stmt->line_num, stmt->GetSequence().stmt,
  761. CheckOrEnsureReturn(stmt->GetSequence().next, void_return,
  762. stmt->line_num));
  763. } else {
  764. return CheckOrEnsureReturn(stmt->GetSequence().stmt, void_return,
  765. stmt->line_num);
  766. }
  767. case StatementKind::Continuation:
  768. case StatementKind::Run:
  769. case StatementKind::Await:
  770. return stmt;
  771. case StatementKind::Assign:
  772. case StatementKind::ExpressionStatement:
  773. case StatementKind::While:
  774. case StatementKind::Break:
  775. case StatementKind::Continue:
  776. case StatementKind::VariableDefinition:
  777. if (void_return) {
  778. return Statement::MakeSequence(
  779. stmt->line_num, stmt,
  780. Statement::MakeReturn(stmt->line_num, Expression::MakeTupleLiteral(
  781. stmt->line_num, {})));
  782. } else {
  783. FATAL_COMPILATION_ERROR(stmt->line_num)
  784. << "control-flow reaches end of non-void function without a return";
  785. }
  786. }
  787. }
  788. // TODO: factor common parts of TypeCheckFunDef and TypeOfFunDef into
  789. // a function.
  790. // TODO: Add checking to function definitions to ensure that
  791. // all deduced type parameters will be deduced.
  792. auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types, Env values)
  793. -> struct FunctionDefinition* {
  794. // Bring the deduced parameters into scope
  795. for (const auto& deduced : f->deduced_parameters) {
  796. // auto t = InterpExp(values, deduced.type);
  797. Address a =
  798. state->heap.AllocateValue(Value::MakeVariableType(deduced.name));
  799. values.Set(deduced.name, a);
  800. }
  801. // Type check the parameter pattern
  802. auto param_res = TypeCheckPattern(f->param_pattern, types, values, nullptr);
  803. // Evaluate the return type expression
  804. auto return_type = InterpPattern(values, f->return_type);
  805. if (f->name == "main") {
  806. ExpectType(f->line_num, "return type of `main`", Value::MakeIntType(),
  807. return_type);
  808. // TODO: Check that main doesn't have any parameters.
  809. }
  810. auto res = TypeCheckStmt(f->body, param_res.types, values, return_type);
  811. bool void_return = TypeEqual(return_type, Value::MakeUnitTypeVal());
  812. auto body = CheckOrEnsureReturn(res.stmt, void_return, f->line_num);
  813. return new FunctionDefinition(
  814. f->line_num, f->name, f->deduced_parameters, f->param_pattern,
  815. new ExpressionPattern(ReifyType(return_type, f->line_num)), body);
  816. }
  817. auto TypeOfFunDef(TypeEnv types, Env values, const FunctionDefinition* fun_def)
  818. -> const Value* {
  819. // Bring the deduced parameters into scope
  820. for (const auto& deduced : fun_def->deduced_parameters) {
  821. // auto t = InterpExp(values, deduced.type);
  822. Address a =
  823. state->heap.AllocateValue(Value::MakeVariableType(deduced.name));
  824. values.Set(deduced.name, a);
  825. }
  826. // Type check the parameter pattern
  827. auto param_res =
  828. TypeCheckPattern(fun_def->param_pattern, types, values, nullptr);
  829. // Evaluate the return type expression
  830. auto ret = InterpPattern(values, fun_def->return_type);
  831. if (ret->tag() == ValKind::AutoType) {
  832. auto f = TypeCheckFunDef(fun_def, types, values);
  833. ret = InterpPattern(values, f->return_type);
  834. }
  835. return Value::MakeFunctionType(fun_def->deduced_parameters, param_res.type,
  836. ret);
  837. }
  838. auto TypeOfStructDef(const StructDefinition* sd, TypeEnv /*types*/, Env ct_top)
  839. -> const Value* {
  840. VarValues fields;
  841. VarValues methods;
  842. for (const Member* m : sd->members) {
  843. switch (m->tag()) {
  844. case MemberKind::FieldMember: {
  845. const BindingPattern* binding = m->GetFieldMember().binding;
  846. if (!binding->Name().has_value()) {
  847. FATAL_COMPILATION_ERROR(binding->LineNumber())
  848. << "Struct members must have names";
  849. }
  850. const Expression* type_expression =
  851. dyn_cast<ExpressionPattern>(binding->Type())->Expression();
  852. if (type_expression == nullptr) {
  853. FATAL_COMPILATION_ERROR(binding->LineNumber())
  854. << "Struct members must have explicit types";
  855. }
  856. auto type = InterpExp(ct_top, type_expression);
  857. fields.push_back(std::make_pair(*binding->Name(), type));
  858. break;
  859. }
  860. }
  861. }
  862. return Value::MakeStructType(sd->name, std::move(fields), std::move(methods));
  863. }
  864. static auto GetName(const Declaration& d) -> const std::string& {
  865. switch (d.tag()) {
  866. case DeclarationKind::FunctionDeclaration:
  867. return d.GetFunctionDeclaration().definition.name;
  868. case DeclarationKind::StructDeclaration:
  869. return d.GetStructDeclaration().definition.name;
  870. case DeclarationKind::ChoiceDeclaration:
  871. return d.GetChoiceDeclaration().name;
  872. case DeclarationKind::VariableDeclaration: {
  873. const BindingPattern* binding = d.GetVariableDeclaration().binding;
  874. if (!binding->Name().has_value()) {
  875. FATAL_COMPILATION_ERROR(binding->LineNumber())
  876. << "Top-level variable declarations must have names";
  877. }
  878. return *binding->Name();
  879. }
  880. }
  881. }
  882. auto MakeTypeChecked(const Declaration& d, const TypeEnv& types,
  883. const Env& values) -> Declaration {
  884. switch (d.tag()) {
  885. case DeclarationKind::FunctionDeclaration:
  886. return Declaration::MakeFunctionDeclaration(*TypeCheckFunDef(
  887. &d.GetFunctionDeclaration().definition, types, values));
  888. case DeclarationKind::StructDeclaration: {
  889. const StructDefinition& struct_def = d.GetStructDeclaration().definition;
  890. std::list<Member*> fields;
  891. for (Member* m : struct_def.members) {
  892. switch (m->tag()) {
  893. case MemberKind::FieldMember:
  894. // TODO: Interpret the type expression and store the result.
  895. fields.push_back(m);
  896. break;
  897. }
  898. }
  899. return Declaration::MakeStructDeclaration(
  900. struct_def.line_num, struct_def.name, std::move(fields));
  901. }
  902. case DeclarationKind::ChoiceDeclaration:
  903. // TODO
  904. return d;
  905. case DeclarationKind::VariableDeclaration: {
  906. const auto& var = d.GetVariableDeclaration();
  907. // Signals a type error if the initializing expression does not have
  908. // the declared type of the variable, otherwise returns this
  909. // declaration with annotated types.
  910. TCExpression type_checked_initializer =
  911. TypeCheckExp(var.initializer, types, values);
  912. const Expression* type =
  913. dyn_cast<ExpressionPattern>(var.binding->Type())->Expression();
  914. if (type == nullptr) {
  915. // TODO: consider adding support for `auto`
  916. FATAL_COMPILATION_ERROR(var.source_location)
  917. << "Type of a top-level variable must be an expression.";
  918. }
  919. const Value* declared_type = InterpExp(values, type);
  920. ExpectType(var.source_location, "initializer of variable", declared_type,
  921. type_checked_initializer.type);
  922. return d;
  923. }
  924. }
  925. }
  926. static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
  927. switch (d.tag()) {
  928. case DeclarationKind::FunctionDeclaration: {
  929. const FunctionDefinition& func_def =
  930. d.GetFunctionDeclaration().definition;
  931. auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
  932. tops->types.Set(func_def.name, t);
  933. InitEnv(d, &tops->values);
  934. break;
  935. }
  936. case DeclarationKind::StructDeclaration: {
  937. const StructDefinition& struct_def = d.GetStructDeclaration().definition;
  938. auto st = TypeOfStructDef(&struct_def, tops->types, tops->values);
  939. Address a = state->heap.AllocateValue(st);
  940. tops->values.Set(struct_def.name, a); // Is this obsolete?
  941. std::vector<TupleElement> field_types;
  942. for (const auto& [field_name, field_value] : st->GetStructType().fields) {
  943. field_types.push_back({.name = field_name, .value = field_value});
  944. }
  945. auto fun_ty = Value::MakeFunctionType(
  946. {}, Value::MakeTupleValue(std::move(field_types)), st);
  947. tops->types.Set(struct_def.name, fun_ty);
  948. break;
  949. }
  950. case DeclarationKind::ChoiceDeclaration: {
  951. const auto& choice = d.GetChoiceDeclaration();
  952. VarValues alts;
  953. for (const auto& [name, signature] : choice.alternatives) {
  954. auto t = InterpExp(tops->values, signature);
  955. alts.push_back(std::make_pair(name, t));
  956. }
  957. auto ct = Value::MakeChoiceType(choice.name, std::move(alts));
  958. Address a = state->heap.AllocateValue(ct);
  959. tops->values.Set(choice.name, a); // Is this obsolete?
  960. tops->types.Set(choice.name, ct);
  961. break;
  962. }
  963. case DeclarationKind::VariableDeclaration: {
  964. const auto& var = d.GetVariableDeclaration();
  965. // Associate the variable name with it's declared type in the
  966. // compile-time symbol table.
  967. const Expression* type =
  968. cast<ExpressionPattern>(var.binding->Type())->Expression();
  969. const Value* declared_type = InterpExp(tops->values, type);
  970. tops->types.Set(*var.binding->Name(), declared_type);
  971. break;
  972. }
  973. }
  974. }
  975. auto TopLevel(std::list<Declaration>* fs) -> TypeCheckContext {
  976. TypeCheckContext tops;
  977. bool found_main = false;
  978. for (auto const& d : *fs) {
  979. if (GetName(d) == "main") {
  980. found_main = true;
  981. }
  982. TopLevel(d, &tops);
  983. }
  984. if (found_main == false) {
  985. FATAL_COMPILATION_ERROR_NO_LINE()
  986. << "program must contain a function named `main`";
  987. }
  988. return tops;
  989. }
  990. } // namespace Carbon