Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --label argument to label new PRs #65

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ghstack/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def main() -> None:
subparser.add_argument(
'--draft', action='store_true',
help='Create the pull request in draft mode (only if it has not already been created)')
subparser.add_argument(
'--label', action='append', default=[],
help='Add this label to all pull requests in the stack '
'(multiple --label arguments can be given)')

unlink = subparsers.add_parser('unlink')
unlink.add_argument('COMMITS', nargs='*')
Expand Down Expand Up @@ -108,6 +112,7 @@ def main() -> None:
force=args.force,
no_skip=args.no_skip,
draft=args.draft,
labels=args.label,
github_url=conf.github_url,
remote_name=conf.remote_name,
)
Expand Down
64 changes: 63 additions & 1 deletion ghstack/github_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import os.path
import re
from dataclasses import dataclass # Oof! Python 3.7 only!!
from itertools import islice
from typing import Any, Dict, List, NewType, Optional, Sequence, cast

import graphql
from sortedcontainers import SortedKeyList # type: ignore[import]
from typing_extensions import TypedDict

import ghstack.github
Expand Down Expand Up @@ -35,6 +37,10 @@
'maintainer_can_modify': bool,
})

AddLabelsInput = TypedDict('AddLabelsInput', {
'labels': List[str],
})

CreatePullRequestPayload = TypedDict('CreatePullRequestPayload', {
'number': int,
})
Expand Down Expand Up @@ -199,14 +205,31 @@ def repository(self, info: GraphQLResolveInfo) -> Repository:


@dataclass
class PullRequest(Node):
class Label(Node):
name: str


@dataclass
class PageInfo:
endCursor: Optional[str]


@dataclass
class LabelConnection:
nodes: Optional[List[Optional[Label]]]
pageInfo: PageInfo


@dataclass
class PullRequest(Node): # type: ignore[no-any-unimported]
baseRef: Optional[Ref]
baseRefName: str
body: str
closed: bool
headRef: Optional[Ref]
headRefName: str
# headRepository: Optional[Repository]
_labels: SortedKeyList # type: ignore[no-any-unimported]
# maintainerCanModify: bool
number: GitHubNumber
_repository: GraphQLId # cycle breaker
Expand All @@ -217,6 +240,25 @@ class PullRequest(Node):
def repository(self, info: GraphQLResolveInfo) -> Repository:
return github_state(info).repositories[self._repository]

def labels(self, info: GraphQLResolveInfo,
after: Optional[str] = None, before: Optional[str] = None,
first: Optional[int] = None, last: Optional[int] = None,
) -> Optional[LabelConnection]:
if first is None:
# the real API also supports `last`, but we do not
raise RuntimeError(
"You must provide a `first` value"
"to properly paginate the `labels` connection."
)
nodes = list(islice(
self._labels.irange_key(after, inclusive=(False, True)),
first,
))
return LabelConnection(
nodes=nodes,
pageInfo=PageInfo(endCursor=nodes[-1].name if nodes else None),
)


@dataclass
class PullRequestConnection:
Expand Down Expand Up @@ -311,6 +353,7 @@ def _create_pull(self, owner: str, name: str,
baseRefName=input['base'],
headRef=headRef,
headRefName=input['head'],
_labels=SortedKeyList(key=lambda label: label.name),
title=input['title'],
body=input['body'],
)
Expand Down Expand Up @@ -347,12 +390,31 @@ def _set_default_branch(self, owner: str, name: str,
repo = state.repository(owner, name)
repo.defaultBranchRef = repo._make_ref(state, input['default_branch'])

# NB: This technically does have a payload, but we don't
# use it so I didn't bother constructing it.
def _add_labels(self, owner: str, name: str, number: GitHubNumber,
input: AddLabelsInput) -> None:
state = self.state
repo = state.repository(owner, name)
pr = state.pull_request(repo, number)
labels = pr._labels
for name in input['labels']:
# https://stackoverflow.com/a/3114640
if not any(True for _ in labels.irange_key(name, name)):
labels.add(Label(id=state.next_id(), name=name))

def rest(self, method: str, path: str, **kwargs: Any) -> Any:
if method == 'post':
m = re.match(r'^repos/([^/]+)/([^/]+)/pulls$', path)
if m:
return self._create_pull(m.group(1), m.group(2),
cast(CreatePullRequestInput, kwargs))
m = re.match(r'^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels$', path)
if m:
owner, name, number = m.groups()
return self._add_labels(
owner, name, GitHubNumber(int(number)),
cast(AddLabelsInput, kwargs))
elif method == 'patch':
m = re.match(r'^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$', path)
if m:
Expand Down
12 changes: 12 additions & 0 deletions ghstack/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def main(*,
force: bool = False,
no_skip: bool = False,
draft: bool = False,
labels: List[str],
github_url: str,
remote_name: str
) -> List[Optional[DiffMeta]]:
Expand Down Expand Up @@ -185,6 +186,7 @@ def main(*,
force=force,
no_skip=no_skip,
draft=draft,
labels=labels,
stack=list(reversed(stack)),
github_url=github_url,
remote_name=remote_name)
Expand Down Expand Up @@ -296,6 +298,9 @@ class Submitter(object):
# Create the PR in draft mode if it is going to be created (and not updated).
draft: bool

# Add these labels to all PRs in the stack
labels: List[str]

# Github url (normally github.com)
github_url: str

Expand All @@ -321,6 +326,7 @@ def __init__(
force: bool,
no_skip: bool,
draft: bool,
labels: List[str],
github_url: str,
remote_name: str):
self.github = github
Expand All @@ -344,6 +350,7 @@ def __init__(
self.force = force
self.no_skip = no_skip
self.draft = draft
self.labels = labels
self.github_url = github_url
self.remote_name = remote_name

Expand Down Expand Up @@ -940,6 +947,11 @@ def push_updates(self, *, import_help: bool = True) -> None: # noqa: C901
number=s.number),
body=RE_STACK.sub(self._format_stack(i), s.body),
title=s.title)
if len(self.labels) > 0:
self.github.post(
f"repos/{self.repo_owner}/{self.repo_name}/issues/{s.number}/labels",
labels=self.labels,
)
else:
logging.info(
"# Skipping closed https://{github_url}/{owner}/{repo}/pull/{number}"
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ hypothesis = "^6"
isort = "^5"
mypy = "^0.800"
pytest = "^6"
sortedcontainers = "^2"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
51 changes: 50 additions & 1 deletion test_ghstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def substituteRev(self, rev: str, substitute: str) -> None:
def gh(self, msg: str = 'Update',
update_fields: bool = False,
short: bool = False,
no_skip: bool = False) -> List[Optional[ghstack.submit.DiffMeta]]:
no_skip: bool = False,
labels: Optional[List[str]] = None
) -> List[Optional[ghstack.submit.DiffMeta]]:
return ghstack.submit.main(
msg=msg,
username='ezyang',
Expand All @@ -109,6 +111,7 @@ def gh(self, msg: str = 'Update',
repo_name='pytorch',
short=short,
no_skip=no_skip,
labels=labels or [],
github_url='github.com',
remote_name='origin')

Expand Down Expand Up @@ -2005,6 +2008,52 @@ def test_default_branch_change(self) -> None:
rUP1 Commit 1
rINI0 Initial commit''')

# ------------------------------------------------------------------------- #

def test_labels(self) -> None:
def assert_labels(pr: int, expected: List[str]) -> None:
raw = self.github.graphql("""
query ($pr: Int!, $first: Int!) {
repository(owner: "pytorch", name: "pytorch") {
pullRequest(number: $pr) {
labels(first: $first) {
nodes {
name
}
}
}
}
}
""", pr=pr, first=len(expected) + 1)
nodes = raw['data']['repository']['pullRequest']['labels']['nodes']
actual = [label['name'] for label in nodes]
self.assertEqual(actual, expected)

self.writeFileAndAdd('file1.txt', 'A')
self.sh.git('commit', '-m', 'Commit 1')
self.sh.test_tick()
self.gh()

assert_labels(500, [])

self.writeFileAndAdd('file2.txt', 'B')
self.sh.git('commit', '-m', 'Commit 2')
self.sh.test_tick()
self.gh(labels=['foo', 'bar'])

# alphabetical order
assert_labels(500, ['bar', 'foo'])
assert_labels(501, ['bar', 'foo'])

self.writeFileAndAdd('file3.txt', 'C')
self.sh.git('commit', '-m', 'Commit 3')
self.sh.test_tick()
self.gh(labels=['foo', 'baz'])

assert_labels(500, ['bar', 'baz', 'foo'])
assert_labels(501, ['bar', 'baz', 'foo'])
assert_labels(502, ['baz', 'foo'])


if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
Expand Down