diff --git a/internal/actions/pr.go b/internal/actions/pr.go index 288a62de..000ca984 100644 --- a/internal/actions/pr.go +++ b/internal/actions/pr.go @@ -81,22 +81,7 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o ) } - // TODO: - // It would be nice to be able to auto-detect that a PR has been - // opened for a given PR without using av. We might need to do this - // when creating PRs for a whole stack (e.g., when running `av pr` - // on stack branch 3, we should make sure PRs exist for 1 and 2). branchMeta, ok := meta.ReadBranch(repo, currentBranch) - if ok && branchMeta.PullRequest != nil && !opts.Force { - _, _ = fmt.Fprint(os.Stderr, - " - ", color.RedString("ERROR: "), - "branch ", colors.UserInput(currentBranch), - " already has an associated pull request: ", - colors.UserInput(branchMeta.PullRequest.Permalink), - "\n", - ) - return nil, errors.New("this branch already has an associated pull request") - } // figure this out based on whether or not we're on a stacked branch var prBaseBranch string @@ -161,17 +146,13 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o opts.Body = firstCommit.Body } - pull, err := client.CreatePullRequest(ctx, githubv4.CreatePullRequestInput{ - RepositoryID: githubv4.ID(repoMeta.ID), - BaseRefName: githubv4.String(prBaseBranch), - HeadRefName: githubv4.String(currentBranch), - Title: githubv4.String(opts.Title), - Body: gh.Ptr(githubv4.String(opts.Body)), - Draft: gh.Ptr(githubv4.Boolean(opts.Draft)), + pull, didCreatePR, err := getOrCreatePR(ctx, client, repoMeta, getOrCreatePROpts{ + baseRefName: prBaseBranch, + headRefName: currentBranch, + title: opts.Title, + body: opts.Body, + draft: opts.Draft, }) - if err != nil { - return nil, err - } branchMeta.PullRequest = &meta.PullRequest{ Number: pull.Number, @@ -193,8 +174,14 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o return nil, errors.WrapIf(err, "adding avbeta-stackedprs label") } + var action string + if didCreatePR { + action = "created" + } else { + action = "fetched existing" + } _, _ = fmt.Fprint(os.Stderr, - " - created pull request for branch ", colors.UserInput(currentBranch), + " - ", action, " pull request for branch ", colors.UserInput(currentBranch), " (into branch ", colors.UserInput(prBaseBranch), "): ", colors.UserInput(pull.Permalink), "\n", @@ -213,3 +200,44 @@ func CreatePullRequest(ctx context.Context, repo *git.Repo, client *gh.Client, o return pull, nil } + +type getOrCreatePROpts struct { + baseRefName string + headRefName string + title string + body string + draft bool +} + +// getOrCreatePR returns the pull request for the given input, creating a new +// pull request if one doesn't exist. It returns the pull request, a boolean +// indicating whether or not the pull request was created, and an error if one +// occurred. +func getOrCreatePR(ctx context.Context, client *gh.Client, repoMeta meta.Repository, opts getOrCreatePROpts) (*gh.PullRequest, bool, error) { + existing, err := client.GetPullRequests(ctx, gh.GetPullRequestsInput{ + Owner: repoMeta.Owner, + Repo: repoMeta.Name, + HeadRefName: opts.headRefName, + BaseRefName: opts.baseRefName, + States: []githubv4.PullRequestState{githubv4.PullRequestStateOpen}, + }) + if err != nil { + return nil, false, errors.WrapIf(err, "querying existing pull requests") + } + if len(existing.PullRequests) > 0 { + return &existing.PullRequests[0], false, nil + } + + pull, err := client.CreatePullRequest(ctx, githubv4.CreatePullRequestInput{ + RepositoryID: githubv4.ID(repoMeta.ID), + BaseRefName: githubv4.String(opts.baseRefName), + HeadRefName: githubv4.String(opts.headRefName), + Title: githubv4.String(opts.title), + Body: gh.Ptr(githubv4.String(opts.body)), + Draft: gh.Ptr(githubv4.Boolean(opts.draft)), + }) + if err != nil { + return nil, false, errors.WrapIf(err, "opening pull request") + } + return pull, true, nil +} diff --git a/internal/gh/client.go b/internal/gh/client.go index e4858c95..166b905e 100644 --- a/internal/gh/client.go +++ b/internal/gh/client.go @@ -134,25 +134,3 @@ func (c *Client) restPost(ctx context.Context, endpoint string, body interface{} } return nil } - -// Ptr returns a pointer to the argument. -// It's a convenience function to make working with the API easier: since Go -// disallows pointers-to-literals, and optional input fields are expressed as -// pointers, this function can be used to easily set optional fields to non-nil -// primitives. -// For example, githubv4.CreatePullRequestInput{Draft: Ptr(true)} -func Ptr[T any](v T) *T { - return &v -} - -// nullable returns a pointer to the argument if it's not the zero value, -// otherwise it returns nil. -// This is useful to translate between Golang-style "unset is zero" and GraphQL -// which distinguishes between unset (null) and zero values. -func nullable[T comparable](v T) *T { - var zero T - if v == zero { - return nil - } - return &v -} diff --git a/internal/gh/common.go b/internal/gh/common.go new file mode 100644 index 00000000..f88d88cd --- /dev/null +++ b/internal/gh/common.go @@ -0,0 +1,32 @@ +package gh + +// PageInfo contains information about the current/previous/next page of results +// when using paginated APIs. +type PageInfo struct { + EndCursor string + HasNextPage bool + HasPreviousPage bool + StartCursor string +} + +// Ptr returns a pointer to the argument. +// It's a convenience function to make working with the API easier: since Go +// disallows pointers-to-literals, and optional input fields are expressed as +// pointers, this function can be used to easily set optional fields to non-nil +// primitives. +// For example, githubv4.CreatePullRequestInput{Draft: Ptr(true)} +func Ptr[T any](v T) *T { + return &v +} + +// nullable returns a pointer to the argument if it's not the zero value, +// otherwise it returns nil. +// This is useful to translate between Golang-style "unset is zero" and GraphQL +// which distinguishes between unset (null) and zero values. +func nullable[T comparable](v T) *T { + var zero T + if v == zero { + return nil + } + return &v +} diff --git a/internal/gh/pullrequest.go b/internal/gh/pullrequest.go index fb26ce47..05bdc1e3 100644 --- a/internal/gh/pullrequest.go +++ b/internal/gh/pullrequest.go @@ -59,6 +59,52 @@ func (c *Client) PullRequest(ctx context.Context, opts PullRequestOpts) (*PullRe return &query.Repository.PullRequest, nil } +type GetPullRequestsInput struct { + // REQUIRED + Owner string + Repo string + // OPTIONAL + HeadRefName string + BaseRefName string + States []githubv4.PullRequestState + First int64 + After string +} + +type GetPullRequestsPage struct { + PageInfo + PullRequests []PullRequest +} + +func (c *Client) GetPullRequests(ctx context.Context, input GetPullRequestsInput) (*GetPullRequestsPage, error) { + if input.First == 0 { + input.First = 50 + } + var query struct { + Repository struct { + PullRequests struct { + Nodes []PullRequest + PageInfo PageInfo + } `graphql:"pullRequests(states: $states, headRefName: $headRefName, baseRefName: $baseRefName, first: $first, after: $after)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + if err := c.query(ctx, &query, map[string]interface{}{ + "owner": githubv4.String(input.Owner), + "repo": githubv4.String(input.Repo), + "headRefName": nullable(githubv4.String(input.HeadRefName)), + "baseRefName": nullable(githubv4.String(input.BaseRefName)), + "states": input.States, + "first": githubv4.Int(input.First), + "after": nullable(githubv4.String(input.After)), + }); err != nil { + return nil, errors.Wrap(err, "failed to query pull requests") + } + return &GetPullRequestsPage{ + PageInfo: query.Repository.PullRequests.PageInfo, + PullRequests: query.Repository.PullRequests.Nodes, + }, nil +} + func (c *Client) CreatePullRequest(ctx context.Context, input githubv4.CreatePullRequestInput) (*PullRequest, error) { var mutation struct { CreatePullRequest struct { @@ -110,13 +156,6 @@ type RepoPullRequestOpts struct { States []githubv4.PullRequestState } -type PageInfo struct { - EndCursor string - HasNextPage bool - HasPreviousPage bool - StartCursor string -} - type RepoPullRequestsResponse struct { PageInfo TotalCount int64