autoupdate_testdata_base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. #!/usr/bin/env python3
  2. """Updates the CHECK: lines in tests with an AUTOUPDATE line."""
  3. __copyright__ = """
  4. Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  5. Exceptions. See /LICENSE for license information.
  6. SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  7. """
  8. from abc import ABC, abstractmethod
  9. import argparse
  10. from concurrent import futures
  11. import os
  12. from pathlib import Path
  13. import re
  14. import subprocess
  15. from typing import (
  16. Any,
  17. Dict,
  18. List,
  19. Match,
  20. NamedTuple,
  21. Optional,
  22. Pattern,
  23. Set,
  24. Tuple,
  25. )
  26. # A prefix followed by a command to run for autoupdating checked output.
  27. AUTOUPDATE_MARKER = "// AUTOUPDATE"
  28. # Indicates no autoupdate is requested.
  29. NOAUTOUPDATE_MARKER = "// NOAUTOUPDATE"
  30. # Supported tools.
  31. TOOLS = {
  32. "carbon": "//toolchain/driver:carbon",
  33. "explorer": "//explorer:explorer",
  34. }
  35. class ParsedArgs(NamedTuple):
  36. autoupdate_args: List[str]
  37. build_mode: str
  38. extra_check_replacements: List[Tuple[Pattern, Pattern, str]]
  39. line_number_delta_prefix: str
  40. line_number_pattern: Pattern
  41. lit_run: List[str]
  42. testdata: str
  43. tests: List[Path]
  44. tool: str
  45. def parse_args() -> ParsedArgs:
  46. """Parses command-line arguments and flags."""
  47. parser = argparse.ArgumentParser(description=__doc__)
  48. parser.add_argument("tests", nargs="*")
  49. parser.add_argument(
  50. "--autoupdate_arg",
  51. metavar="COMMAND",
  52. default=[],
  53. action="append",
  54. help="Optional arguments to pass to the autoupdate command.",
  55. )
  56. parser.add_argument(
  57. "--build_mode",
  58. metavar="MODE",
  59. default="opt",
  60. help="The build mode to use. Defaults to opt for faster execution.",
  61. )
  62. parser.add_argument(
  63. "--extra_check_replacement",
  64. nargs=3,
  65. metavar=("MATCHING", "BEFORE", "AFTER"),
  66. default=[],
  67. action="append",
  68. help="On a CHECK line with MATCHING, does a regex replacement of "
  69. "BEFORE with AFTER.",
  70. )
  71. parser.add_argument(
  72. "--line_number_delta_prefix",
  73. metavar="PREFIX",
  74. default="",
  75. help="An optional prefix to add before the [[@LINE+delta]] marker.",
  76. )
  77. parser.add_argument(
  78. "--line_number_pattern",
  79. metavar="PATTERN",
  80. default=r"(?P<prefix>/(?P<filename>\w+\.carbon):)"
  81. r"(?P<line>\d+)(?P<suffix>(?:\D|$))",
  82. help="A regular expression which matches line numbers to update as its "
  83. "only group. Capture groups 'prefix', 'line', and 'suffix' are "
  84. "required for structure. The 'filename' capture group is optional and "
  85. "should be provided when lines may belong to different files.",
  86. )
  87. parser.add_argument(
  88. "--lit_run",
  89. metavar="COMMAND",
  90. default=[],
  91. required=False,
  92. action="append",
  93. help="RUN lines to set.",
  94. )
  95. parser.add_argument(
  96. "--testdata",
  97. metavar="PATH",
  98. required=True,
  99. help="The path to the testdata to update, relative to the workspace "
  100. "root.",
  101. )
  102. parser.add_argument(
  103. "--tool",
  104. metavar="TOOL",
  105. required=True,
  106. choices=TOOLS.keys(),
  107. help="The tool being tested.",
  108. )
  109. parsed_args = parser.parse_args()
  110. extra_check_replacements = [
  111. (re.compile(line_matcher), re.compile(before), after)
  112. for line_matcher, before, after in parsed_args.extra_check_replacement
  113. ]
  114. return ParsedArgs(
  115. autoupdate_args=parsed_args.autoupdate_arg,
  116. build_mode=parsed_args.build_mode,
  117. extra_check_replacements=extra_check_replacements,
  118. line_number_delta_prefix=parsed_args.line_number_delta_prefix,
  119. line_number_pattern=re.compile(parsed_args.line_number_pattern),
  120. lit_run=parsed_args.lit_run,
  121. testdata=parsed_args.testdata,
  122. tests=[Path(test).resolve() for test in parsed_args.tests],
  123. tool=parsed_args.tool,
  124. )
  125. def get_tests(testdata: str) -> Set[Path]:
  126. """Get the list of tests from the filesystem."""
  127. tests = set()
  128. for root, _, files in os.walk(testdata):
  129. for f in files:
  130. if f in {"lit.cfg.py", "BUILD", "README.md"}:
  131. # Ignore the lit config.
  132. continue
  133. if os.path.splitext(f)[1] == ".carbon":
  134. tests.add(Path(root).joinpath(f))
  135. else:
  136. exit(f"Unrecognized file type in testdata: {f}")
  137. return tests
  138. class Line(ABC):
  139. """A line that may appear in the resulting test file."""
  140. @abstractmethod
  141. def format(
  142. self, *, output_line_number: int, line_number_remap: Dict[int, int]
  143. ) -> str:
  144. raise NotImplementedError
  145. class OriginalLine(Line):
  146. """A line that was copied from the original test file."""
  147. def __init__(self, line_number: int, text: str) -> None:
  148. self.line_number = line_number
  149. self.text = text
  150. def format(self, **kwargs: Any) -> str:
  151. return self.text
  152. class RunLine(Line):
  153. """A RUN line."""
  154. def __init__(self, text: str) -> None:
  155. self.text = text
  156. def format(self, **kwargs: Any) -> str:
  157. return self.text
  158. class CheckLine(Line):
  159. """A `// CHECK:` line generated from the test output.
  160. If there's a line number, it'll be fixed up after we've figured out which
  161. lines to include in the resulting test file and in what order, because
  162. their contents depend on where an original input line appears in the output.
  163. """
  164. def __init__(
  165. self,
  166. test: str,
  167. out_line: str,
  168. line_number_delta_prefix: str,
  169. line_number_pattern: Pattern,
  170. ) -> None:
  171. super().__init__()
  172. self.filename = Path(test).name
  173. self.indent = ""
  174. self.out_line = out_line.rstrip()
  175. self.line_number_delta_prefix = line_number_delta_prefix
  176. self.line_number_pattern = line_number_pattern
  177. # If any match is specific to this file, use the first matched line for
  178. # the location of the CHECK comment.
  179. self.line_in_file = None
  180. for match in line_number_pattern.finditer(self.out_line):
  181. if self._matches_filename(match):
  182. self.line_in_file = int(match.group("line")) - 1
  183. break
  184. def format(
  185. self, *, output_line_number: int, line_number_remap: Dict[int, int]
  186. ) -> str:
  187. assert self.out_line
  188. result = self.out_line
  189. while True:
  190. match = self.line_number_pattern.search(result)
  191. if not match:
  192. break
  193. if self._matches_filename(match):
  194. line_number = int(match.group("line")) - 1
  195. delta = line_number_remap[line_number] - output_line_number
  196. # We use `:+d` here to produce `LINE-n` or `LINE+n` as
  197. # appropriate.
  198. result = self.line_number_pattern.sub(
  199. rf"\g<prefix>{self.line_number_delta_prefix}"
  200. rf"[[@LINE{delta:+d}]]\g<suffix>",
  201. result,
  202. count=1,
  203. )
  204. else:
  205. result = self.line_number_pattern.sub(
  206. r"\g<prefix>{{.*}}\g<suffix>",
  207. result,
  208. count=1,
  209. )
  210. return f"{self.indent}// CHECK:{result}\n"
  211. def _matches_filename(self, match: Match) -> bool:
  212. return (
  213. "filename" not in match.groupdict()
  214. or match.group("filename") == self.filename
  215. )
  216. def find_autoupdate(test: str, orig_lines: List[str]) -> Optional[int]:
  217. """Figures out whether autoupdate should occur.
  218. For AUTOUPDATE, returns the line. For NOAUTOUPDATE, returns None.
  219. """
  220. found = 0
  221. result = None
  222. for line_number, line in enumerate(orig_lines):
  223. if line.startswith(AUTOUPDATE_MARKER):
  224. found += 1
  225. result = line_number
  226. elif line.startswith(NOAUTOUPDATE_MARKER):
  227. found += 1
  228. if found == 0:
  229. raise ValueError(
  230. f"{test} must have either '{AUTOUPDATE_MARKER}' or "
  231. f"'{NOAUTOUPDATE_MARKER}'"
  232. )
  233. elif found > 1:
  234. raise ValueError(
  235. f"{test} must have only one of '{AUTOUPDATE_MARKER}' or "
  236. f"'{NOAUTOUPDATE_MARKER}'"
  237. )
  238. return result
  239. def replace_all(s: str, replacements: List[Tuple[str, str]]) -> str:
  240. """Runs multiple replacements on a string."""
  241. for before, after in replacements:
  242. s = s.replace(before, after)
  243. return s
  244. def label_output(label: str, output: str) -> List[str]:
  245. """Merges output with labels.
  246. This mirrors label_output in lit_test/merge_output.py and should
  247. be kept in sync. They're separate in order to avoid a subprocess or import
  248. complexity.
  249. """
  250. result = []
  251. if output:
  252. for line in output.splitlines():
  253. result.append(" ".join(filter(None, (label, line))))
  254. return result
  255. def get_matchable_test_output(
  256. autoupdate_args: List[str],
  257. for_lit: bool,
  258. extra_check_replacements: List[Tuple[Pattern, Pattern, str]],
  259. tool: str,
  260. bazel_runfiles: Pattern,
  261. llvm_symbolizer: str,
  262. test: str,
  263. ) -> List[str]:
  264. """Runs the autoupdate command and returns the output lines."""
  265. # Run the autoupdate command to generate output.
  266. # (`bazel run` would serialize)
  267. autoupdate_cmd = TOOLS[tool].replace("//", "./bazel-bin/").replace(":", "/")
  268. p = subprocess.run(
  269. [autoupdate_cmd] + autoupdate_args + [test],
  270. env={"LLVM_SYMBOLIZER_PATH": llvm_symbolizer},
  271. stdout=subprocess.PIPE,
  272. stderr=subprocess.PIPE,
  273. encoding="utf-8",
  274. )
  275. out_lines = label_output("STDOUT:", p.stdout)
  276. out_lines.extend(label_output("STDERR:", p.stderr))
  277. for i, line in enumerate(out_lines):
  278. # Escape things that mirror FileCheck special characters.
  279. line = line.replace("{{", "{{[{][{]}}")
  280. line = line.replace("[[", "{{[[][[]}}")
  281. if for_lit:
  282. # `lit` uses full paths to the test file, so use a regex to ignore
  283. # paths when used.
  284. line = line.replace(test, f"{{{{.*}}}}/{test}")
  285. line = bazel_runfiles.sub("{{.*}}/", line)
  286. else:
  287. # When not using `lit`, the runfiles path is removed.
  288. line = bazel_runfiles.sub("", line)
  289. for line_matcher, before, after in extra_check_replacements:
  290. if line_matcher.match(line):
  291. line = before.sub(after, line)
  292. out_lines[i] = line
  293. return out_lines
  294. def is_replaced(line: str) -> bool:
  295. """Returns true if autoupdate should replace the line."""
  296. line = line.lstrip()
  297. return line.startswith("// CHECK") or line.startswith("// RUN:")
  298. def merge_lines(
  299. line_number_delta_prefix: str,
  300. line_number_pattern: Pattern,
  301. lit_run: List[str],
  302. test: str,
  303. autoupdate_line_number: int,
  304. raw_orig_lines: List[str],
  305. out_lines: List[str],
  306. ) -> List[Line]:
  307. """Merges the original output and new lines."""
  308. orig_lines = [
  309. OriginalLine(i, line)
  310. for i, line in enumerate(raw_orig_lines)
  311. if not is_replaced(line)
  312. ]
  313. check_lines = [
  314. CheckLine(test, out_line, line_number_delta_prefix, line_number_pattern)
  315. for out_line in out_lines
  316. ]
  317. result_lines: List[Line] = []
  318. # CHECK lines must go after AUTOUPDATE.
  319. while orig_lines and orig_lines[0].line_number <= autoupdate_line_number:
  320. result_lines.append(orig_lines.pop(0))
  321. for line in lit_run:
  322. run_not = ""
  323. if Path(test).name.startswith("fail_"):
  324. run_not = "%{not} "
  325. result_lines.append(RunLine(f"// RUN: {run_not}{line}\n"))
  326. # Interleave the original lines and the CHECK: lines.
  327. while orig_lines and check_lines:
  328. # Original lines go first when the CHECK line is known and later.
  329. if (
  330. check_lines[0].line_in_file is not None
  331. and check_lines[0].line_in_file > orig_lines[0].line_number
  332. ):
  333. result_lines.append(orig_lines.pop(0))
  334. else:
  335. check_line = check_lines.pop(0)
  336. # Indent to match the next original line.
  337. check_line.indent = re.findall("^ *", orig_lines[0].text)[0]
  338. result_lines.append(check_line)
  339. # One list is non-empty; append remaining lines from both to catch it.
  340. result_lines.extend(orig_lines)
  341. result_lines.extend(check_lines)
  342. return result_lines
  343. def update_check(
  344. parsed_args: ParsedArgs,
  345. bazel_runfiles: Pattern,
  346. llvm_symbolizer: str,
  347. test: Path,
  348. ) -> bool:
  349. """Updates the CHECK: lines for `test` by running the tool.
  350. Returns true if a change was made.
  351. """
  352. with test.open() as f:
  353. orig_lines = f.readlines()
  354. # Make sure we're supposed to autoupdate.
  355. autoupdate_line = find_autoupdate(str(test), orig_lines)
  356. if autoupdate_line is None:
  357. return False
  358. # Determine the merged output lines.
  359. out_lines = get_matchable_test_output(
  360. parsed_args.autoupdate_args,
  361. bool(parsed_args.lit_run),
  362. parsed_args.extra_check_replacements,
  363. parsed_args.tool,
  364. bazel_runfiles,
  365. llvm_symbolizer,
  366. str(test),
  367. )
  368. result_lines = merge_lines(
  369. parsed_args.line_number_delta_prefix,
  370. parsed_args.line_number_pattern,
  371. parsed_args.lit_run,
  372. str(test),
  373. autoupdate_line,
  374. orig_lines,
  375. out_lines,
  376. )
  377. # Calculate the remap for original lines.
  378. line_number_remap = dict(
  379. [
  380. (line.line_number, i)
  381. for i, line in enumerate(result_lines)
  382. if isinstance(line, OriginalLine)
  383. ]
  384. )
  385. # If the last line of the original output was a CHECK, replace it with an
  386. # empty line.
  387. if orig_lines[-1].lstrip().startswith("// CHECK"):
  388. line_number_remap[len(orig_lines) - 1] = len(result_lines) - 1
  389. # Generate contents for any lines that depend on line numbers.
  390. formatted_result_lines = [
  391. line.format(output_line_number=i, line_number_remap=line_number_remap)
  392. for i, line in enumerate(result_lines)
  393. ]
  394. # If nothing's changed, we're done.
  395. if formatted_result_lines == orig_lines:
  396. return False
  397. # Interleave the new CHECK: lines with the tested content.
  398. with test.open("w") as f:
  399. f.writelines(formatted_result_lines)
  400. return True
  401. def update_checks(
  402. parsed_args: ParsedArgs,
  403. bazel_runfiles: Pattern,
  404. llvm_symbolizer: str,
  405. tests: Set[Path],
  406. ) -> None:
  407. """Updates CHECK: lines in lit tests."""
  408. def map_helper(test: Path) -> bool:
  409. try:
  410. updated = update_check(
  411. parsed_args, bazel_runfiles, llvm_symbolizer, test
  412. )
  413. except Exception as e:
  414. raise ValueError(f"Failed to update {test}") from e
  415. print(".", end="", flush=True)
  416. return updated
  417. print(f"Updating {len(tests)} lit test(s)...")
  418. with futures.ThreadPoolExecutor() as exec:
  419. # list() iterates in order to immediately propagate exceptions.
  420. results = list(exec.map(map_helper, tests))
  421. # Each update call indicates progress with a dot without a newline, so put a
  422. # newline to wrap.
  423. print(f"\nUpdated {results.count(True)} lit test(s).")
  424. def main() -> None:
  425. # Parse arguments relative to the working directory.
  426. parsed_args = parse_args()
  427. # Remaining script logic should be relative to the repository root.
  428. root = Path(__file__).parent.parent.parent
  429. os.chdir(root)
  430. if parsed_args.tests:
  431. tests = {test.relative_to(root) for test in parsed_args.tests}
  432. else:
  433. print(
  434. "HINT: run `lit_autoupdate.py f1 f2 ...` to update specific tests"
  435. )
  436. tests = get_tests(parsed_args.testdata)
  437. # Build inputs.
  438. print(f"Building {parsed_args.tool}...")
  439. subprocess.check_call(
  440. [
  441. "bazel",
  442. "build",
  443. "-c",
  444. parsed_args.build_mode,
  445. TOOLS[parsed_args.tool],
  446. ]
  447. )
  448. bazel_bin_dir = subprocess.check_output(
  449. ["bazel", "info", "-c", parsed_args.build_mode, "bazel-bin"],
  450. encoding="utf-8",
  451. ).strip()
  452. bazel_runfiles = re.compile(
  453. r"{0}/.*\.runfiles/carbon/".format(re.escape(bazel_bin_dir))
  454. )
  455. # Grab the symbolizer.
  456. clang_var_content = Path(
  457. "bazel-execroot/external/bazel_cc_toolchain/"
  458. "clang_detected_variables.bzl"
  459. ).read_text()
  460. llvm_symbolizer = re.search(
  461. '(?m)^llvm_symbolizer = "(.*)"$', clang_var_content
  462. )
  463. assert llvm_symbolizer is not None
  464. # Run updates.
  465. update_checks(parsed_args, bazel_runfiles, llvm_symbolizer[1], tests)
  466. if __name__ == "__main__":
  467. main()