From 47905d2d19f703094403392d360f1a9c9ac15168 Mon Sep 17 00:00:00 2001 From: Sebastian Rittau Date: Tue, 25 Jul 2023 17:40:11 +0200 Subject: [PATCH] stubsabot: Add stubsabot label to PRs (#10507) --- scripts/stubsabot.py | 68 +++++++++++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/scripts/stubsabot.py b/scripts/stubsabot.py index 050abfd89..f4ff3dd20 100644 --- a/scripts/stubsabot.py +++ b/scripts/stubsabot.py @@ -18,6 +18,7 @@ import urllib.parse import zipfile from collections.abc import Iterator, Mapping, Sequence from dataclasses import dataclass +from http import HTTPStatus from pathlib import Path from typing import Annotated, Any, ClassVar, NamedTuple from typing_extensions import Self, TypeAlias @@ -29,6 +30,11 @@ import tomli import tomlkit from termcolor import colored +TYPESHED_OWNER = "python" +TYPESHED_API_URL = f"https://api.github.com/repos/{TYPESHED_OWNER}/typeshed" + +STUBSABOT_LABEL = "stubsabot" + class ActionLevel(enum.IntEnum): def __new__(cls, value: int, doc: str) -> Self: @@ -482,9 +488,6 @@ async def determine_action(stub_path: Path, session: aiohttp.ClientSession) -> U ) -TYPESHED_OWNER = "python" - - @functools.lru_cache() def get_origin_owner() -> str: output = subprocess.check_output(["git", "remote", "get-url", "origin"], text=True).strip() @@ -498,32 +501,51 @@ async def create_or_update_pull_request(*, title: str, body: str, branch_name: s fork_owner = get_origin_owner() async with session.post( - f"https://api.github.com/repos/{TYPESHED_OWNER}/typeshed/pulls", + f"{TYPESHED_API_URL}/pulls", json={"title": title, "body": body, "head": f"{fork_owner}:{branch_name}", "base": "main"}, headers=get_github_api_headers(), ) as response: resp_json = await response.json() - if response.status == 422 and any( + if response.status == HTTPStatus.CREATED: + pr_number = resp_json["number"] + assert isinstance(pr_number, int) + elif response.status == HTTPStatus.UNPROCESSABLE_ENTITY and any( "A pull request already exists" in e.get("message", "") for e in resp_json.get("errors", []) ): - # Find the existing PR - async with session.get( - f"https://api.github.com/repos/{TYPESHED_OWNER}/typeshed/pulls", - params={"state": "open", "head": f"{fork_owner}:{branch_name}", "base": "main"}, - headers=get_github_api_headers(), - ) as response: - response.raise_for_status() - resp_json = await response.json() - assert len(resp_json) >= 1 - pr_number = resp_json[0]["number"] - # Update the PR's title and body - async with session.patch( - f"https://api.github.com/repos/{TYPESHED_OWNER}/typeshed/pulls/{pr_number}", - json={"title": title, "body": body}, - headers=get_github_api_headers(), - ) as response: - response.raise_for_status() - return + pr_number = await update_existing_pull_request(title=title, body=body, branch_name=branch_name, session=session) + else: + response.raise_for_status() + raise AssertionError(f"Unexpected response: {response.status}") + await update_pull_request_label(pr_number=pr_number, session=session) + + +async def update_existing_pull_request(*, title: str, body: str, branch_name: str, session: aiohttp.ClientSession) -> int: + fork_owner = get_origin_owner() + + # Find the existing PR + async with session.get( + f"{TYPESHED_API_URL}/pulls", + params={"state": "open", "head": f"{fork_owner}:{branch_name}", "base": "main"}, + headers=get_github_api_headers(), + ) as response: + response.raise_for_status() + resp_json = await response.json() + assert len(resp_json) >= 1 + pr_number = resp_json[0]["number"] + assert isinstance(pr_number, int) + # Update the PR's title and body + async with session.patch( + f"{TYPESHED_API_URL}/pulls/{pr_number}", json={"title": title, "body": body}, headers=get_github_api_headers() + ) as response: + response.raise_for_status() + return pr_number + + +async def update_pull_request_label(*, pr_number: int, session: aiohttp.ClientSession) -> None: + # There is no pulls/.../labels endpoint, which is why we need to use the issues endpoint. + async with session.post( + f"{TYPESHED_API_URL}/issues/{pr_number}/labels", json={"labels": [STUBSABOT_LABEL]}, headers=get_github_api_headers() + ) as response: response.raise_for_status()