Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --label argument to label new PRs #65

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ghstack/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def main() -> None:
subparser.add_argument(
'--draft', action='store_true',
help='Create the pull request in draft mode (only if it has not already been created)')
subparser.add_argument(
'--label', action='append', default=[],
help='Add this label to any newly created pull requests '
'(multiple --label arguments can be given)')

unlink = subparsers.add_parser('unlink')
unlink.add_argument('COMMITS', nargs='*')
Expand Down Expand Up @@ -108,6 +112,7 @@ def main() -> None:
force=args.force,
no_skip=args.no_skip,
draft=args.draft,
labels=args.label,
github_url=conf.github_url,
remote_name=conf.remote_name,
)
Expand Down
47 changes: 47 additions & 0 deletions ghstack/github_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,29 @@
'maintainer_can_modify': bool,
})

AddLabelsInput = TypedDict('AddLabelsInput', {
'labels': List[str],
})

CreatePullRequestPayload = TypedDict('CreatePullRequestPayload', {
'number': int,
})

# omitting many of these fields because we don't use them
Label = TypedDict('Label', {
# 'id': int,
# 'node_id': str,
# 'url': str,
'name': str,
# 'description': str,
# 'color': str,
# 'default': bool,
})

AddLabelsPayload = List[Label]

ListLabelsPayload = List[Label]


# The "database" for our mock instance
class GitHubState:
Expand Down Expand Up @@ -213,6 +232,7 @@ class PullRequest(Node):
# state: PullRequestState
title: str
url: str
labels: List[Label]

def repository(self, info: GraphQLResolveInfo) -> Repository:
return github_state(info).repositories[self._repository]
Expand Down Expand Up @@ -313,6 +333,7 @@ def _create_pull(self, owner: str, name: str,
headRefName=input['head'],
title=input['title'],
body=input['body'],
labels=[],
)
# TODO: compute files changed
state.pull_requests[id] = pr
Expand Down Expand Up @@ -347,12 +368,33 @@ def _set_default_branch(self, owner: str, name: str,
repo = state.repository(owner, name)
repo.defaultBranchRef = repo._make_ref(state, input['default_branch'])

def _add_labels(self, owner: str, name: str, number: GitHubNumber,
input: AddLabelsInput) -> AddLabelsPayload:
state = self.state
repo = state.repository(owner, name)
pr = state.pull_request(repo, number)
pr.labels += [{'name': label} for label in input['labels']]
samestep marked this conversation as resolved.
Show resolved Hide resolved
return pr.labels

def _list_labels(self, owner: str, name: str, number: GitHubNumber) -> ListLabelsPayload:
state = self.state
repo = state.repository(owner, name)
pr = state.pull_request(repo, number)
return pr.labels

def rest(self, method: str, path: str, **kwargs: Any) -> Any:
labels_re = r'^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels$'
if method == 'post':
m = re.match(r'^repos/([^/]+)/([^/]+)/pulls$', path)
if m:
return self._create_pull(m.group(1), m.group(2),
cast(CreatePullRequestInput, kwargs))
m = re.match(labels_re, path)
if m:
owner, name, number = m.groups()
return self._add_labels(
owner, name, GitHubNumber(int(number)),
cast(AddLabelsInput, kwargs))
elif method == 'patch':
m = re.match(r'^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$', path)
if m:
Expand All @@ -365,6 +407,11 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
return self._set_default_branch(
owner, name,
cast(SetDefaultBranchInput, kwargs))
elif method == 'get':
m = re.match(labels_re, path)
if m:
owner, name, number = m.groups()
return self._list_labels(owner, name, GitHubNumber(int(number)))
raise NotImplementedError(
"FakeGitHubEndpoint REST {} {} not implemented"
.format(method.upper(), path)
Expand Down
12 changes: 12 additions & 0 deletions ghstack/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def main(*,
force: bool = False,
no_skip: bool = False,
draft: bool = False,
labels: List[str],
github_url: str,
remote_name: str
) -> List[Optional[DiffMeta]]:
Expand Down Expand Up @@ -185,6 +186,7 @@ def main(*,
force=force,
no_skip=no_skip,
draft=draft,
labels=labels,
stack=list(reversed(stack)),
github_url=github_url,
remote_name=remote_name)
Expand Down Expand Up @@ -296,6 +298,9 @@ class Submitter(object):
# Create the PR in draft mode if it is going to be created (and not updated).
draft: bool

# Add these labels to newly created PRs
labels: List[str]

# Github url (normally github.com)
github_url: str

Expand All @@ -321,6 +326,7 @@ def __init__(
force: bool,
no_skip: bool,
draft: bool,
labels: List[str],
github_url: str,
remote_name: str):
self.github = github
Expand All @@ -344,6 +350,7 @@ def __init__(
self.force = force
self.no_skip = no_skip
self.draft = draft
self.labels = labels
self.github_url = github_url
self.remote_name = remote_name

Expand Down Expand Up @@ -608,6 +615,11 @@ def process_new_commit(self, commit: ghstack.diff.Diff) -> None:
draft=self.draft,
)
number = r['number']
if len(self.labels) > 0:
self.github.post(
f"repos/{self.repo_owner}/{self.repo_name}/issues/{number}/labels",
labels=self.labels,
)

logging.info("Opened PR #{}".format(number))

Expand Down
34 changes: 33 additions & 1 deletion test_ghstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def substituteRev(self, rev: str, substitute: str) -> None:
def gh(self, msg: str = 'Update',
update_fields: bool = False,
short: bool = False,
no_skip: bool = False) -> List[Optional[ghstack.submit.DiffMeta]]:
no_skip: bool = False,
labels: List[str] = []) -> List[Optional[ghstack.submit.DiffMeta]]:
samestep marked this conversation as resolved.
Show resolved Hide resolved
return ghstack.submit.main(
msg=msg,
username='ezyang',
Expand All @@ -109,6 +110,7 @@ def gh(self, msg: str = 'Update',
repo_name='pytorch',
short=short,
no_skip=no_skip,
labels=labels,
github_url='github.com',
remote_name='origin')

Expand Down Expand Up @@ -2005,6 +2007,36 @@ def test_default_branch_change(self) -> None:
rUP1 Commit 1
rINI0 Initial commit''')

# ------------------------------------------------------------------------- #

def test_labels(self) -> None:
# first commit
self.writeFileAndAdd('file1.txt', 'A')
self.sh.git('commit', '-m', 'Commit 1')
self.sh.test_tick()
# ghstack
self.gh()
# second commit
self.writeFileAndAdd('file2.txt', 'B')
self.sh.git('commit', '-m', 'Commit 2')
self.sh.test_tick()
# third commit
self.writeFileAndAdd('file3.txt', 'C')
self.sh.git('commit', '-m', 'Commit 3')
self.sh.test_tick()
# ghstack with labels
self.gh(labels=['foo', 'bar'])

def get_labels(n: int) -> List[str]:
labels = self.github.get(f'repos/pytorch/pytorch/issues/{n}/labels')
return [label['name'] for label in labels]

# was already created before second ghstack run
self.assertEqual(get_labels(500), [])
samestep marked this conversation as resolved.
Show resolved Hide resolved
# included in the second ghstack run
self.assertEqual(get_labels(501), ['foo', 'bar'])
self.assertEqual(get_labels(502), ['foo', 'bar'])


if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
Expand Down