Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context parameter to some database functions #26055

Merged
merged 5 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions models/activities/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ func (a *Action) GetIssueInfos() []string {
}

// GetIssueTitle returns the title of first issue associated
// with the action.
// with the action. This function will be invoked in template so keep db.DefaultContext here
func (a *Action) GetIssueTitle() string {
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
issue, err := issues_model.GetIssueByIndex(db.DefaultContext, a.RepoID, index)
lunny marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
log.Error("GetIssueByIndex: %v", err)
return "500 when get issue"
Expand All @@ -404,9 +404,9 @@ func (a *Action) GetIssueTitle() string {

// GetIssueContent returns the content of first issue associated with
// this action.
func (a *Action) GetIssueContent() string {
func (a *Action) GetIssueContent(ctx context.Context) string {
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
issue, err := issues_model.GetIssueByIndex(ctx, a.RepoID, index)
if err != nil {
log.Error("GetIssueByIndex: %v", err)
return "500 when get issue"
Expand Down
54 changes: 27 additions & 27 deletions models/activities/repo_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ type ActivityStats struct {
func GetActivityStats(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, releases, issues, prs, code bool) (*ActivityStats, error) {
stats := &ActivityStats{Code: &git.CodeActivityStats{}}
if releases {
if err := stats.FillReleases(repo.ID, timeFrom); err != nil {
if err := stats.FillReleases(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillReleases: %w", err)
}
}
if prs {
if err := stats.FillPullRequests(repo.ID, timeFrom); err != nil {
if err := stats.FillPullRequests(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillPullRequests: %w", err)
}
}
if issues {
if err := stats.FillIssues(repo.ID, timeFrom); err != nil {
if err := stats.FillIssues(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillIssues: %w", err)
}
}
if err := stats.FillUnresolvedIssues(repo.ID, timeFrom, issues, prs); err != nil {
if err := stats.FillUnresolvedIssues(ctx, repo.ID, timeFrom, issues, prs); err != nil {
return nil, fmt.Errorf("FillUnresolvedIssues: %w", err)
}
if code {
Expand Down Expand Up @@ -205,41 +205,41 @@ func (stats *ActivityStats) PublishedReleaseCount() int {
}

// FillPullRequests returns pull request information for activity page
func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) error {
func (stats *ActivityStats) FillPullRequests(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64

// Merged pull requests
sess := pullRequestsForActivityStatement(repoID, fromTime, true)
sess := pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
sess.OrderBy("pull_request.merged_unix DESC")
stats.MergedPRs = make(issues_model.PullRequestList, 0)
if err = sess.Find(&stats.MergedPRs); err != nil {
return err
}
if err = stats.MergedPRs.LoadAttributes(); err != nil {
if err = stats.MergedPRs.LoadAttributes(ctx); err != nil {
return err
}

// Merged pull request authors
sess = pullRequestsForActivityStatement(repoID, fromTime, true)
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
return err
}
stats.MergedPRAuthorCount = count

// Opened pull requests
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
sess.OrderBy("issue.created_unix ASC")
stats.OpenedPRs = make(issues_model.PullRequestList, 0)
if err = sess.Find(&stats.OpenedPRs); err != nil {
return err
}
if err = stats.OpenedPRs.LoadAttributes(); err != nil {
if err = stats.OpenedPRs.LoadAttributes(ctx); err != nil {
return err
}

// Opened pull request authors
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
return err
}
Expand All @@ -248,8 +248,8 @@ func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) e
return nil
}

func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged bool) *xorm.Session {
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", repoID).
func pullRequestsForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, merged bool) *xorm.Session {
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", repoID).
Join("INNER", "issue", "pull_request.issue_id = issue.id")

if merged {
Expand All @@ -264,35 +264,35 @@ func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged b
}

// FillIssues returns issue information for activity page
func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
func (stats *ActivityStats) FillIssues(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64

// Closed issues
sess := issuesForActivityStatement(repoID, fromTime, true, false)
sess := issuesForActivityStatement(ctx, repoID, fromTime, true, false)
sess.OrderBy("issue.closed_unix DESC")
stats.ClosedIssues = make(issues_model.IssueList, 0)
if err = sess.Find(&stats.ClosedIssues); err != nil {
return err
}

// Closed issue authors
sess = issuesForActivityStatement(repoID, fromTime, true, false)
sess = issuesForActivityStatement(ctx, repoID, fromTime, true, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
return err
}
stats.ClosedIssueAuthorCount = count

// New issues
sess = issuesForActivityStatement(repoID, fromTime, false, false)
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
sess.OrderBy("issue.created_unix ASC")
stats.OpenedIssues = make(issues_model.IssueList, 0)
if err = sess.Find(&stats.OpenedIssues); err != nil {
return err
}

// Opened issue authors
sess = issuesForActivityStatement(repoID, fromTime, false, false)
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
return err
}
Expand All @@ -302,12 +302,12 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
}

// FillUnresolvedIssues returns unresolved issue and pull request information for activity page
func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Time, issues, prs bool) error {
func (stats *ActivityStats) FillUnresolvedIssues(ctx context.Context, repoID int64, fromTime time.Time, issues, prs bool) error {
// Check if we need to select anything
if !issues && !prs {
return nil
}
sess := issuesForActivityStatement(repoID, fromTime, false, true)
sess := issuesForActivityStatement(ctx, repoID, fromTime, false, true)
if !issues || !prs {
sess.And("issue.is_pull = ?", prs)
}
Expand All @@ -316,8 +316,8 @@ func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Tim
return sess.Find(&stats.UnresolvedIssues)
}

func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
sess := db.GetEngine(db.DefaultContext).Where("issue.repo_id = ?", repoID).
func issuesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
And("issue.is_closed = ?", closed)

if !unresolved {
Expand All @@ -336,20 +336,20 @@ func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unreso
}

// FillReleases returns release information for activity page
func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error {
func (stats *ActivityStats) FillReleases(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64

// Published releases list
sess := releasesForActivityStatement(repoID, fromTime)
sess := releasesForActivityStatement(ctx, repoID, fromTime)
sess.OrderBy("release.created_unix DESC")
stats.PublishedReleases = make([]*repo_model.Release, 0)
if err = sess.Find(&stats.PublishedReleases); err != nil {
return err
}

// Published releases authors
sess = releasesForActivityStatement(repoID, fromTime)
sess = releasesForActivityStatement(ctx, repoID, fromTime)
if _, err = sess.Select("count(distinct release.publisher_id) as `count`").Table("release").Get(&count); err != nil {
return err
}
Expand All @@ -358,8 +358,8 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error
return nil
}

func releasesForActivityStatement(repoID int64, fromTime time.Time) *xorm.Session {
return db.GetEngine(db.DefaultContext).Where("release.repo_id = ?", repoID).
func releasesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
return db.GetEngine(ctx).Where("release.repo_id = ?", repoID).
And("release.is_draft = ?", false).
And("release.created_unix >= ?", fromTime.Unix())
}
11 changes: 3 additions & 8 deletions models/issues/comment_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error {
return nil
}

// loadAttributes loads all attributes
func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
// LoadAttributes loads attributes of the comments, except for attachments and
// comments
func (comments CommentList) LoadAttributes(ctx context.Context) (err error) {
if err = comments.LoadPosters(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -501,9 +502,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) {

return comments.loadDependentIssues(ctx)
}

// LoadAttributes loads attributes of the comments, except for attachments and
// comments
func (comments CommentList) LoadAttributes() error {
return comments.loadAttributes(db.DefaultContext)
}
14 changes: 7 additions & 7 deletions models/issues/issue.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) {
return err
}

if err = issue.Comments.loadAttributes(ctx); err != nil {
if err = issue.Comments.LoadAttributes(ctx); err != nil {
return err
}
if issue.IsTimetrackerEnabled(ctx) {
Expand Down Expand Up @@ -502,15 +502,15 @@ func (issue *Issue) GetLastEventLabelFake() string {
}

// GetIssueByIndex returns raw issue without loading attributes by index in a repository.
func GetIssueByIndex(repoID, index int64) (*Issue, error) {
func GetIssueByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
if index < 1 {
return nil, ErrIssueNotExist{}
}
issue := &Issue{
RepoID: repoID,
Index: index,
}
has, err := db.GetEngine(db.DefaultContext).Get(issue)
has, err := db.GetEngine(ctx).Get(issue)
if err != nil {
return nil, err
} else if !has {
Expand All @@ -520,12 +520,12 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
}

// GetIssueWithAttrsByIndex returns issue by index in a repository.
func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) {
issue, err := GetIssueByIndex(repoID, index)
func GetIssueWithAttrsByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
issue, err := GetIssueByIndex(ctx, repoID, index)
if err != nil {
return nil, err
}
return issue, issue.LoadAttributes(db.DefaultContext)
return issue, issue.LoadAttributes(ctx)
}

// GetIssueByID returns an issue by given ID.
Expand Down Expand Up @@ -846,7 +846,7 @@ func GetPinnedIssues(ctx context.Context, repoID int64, isPull bool) ([]*Issue,
return nil, err
}

err = IssueList(issues).LoadAttributes()
err = IssueList(issues).LoadAttributes(ctx)
if err != nil {
return nil, err
}
Expand Down
8 changes: 1 addition & 7 deletions models/issues/issue_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) {
}

// loadAttributes loads all attributes, expect for attachments and comments
func (issues IssueList) loadAttributes(ctx context.Context) error {
func (issues IssueList) LoadAttributes(ctx context.Context) error {
if _, err := issues.LoadRepositories(ctx); err != nil {
return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err)
}
Expand Down Expand Up @@ -562,12 +562,6 @@ func (issues IssueList) loadAttributes(ctx context.Context) error {
return nil
}

// LoadAttributes loads attributes of the issues, except for attachments and
// comments
func (issues IssueList) LoadAttributes() error {
return issues.loadAttributes(db.DefaultContext)
}

// LoadComments loads comments
func (issues IssueList) LoadComments(ctx context.Context) error {
return issues.loadComments(ctx, builder.NewCond())
Expand Down
2 changes: 1 addition & 1 deletion models/issues/issue_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestIssueList_LoadAttributes(t *testing.T) {
unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}),
}

assert.NoError(t, issueList.LoadAttributes())
assert.NoError(t, issueList.LoadAttributes(db.DefaultContext))
for _, issue := range issueList {
assert.EqualValues(t, issue.RepoID, issue.Repo.ID)
for _, label := range issue.Labels {
Expand Down
2 changes: 1 addition & 1 deletion models/issues/issue_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) {
return nil, fmt.Errorf("unable to query Issues: %w", err)
}

if err := issues.LoadAttributes(); err != nil {
if err := issues.LoadAttributes(ctx); err != nil {
return nil, fmt.Errorf("unable to LoadAttributes for Issues: %w", err)
}

Expand Down
23 changes: 9 additions & 14 deletions models/issues/pull_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor
}

// GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged
func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) {
func GetUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
sess := db.GetEngine(db.DefaultContext).
sess := db.GetEngine(ctx).
Join("INNER", "issue", "issue.id = pull_request.issue_id").
Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub)
return prs, sess.Find(&prs)
}

// CanMaintainerWriteToBranch check whether user is a maintainer and could write to the branch
func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *user_model.User) bool {
func CanMaintainerWriteToBranch(ctx context.Context, p access_model.Permission, branch string, user *user_model.User) bool {
if p.CanWrite(unit.TypeCode) {
return true
}
Expand All @@ -69,18 +69,18 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *
return false
}

prs, err := GetUnmergedPullRequestsByHeadInfo(p.Units[0].RepoID, branch)
prs, err := GetUnmergedPullRequestsByHeadInfo(ctx, p.Units[0].RepoID, branch)
if err != nil {
return false
}

for _, pr := range prs {
if pr.AllowMaintainerEdit {
err = pr.LoadBaseRepo(db.DefaultContext)
err = pr.LoadBaseRepo(ctx)
if err != nil {
continue
}
prPerm, err := access_model.GetUserRepoPermission(db.DefaultContext, pr.BaseRepo, user)
prPerm, err := access_model.GetUserRepoPermission(ctx, pr.BaseRepo, user)
if err != nil {
continue
}
Expand All @@ -104,9 +104,9 @@ func HasUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch

// GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged
// by given base information (repo and branch).
func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) {
func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
return prs, db.GetEngine(db.DefaultContext).
return prs, db.GetEngine(ctx).
Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?",
repoID, branch, false, false).
OrderBy("issue.updated_unix DESC").
Expand Down Expand Up @@ -154,7 +154,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
// PullRequestList defines a list of pull requests
type PullRequestList []*PullRequest

func (prs PullRequestList) loadAttributes(ctx context.Context) error {
func (prs PullRequestList) LoadAttributes(ctx context.Context) error {
if len(prs) == 0 {
return nil
}
Expand Down Expand Up @@ -199,8 +199,3 @@ func (prs PullRequestList) GetIssueIDs() []int64 {
}
return issueIDs
}

// LoadAttributes load all the prs attributes
func (prs PullRequestList) LoadAttributes() error {
return prs.loadAttributes(db.DefaultContext)
}
Loading