check_dependent_pr.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. #!/usr/bin/env python3
  2. """Check if a PR depends on other open PRs based on shared commits.
  3. Usage examples:
  4. # Check a specific PR in dry-run mode:
  5. GITHUB_ACCESS_TOKEN=$(gh auth token) \
  6. python3 github_tools/check_dependent_pr.py --pr-number <PR_NUMBER> --dry-run
  7. # Scan all dependent PRs in dry-run mode:
  8. GITHUB_ACCESS_TOKEN=$(gh auth token) \
  9. python3 github_tools/check_dependent_pr.py --scan --dry-run
  10. """
  11. __copyright__ = """
  12. Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  13. Exceptions. See /LICENSE for license information.
  14. SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  15. """
  16. import argparse
  17. import datetime
  18. import importlib.util
  19. import json
  20. import re
  21. import os
  22. import sys
  23. from typing import Any, Optional
  24. # Do some extra work to support direct runs.
  25. try:
  26. from github_tools import github_helpers
  27. except ImportError:
  28. github_helpers_spec = importlib.util.spec_from_file_location(
  29. "github_helpers",
  30. os.path.join(os.path.dirname(__file__), "github_helpers.py"),
  31. )
  32. assert github_helpers_spec is not None
  33. github_helpers = importlib.util.module_from_spec(github_helpers_spec)
  34. github_helpers_spec.loader.exec_module(github_helpers) # type: ignore
  35. # Queries
  36. _QUERY_OPEN_PRS = """
  37. {
  38. repository(owner: "carbon-language", name: "carbon-lang") {
  39. pullRequests(states: OPEN, first: 100%(cursor)s) {
  40. nodes {
  41. number
  42. commits(first: 100) {
  43. nodes {
  44. commit {
  45. oid
  46. }
  47. }
  48. }
  49. }
  50. %(pagination)s
  51. }
  52. }
  53. }
  54. """
  55. _QUERY_DEPENDENT_PRS = """
  56. {
  57. repository(owner: "carbon-language", name: "carbon-lang") {
  58. pullRequests(states: OPEN, labels: ["dependent"], first: 100%(cursor)s) {
  59. nodes {
  60. number
  61. }
  62. %(pagination)s
  63. }
  64. }
  65. }
  66. """
  67. _QUERY_PR_DETAILS = """
  68. query GetPrDetails($prNumber: Int!) {
  69. repository(owner: "carbon-language", name: "carbon-lang") {
  70. pullRequest(number: $prNumber) {
  71. id
  72. headRefOid
  73. labels(first: 100) {
  74. nodes {
  75. name
  76. id
  77. }
  78. }
  79. commits(first: 100) {
  80. nodes {
  81. commit {
  82. oid
  83. }
  84. }
  85. }
  86. comments(first: 100) {
  87. nodes {
  88. id
  89. body
  90. isMinimized
  91. }
  92. }
  93. }
  94. }
  95. }
  96. """
  97. _QUERY_LABEL = """
  98. {
  99. repository(owner: "carbon-language", name: "carbon-lang") {
  100. label(name: "dependent") {
  101. id
  102. }
  103. }
  104. }
  105. """
  106. _QUERY_MAX_MERGED_PR = """
  107. {
  108. repository(owner: "carbon-language", name: "carbon-lang") {
  109. pullRequests(states: MERGED, first: 1) {
  110. nodes {
  111. number
  112. }
  113. }
  114. }
  115. }
  116. """
  117. _MUTATION_ADD_LABEL = """
  118. mutation AddLabel($labelableId: ID!, $labelIds: [ID!]!) {
  119. addLabelsToLabelable(
  120. input: {labelableId: $labelableId, labelIds: $labelIds}
  121. ) {
  122. clientMutationId
  123. }
  124. }
  125. """
  126. _MUTATION_REMOVE_LABEL = """
  127. mutation RemoveLabel($labelableId: ID!, $labelIds: [ID!]!) {
  128. removeLabelsFromLabelable(
  129. input: {labelableId: $labelableId, labelIds: $labelIds}
  130. ) {
  131. clientMutationId
  132. }
  133. }
  134. """
  135. _MUTATION_UPDATE_COMMENT = """
  136. mutation UpdateComment($id: ID!, $body: String!) {
  137. updateIssueComment(input: {id: $id, body: $body}) {
  138. clientMutationId
  139. }
  140. }
  141. """
  142. _MUTATION_ADD_COMMENT = """
  143. mutation AddComment($subjectId: ID!, $body: String!) {
  144. addComment(input: {subjectId: $subjectId, body: $body}) {
  145. clientMutationId
  146. }
  147. }
  148. """
  149. def _print_err(*args: Any, **kwargs: Any) -> None:
  150. """Prints to stderr."""
  151. kwargs["file"] = sys.stderr
  152. print(*args, **kwargs)
  153. def _parse_pr_number(x: Any) -> Optional[int]:
  154. """Parses x into a positive integer if possible."""
  155. if isinstance(x, int):
  156. return x if x > 0 else None
  157. if isinstance(x, str) and x.isdigit():
  158. val = int(x)
  159. return val if val > 0 else None
  160. return None
  161. def _parse_and_validate_state(
  162. json_str: str,
  163. open_pr_numbers: set[int],
  164. max_merged_pr: int = 10000,
  165. pr_number: int = 0,
  166. ) -> tuple[list[int], list[int], Optional[str]]:
  167. """Parses and validates the state from a JSON string."""
  168. parsed_open: list[int] = []
  169. parsed_merged: list[int] = []
  170. first_commit: Optional[str] = None
  171. raw_state = json.loads(json_str)
  172. if not isinstance(raw_state, dict):
  173. raise ValueError(f"PR #{pr_number}: Parsed JSON is not a dictionary.")
  174. for x in raw_state.get("open", []):
  175. val = _parse_pr_number(x)
  176. if val is None:
  177. raise ValueError(
  178. f"PR #{pr_number}: Invalid PR number format in 'open': {x}"
  179. )
  180. elif val not in open_pr_numbers and val > max_merged_pr:
  181. raise ValueError(
  182. f"PR #{pr_number}: Rejecting PR #{val} from 'open' because "
  183. "it is not an open PR and exceeds maximum merged PR "
  184. f"#{max_merged_pr}."
  185. )
  186. else:
  187. parsed_open.append(val)
  188. for x in raw_state.get("merged", []):
  189. val = _parse_pr_number(x)
  190. if val is None:
  191. raise ValueError(
  192. f"PR #{pr_number}: Invalid PR number format in 'merged': {x}"
  193. )
  194. elif val in open_pr_numbers:
  195. raise ValueError(
  196. f"PR #{pr_number}: Rejecting PR #{val} from 'merged' "
  197. "because it is actually open."
  198. )
  199. elif val > max_merged_pr:
  200. raise ValueError(
  201. f"PR #{pr_number}: Rejecting PR #{val} from 'merged' "
  202. f"because it exceeds maximum merged PR #{max_merged_pr}."
  203. )
  204. else:
  205. parsed_merged.append(val)
  206. if "first_commit" in raw_state:
  207. fc = raw_state["first_commit"]
  208. if isinstance(fc, str) and re.fullmatch(r"[0-9a-fA-F]{40}", fc):
  209. first_commit = fc
  210. else:
  211. raise ValueError(
  212. f"PR #{pr_number}: Invalid commit OID format in "
  213. f"'first_commit': {fc}"
  214. )
  215. return parsed_open, parsed_merged, first_commit
  216. def _process_pr(
  217. client: github_helpers.Client,
  218. pr_number: int,
  219. pr_to_commits: dict[int, list[str]],
  220. open_pr_numbers: set[int],
  221. label_id: str,
  222. dry_run: bool,
  223. scanning: bool = False,
  224. max_merged_pr: int = 10000,
  225. ) -> None:
  226. """Processes a single PR to check for dependencies and update comments."""
  227. current_res = client.execute(
  228. _QUERY_PR_DETAILS, variable_values={"prNumber": pr_number}
  229. )
  230. pr_node = current_res["repository"]["pullRequest"]
  231. if not pr_node:
  232. _print_err(f"PR #{pr_number} not found.")
  233. return
  234. pr_id = pr_node["id"]
  235. commits = pr_node["commits"]["nodes"]
  236. comments = pr_node["comments"]["nodes"]
  237. labels = pr_node["labels"]["nodes"]
  238. open_deps: list[int] = []
  239. if len(commits) <= 1:
  240. _print_err(
  241. f"PR #{pr_number} has 1 or fewer commits, skipping overlap check."
  242. )
  243. current_oids = [c["commit"]["oid"] for c in commits]
  244. else:
  245. current_oids = [c["commit"]["oid"] for c in commits]
  246. # Dependency Logic: Overlap and Sequence
  247. #
  248. # We consider PR B dependent on PR A if:
  249. # 1. The dependency PR A was created before PR B (A.number < B.number).
  250. # 2. There is a non-empty overlap of commits between PR A and PR B.
  251. # 3. PR B has at least one commit not present in PR A.
  252. #
  253. # Why this works:
  254. # - Ensures the dependency direction reflects the creation sequence.
  255. # - Handles minor fixes or differences by only requiring overlap, not
  256. # strict subset inclusion.
  257. # - Avoids circular dependencies via the sequence check.
  258. current_oids_set = set(current_oids)
  259. for other_pr_num, other_oids in pr_to_commits.items():
  260. if other_pr_num >= pr_number:
  261. continue
  262. other_oids_set = set(other_oids)
  263. if not (other_oids_set & current_oids_set):
  264. continue
  265. if not (current_oids_set - other_oids_set):
  266. continue
  267. open_deps.append(other_pr_num)
  268. # Parse existing comment
  269. marker_prefix = "<!-- check_dependent_pr "
  270. existing_comment_id = None
  271. parsed_open_deps: list[int] = []
  272. parsed_merged_deps: list[int] = []
  273. previous_first_commit: Optional[str] = None
  274. matching_comment = None
  275. for comment in comments:
  276. # If a marker comment is hidden (minimized), we ignore it and treat
  277. # the PR as if it never had that comment.
  278. if marker_prefix in comment["body"] and not comment.get("isMinimized"):
  279. matching_comment = comment
  280. break
  281. if matching_comment:
  282. existing_comment_id = matching_comment["id"]
  283. body = matching_comment["body"]
  284. start = body.find(marker_prefix) + len(marker_prefix)
  285. end = body.find(" -->", start)
  286. if end != -1:
  287. parsed_open_deps, parsed_merged_deps, previous_first_commit = (
  288. _parse_and_validate_state(
  289. body[start:end], open_pr_numbers, max_merged_pr, pr_number
  290. )
  291. )
  292. if not open_deps and not existing_comment_id:
  293. return
  294. # Keep tracking previously identified dependencies if they are still open,
  295. # even if they no longer pass the subset check (e.g. they got new commits).
  296. for pr in parsed_open_deps:
  297. if pr in open_pr_numbers and pr not in open_deps:
  298. open_deps.append(pr)
  299. # Identify newly merged PRs
  300. newly_merged_deps = []
  301. for pr in parsed_open_deps:
  302. if pr not in open_deps and pr not in open_pr_numbers:
  303. newly_merged_deps.append(pr)
  304. merged_deps = list(set(parsed_merged_deps + newly_merged_deps))
  305. first_independent_commit_oid = None
  306. if open_deps:
  307. dependent_oids = set()
  308. for d in open_deps:
  309. dependent_oids.update(pr_to_commits[d])
  310. # previous_first_commit already assigned from comment state.
  311. if previous_first_commit and previous_first_commit in current_oids:
  312. start_idx = current_oids.index(previous_first_commit)
  313. else:
  314. start_idx = 0
  315. # Assumes `current_oids` is in chronological order (oldest first).
  316. # This guarantees we find the first independent commit to start the
  317. # review.
  318. for oid in current_oids[start_idx:]:
  319. if oid not in dependent_oids:
  320. first_independent_commit_oid = oid
  321. break
  322. if (
  323. open_deps == parsed_open_deps
  324. and merged_deps == parsed_merged_deps
  325. and first_independent_commit_oid == previous_first_commit
  326. ):
  327. return
  328. # Construct new comment
  329. timestamp = datetime.datetime.now(datetime.timezone.utc).strftime(
  330. "%Y-%m-%d %H:%M:%S UTC"
  331. )
  332. new_state: dict[str, Any] = {
  333. "open": open_deps,
  334. "merged": merged_deps,
  335. "first_commit": first_independent_commit_oid,
  336. }
  337. state_json = json.dumps(new_state)
  338. comment_body = f"{marker_prefix}{state_json} -->\n"
  339. if open_deps:
  340. pr_list_str = ", ".join([f"#{num}" for num in open_deps])
  341. if first_independent_commit_oid:
  342. short_hash = first_independent_commit_oid[:8]
  343. first_commit_linked = (
  344. f"[{short_hash}]({pr_number}/commits/{short_hash})"
  345. )
  346. comment_body += (
  347. f"Depends on {pr_list_str}, start review at "
  348. f"{first_commit_linked}"
  349. )
  350. else:
  351. comment_body += (
  352. f"Depends on {pr_list_str}, unable to identify starting review "
  353. f"commit from simple analysis"
  354. )
  355. else:
  356. comment_body += "All dependent PRs are merged."
  357. if merged_deps:
  358. merged_str = ", ".join([f"#{num}" for num in sorted(merged_deps)])
  359. comment_body += f"\n\nMerged dependent PRs: {merged_str}"
  360. comment_body += f"\n\n(Last updated: {timestamp})"
  361. _print_err(f"PR #{pr_number}: Updating comment. New body:\n{comment_body}")
  362. # Apply mutations
  363. has_dependent_label = any(label["name"] == "dependent" for label in labels)
  364. if open_deps and not has_dependent_label and not scanning:
  365. if dry_run:
  366. _print_err(
  367. f"[Dry-run] Would add 'dependent' label to PR #{pr_number}"
  368. )
  369. else:
  370. client.execute(
  371. _MUTATION_ADD_LABEL,
  372. variable_values={"labelableId": pr_id, "labelIds": [label_id]},
  373. )
  374. elif not open_deps and has_dependent_label:
  375. if dry_run:
  376. _print_err(
  377. f"[Dry-run] Would remove 'dependent' label from PR #{pr_number}"
  378. )
  379. else:
  380. client.execute(
  381. _MUTATION_REMOVE_LABEL,
  382. variable_values={"labelableId": pr_id, "labelIds": [label_id]},
  383. )
  384. if existing_comment_id:
  385. if dry_run:
  386. _print_err(f"[Dry-run] Would update comment {existing_comment_id}")
  387. else:
  388. client.execute(
  389. _MUTATION_UPDATE_COMMENT,
  390. variable_values={
  391. "id": existing_comment_id,
  392. "body": comment_body,
  393. },
  394. )
  395. else:
  396. if scanning:
  397. _print_err(
  398. f"PR #{pr_number}: Skipping new comment creation in scan mode."
  399. )
  400. return
  401. if dry_run:
  402. _print_err(f"[Dry-run] Would add comment to PR #{pr_number}")
  403. else:
  404. client.execute(
  405. _MUTATION_ADD_COMMENT,
  406. variable_values={"subjectId": pr_id, "body": comment_body},
  407. )
  408. def _parse_args(args: Optional[list[str]] = None) -> argparse.Namespace:
  409. """Parses command-line arguments."""
  410. parser = argparse.ArgumentParser(
  411. description=__doc__,
  412. formatter_class=argparse.RawDescriptionHelpFormatter,
  413. )
  414. group = parser.add_mutually_exclusive_group(required=True)
  415. group.add_argument(
  416. "--pr-number",
  417. type=int,
  418. help="The pull request number to check.",
  419. )
  420. group.add_argument(
  421. "--scan",
  422. action="store_true",
  423. help="Scan all open PRs with 'dependent' label and update them.",
  424. )
  425. parser.add_argument(
  426. "--dry-run",
  427. action="store_true",
  428. help="Print mutations without updating GitHub",
  429. )
  430. github_helpers.add_access_token_arg(parser, "repo")
  431. return parser.parse_args(args=args)
  432. def main() -> None:
  433. parsed_args = _parse_args()
  434. client = github_helpers.Client(parsed_args)
  435. _print_err("Loading open PRs ...", end="", flush=True)
  436. pr_to_commits: dict[int, list[str]] = {}
  437. open_pr_numbers: set[int] = set()
  438. for node in client.execute_and_paginate(
  439. _QUERY_OPEN_PRS, ("repository", "pullRequests")
  440. ):
  441. _print_err(".", end="", flush=True)
  442. other_pr_num = node["number"]
  443. open_pr_numbers.add(other_pr_num)
  444. pr_to_commits[other_pr_num] = [
  445. c["commit"]["oid"] for c in node["commits"]["nodes"]
  446. ]
  447. _print_err()
  448. label_res = client.execute(_QUERY_LABEL)
  449. label_id = label_res["repository"]["label"]["id"]
  450. merged_res = client.execute(_QUERY_MAX_MERGED_PR)
  451. merged_nodes = merged_res["repository"]["pullRequests"]["nodes"]
  452. max_merged_pr = merged_nodes[0]["number"] if merged_nodes else 0
  453. if parsed_args.pr_number:
  454. _process_pr(
  455. client,
  456. parsed_args.pr_number,
  457. pr_to_commits,
  458. open_pr_numbers,
  459. label_id,
  460. parsed_args.dry_run,
  461. max_merged_pr=max_merged_pr,
  462. )
  463. elif parsed_args.scan:
  464. for node in client.execute_and_paginate(
  465. _QUERY_DEPENDENT_PRS, ("repository", "pullRequests")
  466. ):
  467. _process_pr(
  468. client,
  469. node["number"],
  470. pr_to_commits,
  471. open_pr_numbers,
  472. label_id,
  473. parsed_args.dry_run,
  474. scanning=True,
  475. max_merged_pr=max_merged_pr,
  476. )
  477. if __name__ == "__main__":
  478. main()