Skip to content

Commit

Permalink
More refactoring of db.DefaultContext (#27083)
Browse files Browse the repository at this point in the history
Next step of #27065
  • Loading branch information
JakobDev authored Sep 15, 2023
1 parent f8a1094 commit c548dde
Show file tree
Hide file tree
Showing 83 changed files with 336 additions and 320 deletions.
2 changes: 1 addition & 1 deletion cmd/admin_user_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func runCreateUser(c *cli.Context) error {
UID: u.ID,
}

if err := auth_model.NewAccessToken(t); err != nil {
if err := auth_model.NewAccessToken(ctx, t); err != nil {
return err
}

Expand Down
4 changes: 2 additions & 2 deletions cmd/admin_user_generate_access_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func runGenerateAccessToken(c *cli.Context) error {
UID: user.ID,
}

exist, err := auth_model.AccessTokenByNameExists(t)
exist, err := auth_model.AccessTokenByNameExists(ctx, t)
if err != nil {
return err
}
Expand All @@ -79,7 +79,7 @@ func runGenerateAccessToken(c *cli.Context) error {
t.Scope = accessTokenScope

// create the token
if err := auth_model.NewAccessToken(t); err != nil {
if err := auth_model.NewAccessToken(ctx, t); err != nil {
return err
}

Expand Down
31 changes: 16 additions & 15 deletions models/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package auth

import (
"context"
"crypto/subtle"
"encoding/hex"
"fmt"
Expand Down Expand Up @@ -95,7 +96,7 @@ func init() {
}

// NewAccessToken creates new access token.
func NewAccessToken(t *AccessToken) error {
func NewAccessToken(ctx context.Context, t *AccessToken) error {
salt, err := util.CryptoRandomString(10)
if err != nil {
return err
Expand All @@ -108,7 +109,7 @@ func NewAccessToken(t *AccessToken) error {
t.Token = hex.EncodeToString(token)
t.TokenHash = HashToken(t.Token, t.TokenSalt)
t.TokenLastEight = t.Token[len(t.Token)-8:]
_, err = db.GetEngine(db.DefaultContext).Insert(t)
_, err = db.GetEngine(ctx).Insert(t)
return err
}

Expand Down Expand Up @@ -137,7 +138,7 @@ func getAccessTokenIDFromCache(token string) int64 {
}

// GetAccessTokenBySHA returns access token by given token value
func GetAccessTokenBySHA(token string) (*AccessToken, error) {
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
if token == "" {
return nil, ErrAccessTokenEmpty{}
}
Expand All @@ -158,7 +159,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
TokenLastEight: lastEight,
}
// Re-get the token from the db in case it has been deleted in the intervening period
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(accessToken)
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
if err != nil {
return nil, err
}
Expand All @@ -169,7 +170,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}

var tokens []AccessToken
err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
if err != nil {
return nil, err
} else if len(tokens) == 0 {
Expand All @@ -189,8 +190,8 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}

// AccessTokenByNameExists checks if a token name has been used already by a user.
func AccessTokenByNameExists(token *AccessToken) (bool, error) {
return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
}

// ListAccessTokensOptions contain filter options
Expand All @@ -201,8 +202,8 @@ type ListAccessTokensOptions struct {
}

// ListAccessTokens returns a list of access tokens belongs to given user.
func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
func ListAccessTokens(ctx context.Context, opts ListAccessTokensOptions) ([]*AccessToken, error) {
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)

if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
Expand All @@ -222,23 +223,23 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
}

// UpdateAccessToken updates information of access token.
func UpdateAccessToken(t *AccessToken) error {
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}

// CountAccessTokens count access tokens belongs to given user by options
func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
func CountAccessTokens(ctx context.Context, opts ListAccessTokensOptions) (int64, error) {
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
}
return sess.Count(&AccessToken{})
}

// DeleteAccessTokenByID deletes access token by given ID.
func DeleteAccessTokenByID(id, userID int64) error {
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
UID: userID,
})
if err != nil {
Expand Down
35 changes: 18 additions & 17 deletions models/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"

"github.com/stretchr/testify/assert"
Expand All @@ -18,15 +19,15 @@ func TestNewAccessToken(t *testing.T) {
UID: 3,
Name: "Token C",
}
assert.NoError(t, auth_model.NewAccessToken(token))
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)

invalidToken := &auth_model.AccessToken{
ID: token.ID, // duplicate
UID: 2,
Name: "Token F",
}
assert.Error(t, auth_model.NewAccessToken(invalidToken))
assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
}

func TestAccessTokenByNameExists(t *testing.T) {
Expand All @@ -39,16 +40,16 @@ func TestAccessTokenByNameExists(t *testing.T) {
}

// Check to make sure it doesn't exists already
exist, err := auth_model.AccessTokenByNameExists(token)
exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.False(t, exist)

// Save it to the database
assert.NoError(t, auth_model.NewAccessToken(token))
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)

// This token must be found by name in the DB now
exist, err = auth_model.AccessTokenByNameExists(token)
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.True(t, exist)

Expand All @@ -59,32 +60,32 @@ func TestAccessTokenByNameExists(t *testing.T) {

// Name matches but different user ID, this shouldn't exists in the
// database
exist, err = auth_model.AccessTokenByNameExists(user4Token)
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
assert.NoError(t, err)
assert.False(t, exist)
}

func TestGetAccessTokenBySHA(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA("d2c6c1ba3890b309189a8e618c72a162e4efbf36")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.Equal(t, "Token A", token.Name)
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
assert.Equal(t, "e4efbf36", token.TokenLastEight)

_, err = auth_model.GetAccessTokenBySHA("notahash")
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))

_, err = auth_model.GetAccessTokenBySHA("")
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
}

func TestListAccessTokens(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
tokens, err := auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 1})
tokens, err := auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
assert.NoError(t, err)
if assert.Len(t, tokens, 2) {
assert.Equal(t, int64(1), tokens[0].UID)
Expand All @@ -93,39 +94,39 @@ func TestListAccessTokens(t *testing.T) {
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
}

tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 2})
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
assert.NoError(t, err)
if assert.Len(t, tokens, 1) {
assert.Equal(t, int64(2), tokens[0].UID)
assert.Equal(t, "Token A", tokens[0].Name)
}

tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 100})
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
assert.NoError(t, err)
assert.Empty(t, tokens)
}

func TestUpdateAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
token.Name = "Token Z"

assert.NoError(t, auth_model.UpdateAccessToken(token))
assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
}

func TestDeleteAccessTokenByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())

token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)

assert.NoError(t, auth_model.DeleteAccessTokenByID(token.ID, 1))
assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
unittest.AssertNotExistsBean(t, token)

err = auth_model.DeleteAccessTokenByID(100, 100)
err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
}
21 changes: 11 additions & 10 deletions models/auth/twofactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package auth

import (
"context"
"crypto/md5"
"crypto/subtle"
"encoding/base32"
Expand Down Expand Up @@ -121,22 +122,22 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
}

// NewTwoFactor creates a new two-factor authentication token.
func NewTwoFactor(t *TwoFactor) error {
_, err := db.GetEngine(db.DefaultContext).Insert(t)
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).Insert(t)
return err
}

// UpdateTwoFactor updates a two-factor authentication token.
func UpdateTwoFactor(t *TwoFactor) error {
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}

// GetTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
twofa := &TwoFactor{}
has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
if err != nil {
return nil, err
} else if !has {
Expand All @@ -147,13 +148,13 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {

// HasTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func HasTwoFactorByUID(uid int64) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
}

// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
func DeleteTwoFactorByID(id, userID int64) error {
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
UID: userID,
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions models/issues/comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ func (c *Comment) LoadPoster(ctx context.Context) (err error) {
}

// AfterDelete is invoked from XORM after the object is deleted.
func (c *Comment) AfterDelete() {
func (c *Comment) AfterDelete(ctx context.Context) {
if c.ID <= 0 {
return
}

_, err := repo_model.DeleteAttachmentsByComment(c.ID, true)
_, err := repo_model.DeleteAttachmentsByComment(ctx, c.ID, true)
if err != nil {
log.Info("Could not delete files for comment %d on issue #%d: %s", c.ID, c.IssueID, err)
}
Expand Down
14 changes: 7 additions & 7 deletions models/issues/pull_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ type PullRequestsOptions struct {
MilestoneID int64
}

func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", baseRepoID)
func listPullRequestStatement(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", baseRepoID)

sess.Join("INNER", "issue", "pull_request.issue_id = issue.id")
switch opts.State {
Expand Down Expand Up @@ -115,21 +115,21 @@ func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch
}

// GetPullRequestIDsByCheckStatus returns all pull requests according the special checking status.
func GetPullRequestIDsByCheckStatus(status PullRequestStatus) ([]int64, error) {
func GetPullRequestIDsByCheckStatus(ctx context.Context, status PullRequestStatus) ([]int64, error) {
prs := make([]int64, 0, 10)
return prs, db.GetEngine(db.DefaultContext).Table("pull_request").
return prs, db.GetEngine(ctx).Table("pull_request").
Where("status=?", status).
Cols("pull_request.id").
Find(&prs)
}

// PullRequests returns all pull requests for a base Repo by the given conditions
func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
func PullRequests(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
if opts.Page <= 0 {
opts.Page = 1
}

countSession, err := listPullRequestStatement(baseRepoID, opts)
countSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
if err != nil {
log.Error("listPullRequestStatement: %v", err)
return nil, 0, err
Expand All @@ -140,7 +140,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
return nil, maxResults, err
}

findSession, err := listPullRequestStatement(baseRepoID, opts)
findSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
applySorts(findSession, opts.SortType, 0)
if err != nil {
log.Error("listPullRequestStatement: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions models/issues/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestPullRequest_LoadHeadRepo(t *testing.T) {

func TestPullRequestsNewest(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
ListOptions: db.ListOptions{
Page: 1,
},
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestLoadRequestedReviewers(t *testing.T) {

func TestPullRequestsOldest(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
ListOptions: db.ListOptions{
Page: 1,
},
Expand Down
Loading

0 comments on commit c548dde

Please sign in to comment.