diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 5cab3a2c915..ada7ef776cd 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -48,12 +48,14 @@ jobs: - name: "Run the issue labeler" if: "github.event.issue || inputs.type == 'issue'" env: + event_json: "${{ toJSON(github.event) }}" GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: ./venv/bin/python hacking/pr_labeler/label.py issue ${{ github.event.issue.number || inputs.number }} - name: "Run the PR labeler" if: "github.event.pull_request || inputs.type == 'pr'" env: + event_json: "${{ toJSON(github.event) }}" GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: ./venv/bin/python hacking/pr_labeler/label.py pr ${{ github.event.number || inputs.number }} diff --git a/hacking/pr_labeler/label.py b/hacking/pr_labeler/label.py index 752d1d82347..8356717b0b3 100644 --- a/hacking/pr_labeler/label.py +++ b/hacking/pr_labeler/label.py @@ -4,11 +4,13 @@ from __future__ import annotations import dataclasses +import json import os from collections.abc import Collection +from contextlib import suppress from functools import cached_property from pathlib import Path -from typing import Union +from typing import Any, Union import github import github.Auth @@ -45,16 +47,30 @@ def get_repo(authed: bool = True) -> tuple[github.Github, github.Repository.Repo return gclient, repo +def get_event_info() -> dict[str, Any]: + event_json = os.environ.get("event_json") + if not event_json: + return {} + with suppress(json.JSONDecodeError): + return json.loads(event_json) + return {} + + @dataclasses.dataclass() class LabelerCtx: client: github.Github repo: github.Repository.Repository dry_run: bool + event_info: dict[str, Any] @property def member(self) -> IssueOrPr: raise NotImplementedError + @property + def event_member(self) -> dict[str, Any]: + raise NotImplementedError + @cached_property def previously_labeled(self) -> frozenset[str]: labels: set[str] = set() @@ -78,6 +94,10 @@ class IssueLabelerCtx(LabelerCtx): def member(self) -> IssueOrPr: return self.issue + @property + def event_member(self) -> dict[str, Any]: + return self.event_info.get("issue", {}) + @dataclasses.dataclass() class PRLabelerCtx(LabelerCtx): @@ -87,6 +107,10 @@ class PRLabelerCtx(LabelerCtx): def member(self) -> IssueOrPr: return self.pr + @property + def event_member(self) -> dict[str, Any]: + return self.event_info.get("pull_request", {}) + def create_comment(ctx: IssueOrPrCtx, body: str) -> None: if ctx.dry_run: @@ -167,7 +191,13 @@ def process_pr( authed = True gclient, repo = get_repo(authed=authed) pr = repo.get_pull(pr_number) - ctx = PRLabelerCtx(client=gclient, repo=repo, pr=pr, dry_run=dry_run) + ctx = PRLabelerCtx( + client=gclient, + repo=repo, + pr=pr, + dry_run=dry_run, + event_info=get_event_info(), + ) if pr.state != "open": log(ctx, "Refusing to process closed ticket") return @@ -187,7 +217,13 @@ def process_issue( authed = True gclient, repo = get_repo(authed=authed) issue = repo.get_issue(issue_number) - ctx = IssueLabelerCtx(client=gclient, repo=repo, issue=issue, dry_run=dry_run) + ctx = IssueLabelerCtx( + client=gclient, + repo=repo, + issue=issue, + dry_run=dry_run, + event_info=get_event_info(), + ) if issue.state != "open": log(ctx, "Refusing to process closed ticket") return