Skip to content

Commit

Permalink
Fetch PR info from GH before trying to open it (#38)
Browse files Browse the repository at this point in the history
This allows users to open PR's outside of `av` and then let av know about
the state of that PR. This is mostly useful in scenarios where users open
the first (root) PR using the GH UI, then realize that they want to stack
their work on top if it. Before this commit, that wasn't possible (without
closing the original PR and then re-creating it with `av pr create`).

When we implement `av stack submit`, this should also happen automatically
(for every PR in the stack), but that's a bridge we haven't quite reached.
  • Loading branch information
twavv authored Jul 7, 2022
1 parent b91d8ab commit e521018
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 55 deletions.
80 changes: 54 additions & 26 deletions internal/actions/pr.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,7 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o
)
}

// TODO:
// It would be nice to be able to auto-detect that a PR has been
// opened for a given PR without using av. We might need to do this
// when creating PRs for a whole stack (e.g., when running `av pr`
// on stack branch 3, we should make sure PRs exist for 1 and 2).
branchMeta, ok := meta.ReadBranch(repo, currentBranch)
if ok && branchMeta.PullRequest != nil && !opts.Force {
_, _ = fmt.Fprint(os.Stderr,
" - ", color.RedString("ERROR: "),
"branch ", colors.UserInput(currentBranch),
" already has an associated pull request: ",
colors.UserInput(branchMeta.PullRequest.Permalink),
"\n",
)
return nil, errors.New("this branch already has an associated pull request")
}

// figure this out based on whether or not we're on a stacked branch
var prBaseBranch string
Expand Down Expand Up @@ -161,17 +146,13 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o
opts.Body = firstCommit.Body
}

pull, err := client.CreatePullRequest(ctx, githubv4.CreatePullRequestInput{
RepositoryID: githubv4.ID(repoMeta.ID),
BaseRefName: githubv4.String(prBaseBranch),
HeadRefName: githubv4.String(currentBranch),
Title: githubv4.String(opts.Title),
Body: gh.Ptr(githubv4.String(opts.Body)),
Draft: gh.Ptr(githubv4.Boolean(opts.Draft)),
pull, didCreatePR, err := getOrCreatePR(ctx, client, repoMeta, getOrCreatePROpts{
baseRefName: prBaseBranch,
headRefName: currentBranch,
title: opts.Title,
body: opts.Body,
draft: opts.Draft,
})
if err != nil {
return nil, err
}

branchMeta.PullRequest = &meta.PullRequest{
Number: pull.Number,
Expand All @@ -193,8 +174,14 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o
return nil, errors.WrapIf(err, "adding avbeta-stackedprs label")
}

var action string
if didCreatePR {
action = "created"
} else {
action = "fetched existing"
}
_, _ = fmt.Fprint(os.Stderr,
" - created pull request for branch ", colors.UserInput(currentBranch),
" - ", action, " pull request for branch ", colors.UserInput(currentBranch),
" (into branch ", colors.UserInput(prBaseBranch), "): ",
colors.UserInput(pull.Permalink),
"\n",
Expand All @@ -213,3 +200,44 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o

return pull, nil
}

type getOrCreatePROpts struct {
baseRefName string
headRefName string
title string
body string
draft bool
}

// getOrCreatePR returns the pull request for the given input, creating a new
// pull request if one doesn't exist. It returns the pull request, a boolean
// indicating whether or not the pull request was created, and an error if one
// occurred.
func getOrCreatePR(ctx context.Context, client *gh.Client, repoMeta meta.Repository, opts getOrCreatePROpts) (*gh.PullRequest, bool, error) {
existing, err := client.GetPullRequests(ctx, gh.GetPullRequestsInput{
Owner: repoMeta.Owner,
Repo: repoMeta.Name,
HeadRefName: opts.headRefName,
BaseRefName: opts.baseRefName,
States: []githubv4.PullRequestState{githubv4.PullRequestStateOpen},
})
if err != nil {
return nil, false, errors.WrapIf(err, "querying existing pull requests")
}
if len(existing.PullRequests) > 0 {
return &existing.PullRequests[0], false, nil
}

pull, err := client.CreatePullRequest(ctx, githubv4.CreatePullRequestInput{
RepositoryID: githubv4.ID(repoMeta.ID),
BaseRefName: githubv4.String(opts.baseRefName),
HeadRefName: githubv4.String(opts.headRefName),
Title: githubv4.String(opts.title),
Body: gh.Ptr(githubv4.String(opts.body)),
Draft: gh.Ptr(githubv4.Boolean(opts.draft)),
})
if err != nil {
return nil, false, errors.WrapIf(err, "opening pull request")
}
return pull, true, nil
}
22 changes: 0 additions & 22 deletions internal/gh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,3 @@ func (c *Client) restPost(ctx context.Context, endpoint string, body interface{}
}
return nil
}

// Ptr returns a pointer to the argument.
// It's a convenience function to make working with the API easier: since Go
// disallows pointers-to-literals, and optional input fields are expressed as
// pointers, this function can be used to easily set optional fields to non-nil
// primitives.
// For example, githubv4.CreatePullRequestInput{Draft: Ptr(true)}
func Ptr[T any](v T) *T {
return &v
}

// nullable returns a pointer to the argument if it's not the zero value,
// otherwise it returns nil.
// This is useful to translate between Golang-style "unset is zero" and GraphQL
// which distinguishes between unset (null) and zero values.
func nullable[T comparable](v T) *T {
var zero T
if v == zero {
return nil
}
return &v
}
32 changes: 32 additions & 0 deletions internal/gh/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package gh

// PageInfo contains information about the current/previous/next page of results
// when using paginated APIs.
type PageInfo struct {
EndCursor string
HasNextPage bool
HasPreviousPage bool
StartCursor string
}

// Ptr returns a pointer to the argument.
// It's a convenience function to make working with the API easier: since Go
// disallows pointers-to-literals, and optional input fields are expressed as
// pointers, this function can be used to easily set optional fields to non-nil
// primitives.
// For example, githubv4.CreatePullRequestInput{Draft: Ptr(true)}
func Ptr[T any](v T) *T {
return &v
}

// nullable returns a pointer to the argument if it's not the zero value,
// otherwise it returns nil.
// This is useful to translate between Golang-style "unset is zero" and GraphQL
// which distinguishes between unset (null) and zero values.
func nullable[T comparable](v T) *T {
var zero T
if v == zero {
return nil
}
return &v
}
53 changes: 46 additions & 7 deletions internal/gh/pullrequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@ func (c *Client) PullRequest(ctx context.Context, opts PullRequestOpts) (*PullRe
return &query.Repository.PullRequest, nil
}

type GetPullRequestsInput struct {
// REQUIRED
Owner string
Repo string
// OPTIONAL
HeadRefName string
BaseRefName string
States []githubv4.PullRequestState
First int64
After string
}

type GetPullRequestsPage struct {
PageInfo
PullRequests []PullRequest
}

func (c *Client) GetPullRequests(ctx context.Context, input GetPullRequestsInput) (*GetPullRequestsPage, error) {
if input.First == 0 {
input.First = 50
}
var query struct {
Repository struct {
PullRequests struct {
Nodes []PullRequest
PageInfo PageInfo
} `graphql:"pullRequests(states: $states, headRefName: $headRefName, baseRefName: $baseRefName, first: $first, after: $after)"`
} `graphql:"repository(owner: $owner, name: $repo)"`
}
if err := c.query(ctx, &query, map[string]interface{}{
"owner": githubv4.String(input.Owner),
"repo": githubv4.String(input.Repo),
"headRefName": nullable(githubv4.String(input.HeadRefName)),
"baseRefName": nullable(githubv4.String(input.BaseRefName)),
"states": input.States,
"first": githubv4.Int(input.First),
"after": nullable(githubv4.String(input.After)),
}); err != nil {
return nil, errors.Wrap(err, "failed to query pull requests")
}
return &GetPullRequestsPage{
PageInfo: query.Repository.PullRequests.PageInfo,
PullRequests: query.Repository.PullRequests.Nodes,
}, nil
}

func (c *Client) CreatePullRequest(ctx context.Context, input githubv4.CreatePullRequestInput) (*PullRequest, error) {
var mutation struct {
CreatePullRequest struct {
Expand Down Expand Up @@ -110,13 +156,6 @@ type RepoPullRequestOpts struct {
States []githubv4.PullRequestState
}

type PageInfo struct {
EndCursor string
HasNextPage bool
HasPreviousPage bool
StartCursor string
}

type RepoPullRequestsResponse struct {
PageInfo
TotalCount int64
Expand Down

0 comments on commit e521018

Please sign in to comment.