Skip to content

Commit

Permalink
labeler: pass Github Action event data to script
Browse files Browse the repository at this point in the history
This way, the script can take action based on specific events when that
data is available.
  • Loading branch information
gotmax23 committed Jul 31, 2023
1 parent 47a86ba commit 1f1252d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ jobs:
- name: "Run the issue labeler"
if: "github.event.issue || inputs.type == 'issue'"
env:
event_json: "${{ toJSON(github.event) }}"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run:
./venv/bin/python hacking/pr_labeler/label.py issue ${{ github.event.issue.number || inputs.number }}
- name: "Run the PR labeler"
if: "github.event.pull_request || inputs.type == 'pr'"
env:
event_json: "${{ toJSON(github.event) }}"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run:
./venv/bin/python hacking/pr_labeler/label.py pr ${{ github.event.number || inputs.number }}
42 changes: 39 additions & 3 deletions hacking/pr_labeler/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from __future__ import annotations

import dataclasses
import json
import os
from collections.abc import Collection
from contextlib import suppress
from functools import cached_property
from pathlib import Path
from typing import Union
from typing import Any, Union

import github
import github.Auth
Expand Down Expand Up @@ -45,16 +47,30 @@ def get_repo(authed: bool = True) -> tuple[github.Github, github.Repository.Repo
return gclient, repo


def get_event_info() -> dict[str, Any]:
event_json = os.environ.get("event_json")
if not event_json:
return {}
with suppress(json.JSONDecodeError):
return json.loads(event_json)
return {}


@dataclasses.dataclass()
class LabelerCtx:
client: github.Github
repo: github.Repository.Repository
dry_run: bool
event_info: dict[str, Any]

@property
def member(self) -> IssueOrPr:
raise NotImplementedError

@property
def event_member(self) -> dict[str, Any]:
raise NotImplementedError

@cached_property
def previously_labeled(self) -> frozenset[str]:
labels: set[str] = set()
Expand All @@ -78,6 +94,10 @@ class IssueLabelerCtx(LabelerCtx):
def member(self) -> IssueOrPr:
return self.issue

@property
def event_member(self) -> dict[str, Any]:
return self.event_info.get("issue", {})


@dataclasses.dataclass()
class PRLabelerCtx(LabelerCtx):
Expand All @@ -87,6 +107,10 @@ class PRLabelerCtx(LabelerCtx):
def member(self) -> IssueOrPr:
return self.pr

@property
def event_member(self) -> dict[str, Any]:
return self.event_info.get("pull_request", {})


def create_comment(ctx: IssueOrPrCtx, body: str) -> None:
if ctx.dry_run:
Expand Down Expand Up @@ -167,7 +191,13 @@ def process_pr(
authed = True
gclient, repo = get_repo(authed=authed)
pr = repo.get_pull(pr_number)
ctx = PRLabelerCtx(client=gclient, repo=repo, pr=pr, dry_run=dry_run)
ctx = PRLabelerCtx(
client=gclient,
repo=repo,
pr=pr,
dry_run=dry_run,
event_info=get_event_info(),
)
if pr.state != "open":
log(ctx, "Refusing to process closed ticket")
return
Expand All @@ -187,7 +217,13 @@ def process_issue(
authed = True
gclient, repo = get_repo(authed=authed)
issue = repo.get_issue(issue_number)
ctx = IssueLabelerCtx(client=gclient, repo=repo, issue=issue, dry_run=dry_run)
ctx = IssueLabelerCtx(
client=gclient,
repo=repo,
issue=issue,
dry_run=dry_run,
event_info=get_event_info(),
)
if issue.state != "open":
log(ctx, "Refusing to process closed ticket")
return
Expand Down

0 comments on commit 1f1252d

Please sign in to comment.