lit_autoupdate_base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. #!/usr/bin/env python3
  2. """Updates the CHECK: lines in lit tests based on the AUTOUPDATE line."""
  3. __copyright__ = """
  4. Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  5. Exceptions. See /LICENSE for license information.
  6. SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  7. """
  8. from abc import ABC, abstractmethod
  9. import argparse
  10. from concurrent import futures
  11. import os
  12. from pathlib import Path
  13. import re
  14. import subprocess
  15. from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Set, Tuple
  16. # A prefix followed by a command to run for autoupdating checked output.
  17. AUTOUPDATE_MARKER = "// AUTOUPDATE"
  18. # Indicates no autoupdate is requested.
  19. NOAUTOUPDATE_MARKER = "// NOAUTOUPDATE"
  20. # Standard replacements normally done in lit.cfg.py.
  21. MERGE_OUTPUT = "./bazel-bin/bazel/testing/merge_output"
  22. class Tool(NamedTuple):
  23. build_target: str
  24. autoupdate_cmd: List[str]
  25. tools = {
  26. "carbon": Tool(
  27. "//toolchain/driver:carbon",
  28. [MERGE_OUTPUT, "./bazel-bin/toolchain/driver/carbon"],
  29. ),
  30. "explorer": Tool(
  31. "//explorer",
  32. [MERGE_OUTPUT, "./bazel-bin/explorer/explorer"],
  33. ),
  34. }
  35. class ParsedArgs(NamedTuple):
  36. autoupdate_args: List[str]
  37. build_mode: str
  38. extra_check_replacements: List[Tuple[Pattern, Pattern, str]]
  39. line_number_format: str
  40. line_number_pattern: Pattern
  41. lit_run: List[str]
  42. testdata: str
  43. tests: List[Path]
  44. tool: str
  45. def parse_args() -> ParsedArgs:
  46. """Parses command-line arguments and flags."""
  47. parser = argparse.ArgumentParser(description=__doc__)
  48. parser.add_argument("tests", nargs="*")
  49. parser.add_argument(
  50. "--autoupdate_arg",
  51. metavar="COMMAND",
  52. default=[],
  53. action="append",
  54. help="Optional arguments to pass to the autoupdate command.",
  55. )
  56. parser.add_argument(
  57. "--build_mode",
  58. metavar="MODE",
  59. default="opt",
  60. help="The build mode to use. Defaults to opt for faster execution.",
  61. )
  62. parser.add_argument(
  63. "--extra_check_replacement",
  64. nargs=3,
  65. metavar=("MATCHING", "BEFORE", "AFTER"),
  66. default=[],
  67. action="append",
  68. help="On a CHECK line with MATCHING, does a regex replacement of "
  69. "BEFORE with AFTER.",
  70. )
  71. parser.add_argument(
  72. "--line_number_format",
  73. metavar="FORMAT",
  74. default="[[@LINE%(delta)s]]",
  75. help="An optional format string for line number delta replacements.",
  76. )
  77. parser.add_argument(
  78. "--line_number_pattern",
  79. metavar="PATTERN",
  80. required=True,
  81. help="A regular expression which matches line numbers to update as its "
  82. "only group.",
  83. )
  84. parser.add_argument(
  85. "--lit_run",
  86. metavar="COMMAND",
  87. required=True,
  88. action="append",
  89. help="RUN lines to set.",
  90. )
  91. parser.add_argument(
  92. "--testdata",
  93. metavar="PATH",
  94. required=True,
  95. help="The path to the testdata to update, relative to the workspace "
  96. "root.",
  97. )
  98. parser.add_argument(
  99. "--tool",
  100. metavar="TOOL",
  101. required=True,
  102. choices=tools.keys(),
  103. help="The tool being tested.",
  104. )
  105. parsed_args = parser.parse_args()
  106. extra_check_replacements = [
  107. (re.compile(line_matcher), re.compile(before), after)
  108. for line_matcher, before, after in parsed_args.extra_check_replacement
  109. ]
  110. return ParsedArgs(
  111. autoupdate_args=parsed_args.autoupdate_arg,
  112. build_mode=parsed_args.build_mode,
  113. extra_check_replacements=extra_check_replacements,
  114. line_number_format=parsed_args.line_number_format,
  115. line_number_pattern=re.compile(parsed_args.line_number_pattern),
  116. lit_run=parsed_args.lit_run,
  117. testdata=parsed_args.testdata,
  118. tests=[Path(test).resolve() for test in parsed_args.tests],
  119. tool=parsed_args.tool,
  120. )
  121. def get_tests(testdata: str) -> Set[Path]:
  122. """Get the list of tests from the filesystem."""
  123. tests = set()
  124. for root, _, files in os.walk(testdata):
  125. for f in files:
  126. if f in {"lit.cfg.py", "BUILD"}:
  127. # Ignore the lit config.
  128. continue
  129. if os.path.splitext(f)[1] == ".carbon":
  130. tests.add(Path(root).joinpath(f))
  131. else:
  132. exit(f"Unrecognized file type in testdata: {f}")
  133. return tests
  134. class Line(ABC):
  135. """A line that may appear in the resulting test file."""
  136. @abstractmethod
  137. def format(
  138. self, *, output_line_number: int, line_number_remap: Dict[int, int]
  139. ) -> str:
  140. raise NotImplementedError
  141. class OriginalLine(Line):
  142. """A line that was copied from the original test file."""
  143. def __init__(self, line_number: int, text: str) -> None:
  144. self.line_number = line_number
  145. self.text = text
  146. def format(self, **kwargs: Any) -> str:
  147. return self.text
  148. class RunLine(Line):
  149. """A RUN line."""
  150. def __init__(self, text: str) -> None:
  151. self.text = text
  152. def format(self, **kwargs: Any) -> str:
  153. return self.text
  154. class CheckLine(Line):
  155. """A `// CHECK:` line generated from the test output.
  156. If there's a line number, it'll be fixed up after we've figured out which
  157. lines to include in the resulting test file and in what order, because
  158. their contents depend on where an original input line appears in the output.
  159. """
  160. def __init__(
  161. self,
  162. out_line: str,
  163. line_number_format: str,
  164. line_number_pattern: Pattern,
  165. ) -> None:
  166. super().__init__()
  167. self.indent = ""
  168. self.out_line = out_line.rstrip()
  169. self.line_number_format = line_number_format
  170. self.line_number_pattern = line_number_pattern
  171. self.line_numbers = [
  172. int(n) - 1 for n in line_number_pattern.findall(self.out_line)
  173. ]
  174. def format(
  175. self, *, output_line_number: int, line_number_remap: Dict[int, int]
  176. ) -> str:
  177. if not self.out_line:
  178. return f"{self.indent}// CHECK-EMPTY:\n"
  179. result = self.out_line
  180. for line_number in self.line_numbers:
  181. delta = line_number_remap[line_number] - output_line_number
  182. # We use `:+d` here to produce `LINE-n` or `LINE+n` as appropriate.
  183. result = self.line_number_pattern.sub(
  184. self.line_number_format % {"delta": f"{delta:+d}"},
  185. result,
  186. count=1,
  187. )
  188. return f"{self.indent}// CHECK:{result}\n"
  189. def find_autoupdate(test: str, orig_lines: List[str]) -> Optional[int]:
  190. """Figures out whether autoupdate should occur.
  191. For AUTOUPDATE, returns the line. For NOAUTOUPDATE, returns None.
  192. """
  193. found = 0
  194. result = None
  195. for line_number, line in enumerate(orig_lines):
  196. if line.startswith(AUTOUPDATE_MARKER):
  197. found += 1
  198. result = line_number
  199. elif line.startswith(NOAUTOUPDATE_MARKER):
  200. found += 1
  201. if found == 0:
  202. raise ValueError(
  203. f"{test} must have either '{AUTOUPDATE_MARKER}' or "
  204. f"'{NOAUTOUPDATE_MARKER}'"
  205. )
  206. elif found > 1:
  207. raise ValueError(
  208. f"{test} must have only one of '{AUTOUPDATE_MARKER}' or "
  209. f"'{NOAUTOUPDATE_MARKER}'"
  210. )
  211. return result
  212. def replace_all(s: str, replacements: List[Tuple[str, str]]) -> str:
  213. """Runs multiple replacements on a string."""
  214. for before, after in replacements:
  215. s = s.replace(before, after)
  216. return s
  217. def get_matchable_test_output(
  218. autoupdate_args: List[str],
  219. extra_check_replacements: List[Tuple[Pattern, Pattern, str]],
  220. tool: str,
  221. test: str,
  222. ) -> List[str]:
  223. """Runs the autoupdate command and returns the output lines."""
  224. # Run the autoupdate command to generate output.
  225. # (`bazel run` would serialize)
  226. p = subprocess.run(
  227. tools[tool].autoupdate_cmd + autoupdate_args + [test],
  228. stdout=subprocess.PIPE,
  229. stderr=subprocess.STDOUT,
  230. )
  231. out = p.stdout.decode("utf-8")
  232. # `lit` uses full paths to the test file, so use a regex to ignore paths
  233. # when used.
  234. out = replace_all(
  235. out,
  236. [
  237. ("{{", "{{[{][{]}}"),
  238. ("[[", "{{[[][[]}}"),
  239. # TODO: Maybe revisit and see if lit can be convinced to give a
  240. # root-relative path.
  241. (test, f"{{{{.*}}}}/{test}"),
  242. ],
  243. )
  244. out_lines = out.splitlines()
  245. for i, line in enumerate(out_lines):
  246. for line_matcher, before, after in extra_check_replacements:
  247. if line_matcher.match(line):
  248. out_lines[i] = before.sub(after, line)
  249. return out_lines
  250. def is_replaced(line: str) -> bool:
  251. """Returns true if autoupdate should replace the line."""
  252. line = line.lstrip()
  253. return line.startswith("// CHECK") or line.startswith("// RUN:")
  254. def merge_lines(
  255. line_number_format: str,
  256. line_number_pattern: Pattern,
  257. lit_run: List[str],
  258. test: str,
  259. autoupdate_line_number: int,
  260. raw_orig_lines: List[str],
  261. out_lines: List[str],
  262. ) -> List[Line]:
  263. """Merges the original output and new lines."""
  264. orig_lines = [
  265. OriginalLine(i, line)
  266. for i, line in enumerate(raw_orig_lines)
  267. if not is_replaced(line)
  268. ]
  269. check_lines = [
  270. CheckLine(out_line, line_number_format, line_number_pattern)
  271. for out_line in out_lines
  272. ]
  273. result_lines: List[Line] = []
  274. # CHECK lines must go after AUTOUPDATE.
  275. while orig_lines and orig_lines[0].line_number <= autoupdate_line_number:
  276. result_lines.append(orig_lines.pop(0))
  277. for line in lit_run:
  278. run_not = ""
  279. if Path(test).name.startswith("fail_"):
  280. run_not = "%{not} "
  281. result_lines.append(RunLine(f"// RUN: {run_not}{line}\n"))
  282. # Interleave the original lines and the CHECK: lines.
  283. while orig_lines and check_lines:
  284. # Original lines go first when the CHECK line is known and later.
  285. if (
  286. check_lines[0].line_numbers
  287. and check_lines[0].line_numbers[0] > orig_lines[0].line_number
  288. ):
  289. result_lines.append(orig_lines.pop(0))
  290. else:
  291. check_line = check_lines.pop(0)
  292. # Indent to match the next original line.
  293. check_line.indent = re.findall("^ *", orig_lines[0].text)[0]
  294. result_lines.append(check_line)
  295. # One list is non-empty; append remaining lines from both to catch it.
  296. result_lines.extend(orig_lines)
  297. result_lines.extend(check_lines)
  298. return result_lines
  299. def update_check(parsed_args: ParsedArgs, test: Path) -> bool:
  300. """Updates the CHECK: lines for `test` by running the tool.
  301. Returns true if a change was made.
  302. """
  303. with test.open() as f:
  304. orig_lines = f.readlines()
  305. # Make sure we're supposed to autoupdate.
  306. autoupdate_line = find_autoupdate(str(test), orig_lines)
  307. if autoupdate_line is None:
  308. return False
  309. # Determine the merged output lines.
  310. out_lines = get_matchable_test_output(
  311. parsed_args.autoupdate_args,
  312. parsed_args.extra_check_replacements,
  313. parsed_args.tool,
  314. str(test),
  315. )
  316. result_lines = merge_lines(
  317. parsed_args.line_number_format,
  318. parsed_args.line_number_pattern,
  319. parsed_args.lit_run,
  320. str(test),
  321. autoupdate_line,
  322. orig_lines,
  323. out_lines,
  324. )
  325. # Calculate the remap for original lines.
  326. line_number_remap = dict(
  327. [
  328. (line.line_number, i)
  329. for i, line in enumerate(result_lines)
  330. if isinstance(line, OriginalLine)
  331. ]
  332. )
  333. # If the last line of the original output was a CHECK, replace it with an
  334. # empty line.
  335. if orig_lines[-1].lstrip().startswith("// CHECK"):
  336. line_number_remap[len(orig_lines) - 1] = len(result_lines) - 1
  337. # Generate contents for any lines that depend on line numbers.
  338. formatted_result_lines = [
  339. line.format(output_line_number=i, line_number_remap=line_number_remap)
  340. for i, line in enumerate(result_lines)
  341. ]
  342. # If nothing's changed, we're done.
  343. if formatted_result_lines == orig_lines:
  344. return False
  345. # Interleave the new CHECK: lines with the tested content.
  346. with test.open("w") as f:
  347. f.writelines(formatted_result_lines)
  348. return True
  349. def update_checks(parsed_args: ParsedArgs, tests: Set[Path]) -> None:
  350. """Updates CHECK: lines in lit tests."""
  351. def map_helper(test: Path) -> bool:
  352. try:
  353. updated = update_check(parsed_args, test)
  354. except Exception as e:
  355. raise ValueError(f"Failed to update {test}") from e
  356. print(".", end="", flush=True)
  357. return updated
  358. print(f"Updating {len(tests)} lit test(s)...")
  359. with futures.ThreadPoolExecutor() as exec:
  360. # list() iterates in order to immediately propagate exceptions.
  361. results = list(exec.map(map_helper, tests))
  362. # Each update call indicates progress with a dot without a newline, so put a
  363. # newline to wrap.
  364. print(f"\nUpdated {results.count(True)} lit test(s).")
  365. def main() -> None:
  366. # Parse arguments relative to the working directory.
  367. parsed_args = parse_args()
  368. # Remaining script logic should be relative to the repository root.
  369. os.chdir(Path(__file__).parent.parent.parent)
  370. if parsed_args.tests:
  371. tests = set(parsed_args.tests)
  372. else:
  373. print("HINT: run `update_checks.py f1 f2 ...` to update specific tests")
  374. tests = get_tests(parsed_args.testdata)
  375. # Build inputs.
  376. print(f"Building {parsed_args.tool}...")
  377. subprocess.check_call(
  378. [
  379. "bazel",
  380. "build",
  381. "-c",
  382. parsed_args.build_mode,
  383. "//bazel/testing:merge_output",
  384. tools[parsed_args.tool].build_target,
  385. ]
  386. )
  387. # Run updates.
  388. update_checks(parsed_args, tests)
  389. if __name__ == "__main__":
  390. main()