check_dependent_pr.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. #!/usr/bin/env python3
  2. """Check if a PR depends on other open PRs based on shared commits.
  3. Usage examples:
  4. # Check a specific PR in dry-run mode:
  5. GITHUB_ACCESS_TOKEN=$(gh auth token) \
  6. python3 github_tools/check_dependent_pr.py --pr-number <PR_NUMBER> --dry-run
  7. # Scan all dependent PRs in dry-run mode:
  8. GITHUB_ACCESS_TOKEN=$(gh auth token) \
  9. python3 github_tools/check_dependent_pr.py --scan --dry-run
  10. """
  11. __copyright__ = """
  12. Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  13. Exceptions. See /LICENSE for license information.
  14. SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  15. """
  16. import argparse
  17. import datetime
  18. import importlib.util
  19. import json
  20. import re
  21. import os
  22. import sys
  23. import requests
  24. from typing import Any, Optional
  25. # Do some extra work to support direct runs.
  26. try:
  27. from github_tools import github_helpers
  28. except ImportError:
  29. github_helpers_spec = importlib.util.spec_from_file_location(
  30. "github_helpers",
  31. os.path.join(os.path.dirname(__file__), "github_helpers.py"),
  32. )
  33. assert github_helpers_spec is not None
  34. github_helpers = importlib.util.module_from_spec(github_helpers_spec)
  35. github_helpers_spec.loader.exec_module(github_helpers) # type: ignore
  36. # Queries
  37. _QUERY_OPEN_PRS = """
  38. {
  39. repository(owner: "carbon-language", name: "carbon-lang") {
  40. pullRequests(states: OPEN, first: 100%(cursor)s) {
  41. nodes {
  42. number
  43. commits(first: 100) {
  44. nodes {
  45. commit {
  46. oid
  47. }
  48. }
  49. }
  50. }
  51. %(pagination)s
  52. }
  53. }
  54. }
  55. """
  56. _QUERY_DEPENDENT_PRS = """
  57. {
  58. repository(owner: "carbon-language", name: "carbon-lang") {
  59. pullRequests(states: OPEN, labels: ["dependent"], first: 100%(cursor)s) {
  60. nodes {
  61. number
  62. }
  63. %(pagination)s
  64. }
  65. }
  66. }
  67. """
  68. _QUERY_PR_DETAILS = """
  69. query GetPrDetails($prNumber: Int!) {
  70. repository(owner: "carbon-language", name: "carbon-lang") {
  71. pullRequest(number: $prNumber) {
  72. id
  73. headRefOid
  74. labels(first: 100) {
  75. nodes {
  76. name
  77. id
  78. }
  79. }
  80. commits(first: 100) {
  81. nodes {
  82. commit {
  83. oid
  84. }
  85. }
  86. }
  87. comments(first: 100) {
  88. nodes {
  89. id
  90. body
  91. isMinimized
  92. }
  93. }
  94. }
  95. }
  96. }
  97. """
  98. _QUERY_LABEL = """
  99. {
  100. repository(owner: "carbon-language", name: "carbon-lang") {
  101. label(name: "dependent") {
  102. id
  103. }
  104. }
  105. }
  106. """
  107. _QUERY_MAX_MERGED_PR = """
  108. {
  109. repository(owner: "carbon-language", name: "carbon-lang") {
  110. pullRequests(
  111. states: MERGED
  112. orderBy: {field: CREATED_AT, direction: DESC}
  113. first: 1
  114. ) {
  115. nodes {
  116. number
  117. }
  118. }
  119. }
  120. }
  121. """
  122. _MUTATION_ADD_LABEL = """
  123. mutation AddLabel($labelableId: ID!, $labelIds: [ID!]!) {
  124. addLabelsToLabelable(
  125. input: {labelableId: $labelableId, labelIds: $labelIds}
  126. ) {
  127. clientMutationId
  128. }
  129. }
  130. """
  131. _MUTATION_REMOVE_LABEL = """
  132. mutation RemoveLabel($labelableId: ID!, $labelIds: [ID!]!) {
  133. removeLabelsFromLabelable(
  134. input: {labelableId: $labelableId, labelIds: $labelIds}
  135. ) {
  136. clientMutationId
  137. }
  138. }
  139. """
  140. _MUTATION_UPDATE_COMMENT = """
  141. mutation UpdateComment($id: ID!, $body: String!) {
  142. updateIssueComment(input: {id: $id, body: $body}) {
  143. clientMutationId
  144. }
  145. }
  146. """
  147. _MUTATION_ADD_COMMENT = """
  148. mutation AddComment($subjectId: ID!, $body: String!) {
  149. addComment(input: {subjectId: $subjectId, body: $body}) {
  150. clientMutationId
  151. }
  152. }
  153. """
  154. def _print_err(*args: Any, **kwargs: Any) -> None:
  155. """Prints to stderr."""
  156. kwargs["file"] = sys.stderr
  157. print(*args, **kwargs)
  158. def _parse_pr_number(x: Any) -> Optional[int]:
  159. """Parses x into a positive integer if possible."""
  160. if isinstance(x, int):
  161. return x if x > 0 else None
  162. if isinstance(x, str) and x.isdigit():
  163. val = int(x)
  164. return val if val > 0 else None
  165. return None
  166. def _parse_and_validate_state(
  167. json_str: str,
  168. open_pr_numbers: set[int],
  169. max_merged_pr: int = 10000,
  170. pr_number: int = 0,
  171. ) -> tuple[list[int], list[int], Optional[str]]:
  172. """Parses and validates the state from a JSON string."""
  173. parsed_open: list[int] = []
  174. parsed_merged: list[int] = []
  175. first_commit: Optional[str] = None
  176. raw_state = json.loads(json_str)
  177. if not isinstance(raw_state, dict):
  178. raise ValueError(f"PR #{pr_number}: Parsed JSON is not a dictionary.")
  179. for x in raw_state.get("open", []):
  180. val = _parse_pr_number(x)
  181. if val is None:
  182. raise ValueError(
  183. f"PR #{pr_number}: Invalid PR number format in 'open': {x}"
  184. )
  185. elif val not in open_pr_numbers and val > max_merged_pr:
  186. raise ValueError(
  187. f"PR #{pr_number}: Rejecting PR #{val} from 'open' because "
  188. "it is not an open PR and exceeds maximum merged PR "
  189. f"#{max_merged_pr}."
  190. )
  191. else:
  192. parsed_open.append(val)
  193. for x in raw_state.get("merged", []):
  194. val = _parse_pr_number(x)
  195. if val is None:
  196. raise ValueError(
  197. f"PR #{pr_number}: Invalid PR number format in 'merged': {x}"
  198. )
  199. elif val in open_pr_numbers:
  200. raise ValueError(
  201. f"PR #{pr_number}: Rejecting PR #{val} from 'merged' "
  202. "because it is actually open."
  203. )
  204. elif val > max_merged_pr:
  205. raise ValueError(
  206. f"PR #{pr_number}: Rejecting PR #{val} from 'merged' "
  207. f"because it exceeds maximum merged PR #{max_merged_pr}."
  208. )
  209. else:
  210. parsed_merged.append(val)
  211. if "first_commit" in raw_state:
  212. fc = raw_state["first_commit"]
  213. if isinstance(fc, str) and re.fullmatch(r"[0-9a-fA-F]{40}", fc):
  214. first_commit = fc
  215. else:
  216. raise ValueError(
  217. f"PR #{pr_number}: Invalid commit OID format in "
  218. f"'first_commit': {fc}"
  219. )
  220. return parsed_open, parsed_merged, first_commit
  221. def _set_commit_status(
  222. sha: str,
  223. state: str,
  224. description: str,
  225. token: str,
  226. dry_run: bool,
  227. ) -> None:
  228. """Sets the commit status via the GitHub REST API."""
  229. url = (
  230. "https://api.github.com/repos/carbon-language/carbon-lang/"
  231. f"statuses/{sha}"
  232. )
  233. headers = {
  234. "Authorization": f"bearer {token}",
  235. "Accept": "application/vnd.github.v3+json",
  236. }
  237. payload = {
  238. "state": state,
  239. "description": description,
  240. "context": "PR dependencies check",
  241. }
  242. if dry_run:
  243. _print_err(
  244. f"[Dry-run] Would set commit status on {sha[:8]} to {state} "
  245. f"({description})"
  246. )
  247. return
  248. try:
  249. response = requests.post(url, headers=headers, json=payload)
  250. response.raise_for_status()
  251. _print_err(f"Set commit status on {sha[:8]} to {state}")
  252. except Exception as e:
  253. _print_err(f"Error setting commit status on {sha[:8]}: {e}")
  254. def _process_pr(
  255. client: github_helpers.Client,
  256. pr_number: int,
  257. pr_to_commits: dict[int, list[str]],
  258. open_pr_numbers: set[int],
  259. label_id: str,
  260. token: str,
  261. dry_run: bool = False,
  262. scanning: bool = False,
  263. max_merged_pr: int = 10000,
  264. ) -> None:
  265. """Processes a single PR to check for dependencies and update comments."""
  266. current_res = client.execute(
  267. _QUERY_PR_DETAILS, variable_values={"prNumber": pr_number}
  268. )
  269. pr_node = current_res["repository"]["pullRequest"]
  270. if not pr_node:
  271. _print_err(f"PR #{pr_number} not found.")
  272. return
  273. pr_id = pr_node["id"]
  274. commits = pr_node["commits"]["nodes"]
  275. comments = pr_node["comments"]["nodes"]
  276. labels = pr_node["labels"]["nodes"]
  277. open_deps: list[int] = []
  278. if len(commits) <= 1:
  279. _print_err(
  280. f"PR #{pr_number} has 1 or fewer commits, skipping overlap check."
  281. )
  282. current_oids = [c["commit"]["oid"] for c in commits]
  283. else:
  284. current_oids = [c["commit"]["oid"] for c in commits]
  285. # Dependency Logic: Overlap and Sequence
  286. #
  287. # We consider PR B dependent on PR A if:
  288. # 1. The dependency PR A was created before PR B (A.number < B.number).
  289. # 2. There is a non-empty overlap of commits between PR A and PR B.
  290. # 3. PR B has at least one commit not present in PR A.
  291. #
  292. # Why this works:
  293. # - Ensures the dependency direction reflects the creation sequence.
  294. # - Handles minor fixes or differences by only requiring overlap, not
  295. # strict subset inclusion.
  296. # - Avoids circular dependencies via the sequence check.
  297. current_oids_set = set(current_oids)
  298. for other_pr_num, other_oids in pr_to_commits.items():
  299. if other_pr_num >= pr_number:
  300. continue
  301. other_oids_set = set(other_oids)
  302. if not (other_oids_set & current_oids_set):
  303. continue
  304. if not (current_oids_set - other_oids_set):
  305. continue
  306. open_deps.append(other_pr_num)
  307. # Parse existing comment
  308. marker_prefix = "<!-- check_dependent_pr "
  309. existing_comment_id = None
  310. parsed_open_deps: list[int] = []
  311. parsed_merged_deps: list[int] = []
  312. previous_first_commit: Optional[str] = None
  313. matching_comment = None
  314. for comment in comments:
  315. # If a marker comment is hidden (minimized), we ignore it and treat
  316. # the PR as if it never had that comment.
  317. if marker_prefix in comment["body"] and not comment.get("isMinimized"):
  318. matching_comment = comment
  319. break
  320. if matching_comment:
  321. existing_comment_id = matching_comment["id"]
  322. body = matching_comment["body"]
  323. start = body.find(marker_prefix) + len(marker_prefix)
  324. end = body.find(" -->", start)
  325. if end != -1:
  326. parsed_open_deps, parsed_merged_deps, previous_first_commit = (
  327. _parse_and_validate_state(
  328. body[start:end], open_pr_numbers, max_merged_pr, pr_number
  329. )
  330. )
  331. # Keep tracking previously identified dependencies if they are still open,
  332. # even if they no longer pass the subset check (e.g. they got new commits).
  333. for pr in parsed_open_deps:
  334. if pr in open_pr_numbers and pr not in open_deps:
  335. open_deps.append(pr)
  336. # Identify newly merged PRs
  337. newly_merged_deps = []
  338. for pr in parsed_open_deps:
  339. if pr not in open_deps and pr not in open_pr_numbers:
  340. newly_merged_deps.append(pr)
  341. merged_deps = list(set(parsed_merged_deps + newly_merged_deps))
  342. if open_deps:
  343. state = "pending"
  344. pr_list_str = ", ".join([f"#{num}" for num in open_deps])
  345. description = f"This PR has open dependencies: {pr_list_str}"
  346. else:
  347. state = "success"
  348. description = "This PR has no open dependencies"
  349. _set_commit_status(
  350. pr_node["headRefOid"], state, description, token, dry_run
  351. )
  352. first_independent_commit_oid = None
  353. if open_deps:
  354. dependent_oids = set()
  355. for d in open_deps:
  356. dependent_oids.update(pr_to_commits[d])
  357. # previous_first_commit already assigned from comment state.
  358. if previous_first_commit and previous_first_commit in current_oids:
  359. start_idx = current_oids.index(previous_first_commit)
  360. else:
  361. start_idx = 0
  362. # Assumes `current_oids` is in chronological order (oldest first).
  363. # This guarantees we find the first independent commit to start the
  364. # review.
  365. for oid in current_oids[start_idx:]:
  366. if oid not in dependent_oids:
  367. first_independent_commit_oid = oid
  368. break
  369. any_later_dependent_oids = False
  370. if first_independent_commit_oid:
  371. idx = current_oids.index(first_independent_commit_oid)
  372. any_later_dependent_oids = any(
  373. oid in dependent_oids for oid in current_oids[idx + 1 :]
  374. )
  375. if (
  376. open_deps == parsed_open_deps
  377. and merged_deps == parsed_merged_deps
  378. and first_independent_commit_oid == previous_first_commit
  379. ):
  380. return
  381. # Construct new comment
  382. timestamp = datetime.datetime.now(datetime.timezone.utc).strftime(
  383. "%Y-%m-%d %H:%M:%S UTC"
  384. )
  385. new_state: dict[str, Any] = {
  386. "open": open_deps,
  387. "merged": merged_deps,
  388. "first_commit": first_independent_commit_oid,
  389. }
  390. state_json = json.dumps(new_state)
  391. comment_body = f"{marker_prefix}{state_json} -->\n"
  392. if open_deps:
  393. pr_list_str = ", ".join([f"#{num}" for num in open_deps])
  394. if first_independent_commit_oid:
  395. changes_url = (
  396. "https://github.com/carbon-language/carbon-lang/pull/"
  397. f"{pr_number}/changes/{first_independent_commit_oid}..HEAD"
  398. )
  399. comment_body += (
  400. f"Depends on {pr_list_str}, start review with "
  401. f"[these changes]({changes_url})"
  402. )
  403. if any_later_dependent_oids:
  404. comment_body += (
  405. "\n\n> [!WARNING]\n"
  406. "> Also contains changes from dependent PRs due to "
  407. "non-linear history."
  408. )
  409. else:
  410. comment_body += (
  411. f"Depends on {pr_list_str}, unable to identify starting review "
  412. f"commit from simple analysis"
  413. )
  414. else:
  415. comment_body += "All dependent PRs are merged."
  416. if merged_deps:
  417. merged_str = ", ".join([f"#{num}" for num in sorted(merged_deps)])
  418. comment_body += f"\n\nMerged dependent PRs: {merged_str}"
  419. comment_body += f"\n\n(Last updated: {timestamp})"
  420. _print_err(f"PR #{pr_number}: Updating comment. New body:\n{comment_body}")
  421. # Apply mutations
  422. has_dependent_label = any(label["name"] == "dependent" for label in labels)
  423. if open_deps and not has_dependent_label and not scanning:
  424. if dry_run:
  425. _print_err(
  426. f"[Dry-run] Would add 'dependent' label to PR #{pr_number}"
  427. )
  428. else:
  429. client.execute(
  430. _MUTATION_ADD_LABEL,
  431. variable_values={"labelableId": pr_id, "labelIds": [label_id]},
  432. )
  433. elif not open_deps and has_dependent_label:
  434. if dry_run:
  435. _print_err(
  436. f"[Dry-run] Would remove 'dependent' label from PR #{pr_number}"
  437. )
  438. else:
  439. client.execute(
  440. _MUTATION_REMOVE_LABEL,
  441. variable_values={"labelableId": pr_id, "labelIds": [label_id]},
  442. )
  443. if existing_comment_id:
  444. if dry_run:
  445. _print_err(f"[Dry-run] Would update comment {existing_comment_id}")
  446. else:
  447. client.execute(
  448. _MUTATION_UPDATE_COMMENT,
  449. variable_values={
  450. "id": existing_comment_id,
  451. "body": comment_body,
  452. },
  453. )
  454. else:
  455. if scanning:
  456. _print_err(
  457. f"PR #{pr_number}: Skipping new comment creation in scan mode."
  458. )
  459. return
  460. if dry_run:
  461. _print_err(f"[Dry-run] Would add comment to PR #{pr_number}")
  462. else:
  463. client.execute(
  464. _MUTATION_ADD_COMMENT,
  465. variable_values={"subjectId": pr_id, "body": comment_body},
  466. )
  467. def _parse_args(args: Optional[list[str]] = None) -> argparse.Namespace:
  468. """Parses command-line arguments."""
  469. parser = argparse.ArgumentParser(
  470. description=__doc__,
  471. formatter_class=argparse.RawDescriptionHelpFormatter,
  472. )
  473. group = parser.add_mutually_exclusive_group(required=True)
  474. group.add_argument(
  475. "--pr-number",
  476. type=int,
  477. help="The pull request number to check.",
  478. )
  479. group.add_argument(
  480. "--scan",
  481. action="store_true",
  482. help="Scan all open PRs with 'dependent' label and update them.",
  483. )
  484. parser.add_argument(
  485. "--dry-run",
  486. action="store_true",
  487. help="Print mutations without updating GitHub",
  488. )
  489. github_helpers.add_access_token_arg(parser, "repo")
  490. return parser.parse_args(args=args)
  491. def main() -> None:
  492. parsed_args = _parse_args()
  493. client = github_helpers.Client(parsed_args)
  494. _print_err("Loading open PRs ...", end="", flush=True)
  495. pr_to_commits: dict[int, list[str]] = {}
  496. open_pr_numbers: set[int] = set()
  497. for node in client.execute_and_paginate(
  498. _QUERY_OPEN_PRS, ("repository", "pullRequests")
  499. ):
  500. _print_err(".", end="", flush=True)
  501. other_pr_num = node["number"]
  502. open_pr_numbers.add(other_pr_num)
  503. pr_to_commits[other_pr_num] = [
  504. c["commit"]["oid"] for c in node["commits"]["nodes"]
  505. ]
  506. _print_err()
  507. label_res = client.execute(_QUERY_LABEL)
  508. label_id = label_res["repository"]["label"]["id"]
  509. merged_res = client.execute(_QUERY_MAX_MERGED_PR)
  510. merged_nodes = merged_res["repository"]["pullRequests"]["nodes"]
  511. max_merged_pr = merged_nodes[0]["number"] if merged_nodes else 0
  512. if parsed_args.pr_number:
  513. _process_pr(
  514. client,
  515. parsed_args.pr_number,
  516. pr_to_commits,
  517. open_pr_numbers,
  518. label_id,
  519. parsed_args.access_token,
  520. dry_run=parsed_args.dry_run,
  521. max_merged_pr=max_merged_pr,
  522. )
  523. elif parsed_args.scan:
  524. for node in client.execute_and_paginate(
  525. _QUERY_DEPENDENT_PRS, ("repository", "pullRequests")
  526. ):
  527. _process_pr(
  528. client,
  529. node["number"],
  530. pr_to_commits,
  531. open_pr_numbers,
  532. label_id,
  533. parsed_args.access_token,
  534. dry_run=parsed_args.dry_run,
  535. scanning=True,
  536. max_merged_pr=max_merged_pr,
  537. )
  538. if __name__ == "__main__":
  539. main()