parent
f8a1094406
commit
c548dde205
83 changed files with 336 additions and 320 deletions
|
@ -5,6 +5,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
@ -95,7 +96,7 @@ func init() {
|
|||
}
|
||||
|
||||
// NewAccessToken creates new access token.
|
||||
func NewAccessToken(t *AccessToken) error {
|
||||
func NewAccessToken(ctx context.Context, t *AccessToken) error {
|
||||
salt, err := util.CryptoRandomString(10)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -108,7 +109,7 @@ func NewAccessToken(t *AccessToken) error {
|
|||
t.Token = hex.EncodeToString(token)
|
||||
t.TokenHash = HashToken(t.Token, t.TokenSalt)
|
||||
t.TokenLastEight = t.Token[len(t.Token)-8:]
|
||||
_, err = db.GetEngine(db.DefaultContext).Insert(t)
|
||||
_, err = db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -137,7 +138,7 @@ func getAccessTokenIDFromCache(token string) int64 {
|
|||
}
|
||||
|
||||
// GetAccessTokenBySHA returns access token by given token value
|
||||
func GetAccessTokenBySHA(token string) (*AccessToken, error) {
|
||||
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
|
||||
if token == "" {
|
||||
return nil, ErrAccessTokenEmpty{}
|
||||
}
|
||||
|
@ -158,7 +159,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
|
|||
TokenLastEight: lastEight,
|
||||
}
|
||||
// Re-get the token from the db in case it has been deleted in the intervening period
|
||||
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(accessToken)
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -169,7 +170,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
|
|||
}
|
||||
|
||||
var tokens []AccessToken
|
||||
err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
|
||||
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(tokens) == 0 {
|
||||
|
@ -189,8 +190,8 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
|
|||
}
|
||||
|
||||
// AccessTokenByNameExists checks if a token name has been used already by a user.
|
||||
func AccessTokenByNameExists(token *AccessToken) (bool, error) {
|
||||
return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
|
||||
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
|
||||
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
|
||||
}
|
||||
|
||||
// ListAccessTokensOptions contain filter options
|
||||
|
@ -201,8 +202,8 @@ type ListAccessTokensOptions struct {
|
|||
}
|
||||
|
||||
// ListAccessTokens returns a list of access tokens belongs to given user.
|
||||
func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
|
||||
func ListAccessTokens(ctx context.Context, opts ListAccessTokensOptions) ([]*AccessToken, error) {
|
||||
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
|
||||
|
||||
if len(opts.Name) != 0 {
|
||||
sess = sess.Where("name=?", opts.Name)
|
||||
|
@ -222,14 +223,14 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
|
|||
}
|
||||
|
||||
// UpdateAccessToken updates information of access token.
|
||||
func UpdateAccessToken(t *AccessToken) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
|
||||
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// CountAccessTokens count access tokens belongs to given user by options
|
||||
func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
|
||||
func CountAccessTokens(ctx context.Context, opts ListAccessTokensOptions) (int64, error) {
|
||||
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
|
||||
if len(opts.Name) != 0 {
|
||||
sess = sess.Where("name=?", opts.Name)
|
||||
}
|
||||
|
@ -237,8 +238,8 @@ func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
|
|||
}
|
||||
|
||||
// DeleteAccessTokenByID deletes access token by given ID.
|
||||
func DeleteAccessTokenByID(id, userID int64) error {
|
||||
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{
|
||||
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
|
||||
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
|
||||
UID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -18,7 +19,7 @@ func TestNewAccessToken(t *testing.T) {
|
|||
UID: 3,
|
||||
Name: "Token C",
|
||||
}
|
||||
assert.NoError(t, auth_model.NewAccessToken(token))
|
||||
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
|
||||
invalidToken := &auth_model.AccessToken{
|
||||
|
@ -26,7 +27,7 @@ func TestNewAccessToken(t *testing.T) {
|
|||
UID: 2,
|
||||
Name: "Token F",
|
||||
}
|
||||
assert.Error(t, auth_model.NewAccessToken(invalidToken))
|
||||
assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
|
||||
}
|
||||
|
||||
func TestAccessTokenByNameExists(t *testing.T) {
|
||||
|
@ -39,16 +40,16 @@ func TestAccessTokenByNameExists(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check to make sure it doesn't exists already
|
||||
exist, err := auth_model.AccessTokenByNameExists(token)
|
||||
exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exist)
|
||||
|
||||
// Save it to the database
|
||||
assert.NoError(t, auth_model.NewAccessToken(token))
|
||||
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
|
||||
// This token must be found by name in the DB now
|
||||
exist, err = auth_model.AccessTokenByNameExists(token)
|
||||
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exist)
|
||||
|
||||
|
@ -59,32 +60,32 @@ func TestAccessTokenByNameExists(t *testing.T) {
|
|||
|
||||
// Name matches but different user ID, this shouldn't exists in the
|
||||
// database
|
||||
exist, err = auth_model.AccessTokenByNameExists(user4Token)
|
||||
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exist)
|
||||
}
|
||||
|
||||
func TestGetAccessTokenBySHA(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token, err := auth_model.GetAccessTokenBySHA("d2c6c1ba3890b309189a8e618c72a162e4efbf36")
|
||||
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), token.UID)
|
||||
assert.Equal(t, "Token A", token.Name)
|
||||
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
|
||||
assert.Equal(t, "e4efbf36", token.TokenLastEight)
|
||||
|
||||
_, err = auth_model.GetAccessTokenBySHA("notahash")
|
||||
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
|
||||
|
||||
_, err = auth_model.GetAccessTokenBySHA("")
|
||||
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
|
||||
}
|
||||
|
||||
func TestListAccessTokens(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
tokens, err := auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 1})
|
||||
tokens, err := auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, tokens, 2) {
|
||||
assert.Equal(t, int64(1), tokens[0].UID)
|
||||
|
@ -93,39 +94,39 @@ func TestListAccessTokens(t *testing.T) {
|
|||
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
|
||||
}
|
||||
|
||||
tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 2})
|
||||
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, tokens, 1) {
|
||||
assert.Equal(t, int64(2), tokens[0].UID)
|
||||
assert.Equal(t, "Token A", tokens[0].Name)
|
||||
}
|
||||
|
||||
tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 100})
|
||||
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, tokens)
|
||||
}
|
||||
|
||||
func TestUpdateAccessToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
assert.NoError(t, err)
|
||||
token.Name = "Token Z"
|
||||
|
||||
assert.NoError(t, auth_model.UpdateAccessToken(token))
|
||||
assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
}
|
||||
|
||||
func TestDeleteAccessTokenByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), token.UID)
|
||||
|
||||
assert.NoError(t, auth_model.DeleteAccessTokenByID(token.ID, 1))
|
||||
assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
|
||||
unittest.AssertNotExistsBean(t, token)
|
||||
|
||||
err = auth_model.DeleteAccessTokenByID(100, 100)
|
||||
err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/subtle"
|
||||
"encoding/base32"
|
||||
|
@ -121,22 +122,22 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
|
|||
}
|
||||
|
||||
// NewTwoFactor creates a new two-factor authentication token.
|
||||
func NewTwoFactor(t *TwoFactor) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).Insert(t)
|
||||
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
|
||||
_, err := db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTwoFactor updates a two-factor authentication token.
|
||||
func UpdateTwoFactor(t *TwoFactor) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
|
||||
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
|
||||
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
|
||||
twofa := &TwoFactor{}
|
||||
has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
|
||||
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
|
@ -147,13 +148,13 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
|
|||
|
||||
// HasTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func HasTwoFactorByUID(uid int64) (bool, error) {
|
||||
return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
|
||||
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
|
||||
}
|
||||
|
||||
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
|
||||
func DeleteTwoFactorByID(id, userID int64) error {
|
||||
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
|
||||
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
|
||||
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
|
||||
UID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -359,12 +359,12 @@ func (c *Comment) LoadPoster(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// AfterDelete is invoked from XORM after the object is deleted.
|
||||
func (c *Comment) AfterDelete() {
|
||||
func (c *Comment) AfterDelete(ctx context.Context) {
|
||||
if c.ID <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
_, err := repo_model.DeleteAttachmentsByComment(c.ID, true)
|
||||
_, err := repo_model.DeleteAttachmentsByComment(ctx, c.ID, true)
|
||||
if err != nil {
|
||||
log.Info("Could not delete files for comment %d on issue #%d: %s", c.ID, c.IssueID, err)
|
||||
}
|
||||
|
|
|
@ -27,8 +27,8 @@ type PullRequestsOptions struct {
|
|||
MilestoneID int64
|
||||
}
|
||||
|
||||
func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", baseRepoID)
|
||||
func listPullRequestStatement(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
|
||||
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", baseRepoID)
|
||||
|
||||
sess.Join("INNER", "issue", "pull_request.issue_id = issue.id")
|
||||
switch opts.State {
|
||||
|
@ -115,21 +115,21 @@ func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch
|
|||
}
|
||||
|
||||
// GetPullRequestIDsByCheckStatus returns all pull requests according the special checking status.
|
||||
func GetPullRequestIDsByCheckStatus(status PullRequestStatus) ([]int64, error) {
|
||||
func GetPullRequestIDsByCheckStatus(ctx context.Context, status PullRequestStatus) ([]int64, error) {
|
||||
prs := make([]int64, 0, 10)
|
||||
return prs, db.GetEngine(db.DefaultContext).Table("pull_request").
|
||||
return prs, db.GetEngine(ctx).Table("pull_request").
|
||||
Where("status=?", status).
|
||||
Cols("pull_request.id").
|
||||
Find(&prs)
|
||||
}
|
||||
|
||||
// PullRequests returns all pull requests for a base Repo by the given conditions
|
||||
func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
|
||||
func PullRequests(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
|
||||
if opts.Page <= 0 {
|
||||
opts.Page = 1
|
||||
}
|
||||
|
||||
countSession, err := listPullRequestStatement(baseRepoID, opts)
|
||||
countSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
|
||||
if err != nil {
|
||||
log.Error("listPullRequestStatement: %v", err)
|
||||
return nil, 0, err
|
||||
|
@ -140,7 +140,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
|
|||
return nil, maxResults, err
|
||||
}
|
||||
|
||||
findSession, err := listPullRequestStatement(baseRepoID, opts)
|
||||
findSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
|
||||
applySorts(findSession, opts.SortType, 0)
|
||||
if err != nil {
|
||||
log.Error("listPullRequestStatement: %v", err)
|
||||
|
|
|
@ -60,7 +60,7 @@ func TestPullRequest_LoadHeadRepo(t *testing.T) {
|
|||
|
||||
func TestPullRequestsNewest(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
|
||||
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
|
||||
ListOptions: db.ListOptions{
|
||||
Page: 1,
|
||||
},
|
||||
|
@ -107,7 +107,7 @@ func TestLoadRequestedReviewers(t *testing.T) {
|
|||
|
||||
func TestPullRequestsOldest(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
|
||||
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
|
||||
ListOptions: db.ListOptions{
|
||||
Page: 1,
|
||||
},
|
||||
|
|
|
@ -37,9 +37,9 @@ func init() {
|
|||
}
|
||||
|
||||
// IncreaseDownloadCount is update download count + 1
|
||||
func (a *Attachment) IncreaseDownloadCount() error {
|
||||
func (a *Attachment) IncreaseDownloadCount(ctx context.Context) error {
|
||||
// Update download count.
|
||||
if _, err := db.GetEngine(db.DefaultContext).Exec("UPDATE `attachment` SET download_count=download_count+1 WHERE id=?", a.ID); err != nil {
|
||||
if _, err := db.GetEngine(ctx).Exec("UPDATE `attachment` SET download_count=download_count+1 WHERE id=?", a.ID); err != nil {
|
||||
return fmt.Errorf("increase attachment count: %w", err)
|
||||
}
|
||||
|
||||
|
@ -164,8 +164,8 @@ func GetAttachmentByReleaseIDFileName(ctx context.Context, releaseID int64, file
|
|||
}
|
||||
|
||||
// DeleteAttachment deletes the given attachment and optionally the associated file.
|
||||
func DeleteAttachment(a *Attachment, remove bool) error {
|
||||
_, err := DeleteAttachments(db.DefaultContext, []*Attachment{a}, remove)
|
||||
func DeleteAttachment(ctx context.Context, a *Attachment, remove bool) error {
|
||||
_, err := DeleteAttachments(ctx, []*Attachment{a}, remove)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -196,23 +196,23 @@ func DeleteAttachments(ctx context.Context, attachments []*Attachment, remove bo
|
|||
}
|
||||
|
||||
// DeleteAttachmentsByIssue deletes all attachments associated with the given issue.
|
||||
func DeleteAttachmentsByIssue(issueID int64, remove bool) (int, error) {
|
||||
attachments, err := GetAttachmentsByIssueID(db.DefaultContext, issueID)
|
||||
func DeleteAttachmentsByIssue(ctx context.Context, issueID int64, remove bool) (int, error) {
|
||||
attachments, err := GetAttachmentsByIssueID(ctx, issueID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return DeleteAttachments(db.DefaultContext, attachments, remove)
|
||||
return DeleteAttachments(ctx, attachments, remove)
|
||||
}
|
||||
|
||||
// DeleteAttachmentsByComment deletes all attachments associated with the given comment.
|
||||
func DeleteAttachmentsByComment(commentID int64, remove bool) (int, error) {
|
||||
attachments, err := GetAttachmentsByCommentID(db.DefaultContext, commentID)
|
||||
func DeleteAttachmentsByComment(ctx context.Context, commentID int64, remove bool) (int, error) {
|
||||
attachments, err := GetAttachmentsByCommentID(ctx, commentID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return DeleteAttachments(db.DefaultContext, attachments, remove)
|
||||
return DeleteAttachments(ctx, attachments, remove)
|
||||
}
|
||||
|
||||
// UpdateAttachmentByUUID Updates attachment via uuid
|
||||
|
|
|
@ -21,7 +21,7 @@ func TestIncreaseDownloadCount(t *testing.T) {
|
|||
assert.Equal(t, int64(0), attachment.DownloadCount)
|
||||
|
||||
// increase download count
|
||||
err = attachment.IncreaseDownloadCount()
|
||||
err = attachment.IncreaseDownloadCount(db.DefaultContext)
|
||||
assert.NoError(t, err)
|
||||
|
||||
attachment, err = repo_model.GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
|
||||
|
@ -45,15 +45,15 @@ func TestGetByCommentOrIssueID(t *testing.T) {
|
|||
func TestDeleteAttachments(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
count, err := repo_model.DeleteAttachmentsByIssue(4, false)
|
||||
count, err := repo_model.DeleteAttachmentsByIssue(db.DefaultContext, 4, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
count, err = repo_model.DeleteAttachmentsByComment(2, false)
|
||||
count, err = repo_model.DeleteAttachmentsByComment(db.DefaultContext, 2, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
err = repo_model.DeleteAttachment(&repo_model.Attachment{ID: 8}, false)
|
||||
err = repo_model.DeleteAttachment(db.DefaultContext, &repo_model.Attachment{ID: 8}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
attachment, err := repo_model.GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a18")
|
||||
|
|
|
@ -24,8 +24,8 @@ func init() {
|
|||
}
|
||||
|
||||
// StarRepo or unstar repository.
|
||||
func StarRepo(userID, repoID int64, star bool) error {
|
||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
||||
func StarRepo(ctx context.Context, userID, repoID int64, star bool) error {
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -72,8 +72,8 @@ func IsStaring(ctx context.Context, userID, repoID int64) bool {
|
|||
}
|
||||
|
||||
// GetStargazers returns the users that starred the repo.
|
||||
func GetStargazers(repo *Repository, opts db.ListOptions) ([]*user_model.User, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).Where("star.repo_id = ?", repo.ID).
|
||||
func GetStargazers(ctx context.Context, repo *Repository, opts db.ListOptions) ([]*user_model.User, error) {
|
||||
sess := db.GetEngine(ctx).Where("star.repo_id = ?", repo.ID).
|
||||
Join("LEFT", "star", "`user`.id = star.uid")
|
||||
if opts.Page > 0 {
|
||||
sess = db.SetSessionPagination(sess, &opts)
|
||||
|
|
|
@ -18,11 +18,11 @@ func TestStarRepo(t *testing.T) {
|
|||
const userID = 2
|
||||
const repoID = 1
|
||||
unittest.AssertNotExistsBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
assert.NoError(t, repo_model.StarRepo(userID, repoID, true))
|
||||
assert.NoError(t, repo_model.StarRepo(db.DefaultContext, userID, repoID, true))
|
||||
unittest.AssertExistsAndLoadBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
assert.NoError(t, repo_model.StarRepo(userID, repoID, true))
|
||||
assert.NoError(t, repo_model.StarRepo(db.DefaultContext, userID, repoID, true))
|
||||
unittest.AssertExistsAndLoadBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
assert.NoError(t, repo_model.StarRepo(userID, repoID, false))
|
||||
assert.NoError(t, repo_model.StarRepo(db.DefaultContext, userID, repoID, false))
|
||||
unittest.AssertNotExistsBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ func TestRepository_GetStargazers(t *testing.T) {
|
|||
// repo with stargazers
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 4})
|
||||
gazers, err := repo_model.GetStargazers(repo, db.ListOptions{Page: 0})
|
||||
gazers, err := repo_model.GetStargazers(db.DefaultContext, repo, db.ListOptions{Page: 0})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, gazers, 1) {
|
||||
assert.Equal(t, int64(2), gazers[0].ID)
|
||||
|
@ -47,7 +47,7 @@ func TestRepository_GetStargazers2(t *testing.T) {
|
|||
// repo with stargazers
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 3})
|
||||
gazers, err := repo_model.GetStargazers(repo, db.ListOptions{Page: 0})
|
||||
gazers, err := repo_model.GetStargazers(db.DefaultContext, repo, db.ListOptions{Page: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, gazers, 0)
|
||||
}
|
||||
|
@ -57,15 +57,15 @@ func TestClearRepoStars(t *testing.T) {
|
|||
const userID = 2
|
||||
const repoID = 1
|
||||
unittest.AssertNotExistsBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
assert.NoError(t, repo_model.StarRepo(userID, repoID, true))
|
||||
assert.NoError(t, repo_model.StarRepo(db.DefaultContext, userID, repoID, true))
|
||||
unittest.AssertExistsAndLoadBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
assert.NoError(t, repo_model.StarRepo(userID, repoID, false))
|
||||
assert.NoError(t, repo_model.StarRepo(db.DefaultContext, userID, repoID, false))
|
||||
unittest.AssertNotExistsBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
assert.NoError(t, repo_model.ClearRepoStars(db.DefaultContext, repoID))
|
||||
unittest.AssertNotExistsBean(t, &repo_model.Star{UID: userID, RepoID: repoID})
|
||||
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
|
||||
gazers, err := repo_model.GetStargazers(repo, db.ListOptions{Page: 0})
|
||||
gazers, err := repo_model.GetStargazers(db.DefaultContext, repo, db.ListOptions{Page: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, gazers, 0)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
|
@ -61,7 +62,7 @@ func (upload *Upload) LocalPath() string {
|
|||
}
|
||||
|
||||
// NewUpload creates a new upload object.
|
||||
func NewUpload(name string, buf []byte, file multipart.File) (_ *Upload, err error) {
|
||||
func NewUpload(ctx context.Context, name string, buf []byte, file multipart.File) (_ *Upload, err error) {
|
||||
upload := &Upload{
|
||||
UUID: gouuid.New().String(),
|
||||
Name: name,
|
||||
|
@ -84,7 +85,7 @@ func NewUpload(name string, buf []byte, file multipart.File) (_ *Upload, err err
|
|||
return nil, fmt.Errorf("Copy: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.GetEngine(db.DefaultContext).Insert(upload); err != nil {
|
||||
if _, err := db.GetEngine(ctx).Insert(upload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -92,9 +93,9 @@ func NewUpload(name string, buf []byte, file multipart.File) (_ *Upload, err err
|
|||
}
|
||||
|
||||
// GetUploadByUUID returns the Upload by UUID
|
||||
func GetUploadByUUID(uuid string) (*Upload, error) {
|
||||
func GetUploadByUUID(ctx context.Context, uuid string) (*Upload, error) {
|
||||
upload := &Upload{}
|
||||
has, err := db.GetEngine(db.DefaultContext).Where("uuid=?", uuid).Get(upload)
|
||||
has, err := db.GetEngine(ctx).Where("uuid=?", uuid).Get(upload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
|
@ -104,23 +105,23 @@ func GetUploadByUUID(uuid string) (*Upload, error) {
|
|||
}
|
||||
|
||||
// GetUploadsByUUIDs returns multiple uploads by UUIDS
|
||||
func GetUploadsByUUIDs(uuids []string) ([]*Upload, error) {
|
||||
func GetUploadsByUUIDs(ctx context.Context, uuids []string) ([]*Upload, error) {
|
||||
if len(uuids) == 0 {
|
||||
return []*Upload{}, nil
|
||||
}
|
||||
|
||||
// Silently drop invalid uuids.
|
||||
uploads := make([]*Upload, 0, len(uuids))
|
||||
return uploads, db.GetEngine(db.DefaultContext).In("uuid", uuids).Find(&uploads)
|
||||
return uploads, db.GetEngine(ctx).In("uuid", uuids).Find(&uploads)
|
||||
}
|
||||
|
||||
// DeleteUploads deletes multiple uploads
|
||||
func DeleteUploads(uploads ...*Upload) (err error) {
|
||||
func DeleteUploads(ctx context.Context, uploads ...*Upload) (err error) {
|
||||
if len(uploads) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -159,8 +160,8 @@ func DeleteUploads(uploads ...*Upload) (err error) {
|
|||
}
|
||||
|
||||
// DeleteUploadByUUID deletes a upload by UUID
|
||||
func DeleteUploadByUUID(uuid string) error {
|
||||
upload, err := GetUploadByUUID(uuid)
|
||||
func DeleteUploadByUUID(ctx context.Context, uuid string) error {
|
||||
upload, err := GetUploadByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
if IsErrUploadNotExist(err) {
|
||||
return nil
|
||||
|
@ -168,7 +169,7 @@ func DeleteUploadByUUID(uuid string) error {
|
|||
return fmt.Errorf("GetUploadByUUID: %w", err)
|
||||
}
|
||||
|
||||
if err := DeleteUploads(upload); err != nil {
|
||||
if err := DeleteUploads(ctx, upload); err != nil {
|
||||
return fmt.Errorf("DeleteUpload: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -59,8 +59,8 @@ func IsWatchMode(mode WatchMode) bool {
|
|||
}
|
||||
|
||||
// IsWatching checks if user has watched given repository.
|
||||
func IsWatching(userID, repoID int64) bool {
|
||||
watch, err := GetWatch(db.DefaultContext, userID, repoID)
|
||||
func IsWatching(ctx context.Context, userID, repoID int64) bool {
|
||||
watch, err := GetWatch(ctx, userID, repoID)
|
||||
return err == nil && IsWatchMode(watch.Mode)
|
||||
}
|
||||
|
||||
|
@ -155,8 +155,8 @@ func GetRepoWatchersIDs(ctx context.Context, repoID int64) ([]int64, error) {
|
|||
}
|
||||
|
||||
// GetRepoWatchers returns range of users watching given repository.
|
||||
func GetRepoWatchers(repoID int64, opts db.ListOptions) ([]*user_model.User, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).Where("watch.repo_id=?", repoID).
|
||||
func GetRepoWatchers(ctx context.Context, repoID int64, opts db.ListOptions) ([]*user_model.User, error) {
|
||||
sess := db.GetEngine(ctx).Where("watch.repo_id=?", repoID).
|
||||
Join("LEFT", "watch", "`user`.id=`watch`.user_id").
|
||||
And("`watch`.mode<>?", WatchModeDont)
|
||||
if opts.Page > 0 {
|
||||
|
|
|
@ -17,13 +17,13 @@ import (
|
|||
func TestIsWatching(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
assert.True(t, repo_model.IsWatching(1, 1))
|
||||
assert.True(t, repo_model.IsWatching(4, 1))
|
||||
assert.True(t, repo_model.IsWatching(11, 1))
|
||||
assert.True(t, repo_model.IsWatching(db.DefaultContext, 1, 1))
|
||||
assert.True(t, repo_model.IsWatching(db.DefaultContext, 4, 1))
|
||||
assert.True(t, repo_model.IsWatching(db.DefaultContext, 11, 1))
|
||||
|
||||
assert.False(t, repo_model.IsWatching(1, 5))
|
||||
assert.False(t, repo_model.IsWatching(8, 1))
|
||||
assert.False(t, repo_model.IsWatching(unittest.NonexistentID, unittest.NonexistentID))
|
||||
assert.False(t, repo_model.IsWatching(db.DefaultContext, 1, 5))
|
||||
assert.False(t, repo_model.IsWatching(db.DefaultContext, 8, 1))
|
||||
assert.False(t, repo_model.IsWatching(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID))
|
||||
}
|
||||
|
||||
func TestGetWatchers(t *testing.T) {
|
||||
|
@ -47,7 +47,7 @@ func TestRepository_GetWatchers(t *testing.T) {
|
|||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
|
||||
watchers, err := repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err := repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, repo.NumWatches)
|
||||
for _, watcher := range watchers {
|
||||
|
@ -55,7 +55,7 @@ func TestRepository_GetWatchers(t *testing.T) {
|
|||
}
|
||||
|
||||
repo = unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 9})
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, 0)
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ func TestWatchIfAuto(t *testing.T) {
|
|||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
|
||||
watchers, err := repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err := repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, repo.NumWatches)
|
||||
|
||||
|
@ -74,13 +74,13 @@ func TestWatchIfAuto(t *testing.T) {
|
|||
|
||||
// Must not add watch
|
||||
assert.NoError(t, repo_model.WatchIfAuto(db.DefaultContext, 8, 1, true))
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, prevCount)
|
||||
|
||||
// Should not add watch
|
||||
assert.NoError(t, repo_model.WatchIfAuto(db.DefaultContext, 10, 1, true))
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, prevCount)
|
||||
|
||||
|
@ -88,31 +88,31 @@ func TestWatchIfAuto(t *testing.T) {
|
|||
|
||||
// Must not add watch
|
||||
assert.NoError(t, repo_model.WatchIfAuto(db.DefaultContext, 8, 1, true))
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, prevCount)
|
||||
|
||||
// Should not add watch
|
||||
assert.NoError(t, repo_model.WatchIfAuto(db.DefaultContext, 12, 1, false))
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, prevCount)
|
||||
|
||||
// Should add watch
|
||||
assert.NoError(t, repo_model.WatchIfAuto(db.DefaultContext, 12, 1, true))
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, prevCount+1)
|
||||
|
||||
// Should remove watch, inhibit from adding auto
|
||||
assert.NoError(t, repo_model.WatchRepo(db.DefaultContext, 12, 1, false))
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, prevCount)
|
||||
|
||||
// Must not add watch
|
||||
assert.NoError(t, repo_model.WatchIfAuto(db.DefaultContext, 12, 1, true))
|
||||
watchers, err = repo_model.GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
|
||||
watchers, err = repo_model.GetRepoWatchers(db.DefaultContext, repo.ID, db.ListOptions{Page: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, watchers, prevCount)
|
||||
}
|
||||
|
|
|
@ -28,9 +28,9 @@ func init() {
|
|||
}
|
||||
|
||||
// GetUserOpenIDs returns all openid addresses that belongs to given user.
|
||||
func GetUserOpenIDs(uid int64) ([]*UserOpenID, error) {
|
||||
func GetUserOpenIDs(ctx context.Context, uid int64) ([]*UserOpenID, error) {
|
||||
openids := make([]*UserOpenID, 0, 5)
|
||||
if err := db.GetEngine(db.DefaultContext).
|
||||
if err := db.GetEngine(ctx).
|
||||
Where("uid=?", uid).
|
||||
Asc("id").
|
||||
Find(&openids); err != nil {
|
||||
|
@ -82,16 +82,16 @@ func AddUserOpenID(ctx context.Context, openid *UserOpenID) error {
|
|||
}
|
||||
|
||||
// DeleteUserOpenID deletes an openid address of given user.
|
||||
func DeleteUserOpenID(openid *UserOpenID) (err error) {
|
||||
func DeleteUserOpenID(ctx context.Context, openid *UserOpenID) (err error) {
|
||||
var deleted int64
|
||||
// ask to check UID
|
||||
address := UserOpenID{
|
||||
UID: openid.UID,
|
||||
}
|
||||
if openid.ID > 0 {
|
||||
deleted, err = db.GetEngine(db.DefaultContext).ID(openid.ID).Delete(&address)
|
||||
deleted, err = db.GetEngine(ctx).ID(openid.ID).Delete(&address)
|
||||
} else {
|
||||
deleted, err = db.GetEngine(db.DefaultContext).
|
||||
deleted, err = db.GetEngine(ctx).
|
||||
Where("openid=?", openid.URI).
|
||||
Delete(&address)
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ func DeleteUserOpenID(openid *UserOpenID) (err error) {
|
|||
}
|
||||
|
||||
// ToggleUserOpenIDVisibility toggles visibility of an openid address of given user.
|
||||
func ToggleUserOpenIDVisibility(id int64) (err error) {
|
||||
_, err = db.GetEngine(db.DefaultContext).Exec("update `user_open_id` set `show` = not `show` where `id` = ?", id)
|
||||
func ToggleUserOpenIDVisibility(ctx context.Context, id int64) (err error) {
|
||||
_, err = db.GetEngine(ctx).Exec("update `user_open_id` set `show` = not `show` where `id` = ?", id)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package user_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
|
||||
|
@ -15,7 +16,7 @@ import (
|
|||
func TestGetUserOpenIDs(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
oids, err := user_model.GetUserOpenIDs(int64(1))
|
||||
oids, err := user_model.GetUserOpenIDs(db.DefaultContext, int64(1))
|
||||
if assert.NoError(t, err) && assert.Len(t, oids, 2) {
|
||||
assert.Equal(t, "https://user1.domain1.tld/", oids[0].URI)
|
||||
assert.False(t, oids[0].Show)
|
||||
|
@ -23,7 +24,7 @@ func TestGetUserOpenIDs(t *testing.T) {
|
|||
assert.True(t, oids[1].Show)
|
||||
}
|
||||
|
||||
oids, err = user_model.GetUserOpenIDs(int64(2))
|
||||
oids, err = user_model.GetUserOpenIDs(db.DefaultContext, int64(2))
|
||||
if assert.NoError(t, err) && assert.Len(t, oids, 1) {
|
||||
assert.Equal(t, "https://domain1.tld/user2/", oids[0].URI)
|
||||
assert.True(t, oids[0].Show)
|
||||
|
@ -32,28 +33,28 @@ func TestGetUserOpenIDs(t *testing.T) {
|
|||
|
||||
func TestToggleUserOpenIDVisibility(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
oids, err := user_model.GetUserOpenIDs(int64(2))
|
||||
oids, err := user_model.GetUserOpenIDs(db.DefaultContext, int64(2))
|
||||
if !assert.NoError(t, err) || !assert.Len(t, oids, 1) {
|
||||
return
|
||||
}
|
||||
assert.True(t, oids[0].Show)
|
||||
|
||||
err = user_model.ToggleUserOpenIDVisibility(oids[0].ID)
|
||||
err = user_model.ToggleUserOpenIDVisibility(db.DefaultContext, oids[0].ID)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
oids, err = user_model.GetUserOpenIDs(int64(2))
|
||||
oids, err = user_model.GetUserOpenIDs(db.DefaultContext, int64(2))
|
||||
if !assert.NoError(t, err) || !assert.Len(t, oids, 1) {
|
||||
return
|
||||
}
|
||||
assert.False(t, oids[0].Show)
|
||||
err = user_model.ToggleUserOpenIDVisibility(oids[0].ID)
|
||||
err = user_model.ToggleUserOpenIDVisibility(db.DefaultContext, oids[0].ID)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
oids, err = user_model.GetUserOpenIDs(int64(2))
|
||||
oids, err = user_model.GetUserOpenIDs(db.DefaultContext, int64(2))
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -59,9 +59,9 @@ func genSettingCacheKey(userID int64, key string) string {
|
|||
}
|
||||
|
||||
// GetSetting returns the setting value via the key
|
||||
func GetSetting(uid int64, key string) (string, error) {
|
||||
func GetSetting(ctx context.Context, uid int64, key string) (string, error) {
|
||||
return cache.GetString(genSettingCacheKey(uid, key), func() (string, error) {
|
||||
res, err := GetSettingNoCache(uid, key)
|
||||
res, err := GetSettingNoCache(ctx, uid, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -70,8 +70,8 @@ func GetSetting(uid int64, key string) (string, error) {
|
|||
}
|
||||
|
||||
// GetSettingNoCache returns specific setting without using the cache
|
||||
func GetSettingNoCache(uid int64, key string) (*Setting, error) {
|
||||
v, err := GetSettings(uid, []string{key})
|
||||
func GetSettingNoCache(ctx context.Context, uid int64, key string) (*Setting, error) {
|
||||
v, err := GetSettings(ctx, uid, []string{key})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -82,9 +82,9 @@ func GetSettingNoCache(uid int64, key string) (*Setting, error) {
|
|||
}
|
||||
|
||||
// GetSettings returns specific settings from user
|
||||
func GetSettings(uid int64, keys []string) (map[string]*Setting, error) {
|
||||
func GetSettings(ctx context.Context, uid int64, keys []string) (map[string]*Setting, error) {
|
||||
settings := make([]*Setting, 0, len(keys))
|
||||
if err := db.GetEngine(db.DefaultContext).
|
||||
if err := db.GetEngine(ctx).
|
||||
Where("user_id=?", uid).
|
||||
And(builder.In("setting_key", keys)).
|
||||
Find(&settings); err != nil {
|
||||
|
@ -98,9 +98,9 @@ func GetSettings(uid int64, keys []string) (map[string]*Setting, error) {
|
|||
}
|
||||
|
||||
// GetUserAllSettings returns all settings from user
|
||||
func GetUserAllSettings(uid int64) (map[string]*Setting, error) {
|
||||
func GetUserAllSettings(ctx context.Context, uid int64) (map[string]*Setting, error) {
|
||||
settings := make([]*Setting, 0, 5)
|
||||
if err := db.GetEngine(db.DefaultContext).
|
||||
if err := db.GetEngine(ctx).
|
||||
Where("user_id=?", uid).
|
||||
Find(&settings); err != nil {
|
||||
return nil, err
|
||||
|
@ -123,13 +123,13 @@ func validateUserSettingKey(key string) error {
|
|||
}
|
||||
|
||||
// GetUserSetting gets a specific setting for a user
|
||||
func GetUserSetting(userID int64, key string, def ...string) (string, error) {
|
||||
func GetUserSetting(ctx context.Context, userID int64, key string, def ...string) (string, error) {
|
||||
if err := validateUserSettingKey(key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
setting := &Setting{UserID: userID, SettingKey: key}
|
||||
has, err := db.GetEngine(db.DefaultContext).Get(setting)
|
||||
has, err := db.GetEngine(ctx).Get(setting)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -143,24 +143,24 @@ func GetUserSetting(userID int64, key string, def ...string) (string, error) {
|
|||
}
|
||||
|
||||
// DeleteUserSetting deletes a specific setting for a user
|
||||
func DeleteUserSetting(userID int64, key string) error {
|
||||
func DeleteUserSetting(ctx context.Context, userID int64, key string) error {
|
||||
if err := validateUserSettingKey(key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cache.Remove(genSettingCacheKey(userID, key))
|
||||
_, err := db.GetEngine(db.DefaultContext).Delete(&Setting{UserID: userID, SettingKey: key})
|
||||
_, err := db.GetEngine(ctx).Delete(&Setting{UserID: userID, SettingKey: key})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SetUserSetting updates a users' setting for a specific key
|
||||
func SetUserSetting(userID int64, key, value string) error {
|
||||
func SetUserSetting(ctx context.Context, userID int64, key, value string) error {
|
||||
if err := validateUserSettingKey(key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := upsertUserSettingValue(userID, key, value); err != nil {
|
||||
if err := upsertUserSettingValue(ctx, userID, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -172,8 +172,8 @@ func SetUserSetting(userID int64, key, value string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func upsertUserSettingValue(userID int64, key, value string) error {
|
||||
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
|
||||
func upsertUserSettingValue(ctx context.Context, userID int64, key, value string) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
// here we use a general method to do a safe upsert for different databases (and most transaction levels)
|
||||
|
|
|
@ -6,6 +6,7 @@ package user_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
|
||||
|
@ -19,41 +20,41 @@ func TestSettings(t *testing.T) {
|
|||
newSetting := &user_model.Setting{UserID: 99, SettingKey: keyName, SettingValue: "Gitea User Setting Test"}
|
||||
|
||||
// create setting
|
||||
err := user_model.SetUserSetting(newSetting.UserID, newSetting.SettingKey, newSetting.SettingValue)
|
||||
err := user_model.SetUserSetting(db.DefaultContext, newSetting.UserID, newSetting.SettingKey, newSetting.SettingValue)
|
||||
assert.NoError(t, err)
|
||||
// test about saving unchanged values
|
||||
err = user_model.SetUserSetting(newSetting.UserID, newSetting.SettingKey, newSetting.SettingValue)
|
||||
err = user_model.SetUserSetting(db.DefaultContext, newSetting.UserID, newSetting.SettingKey, newSetting.SettingValue)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// get specific setting
|
||||
settings, err := user_model.GetSettings(99, []string{keyName})
|
||||
settings, err := user_model.GetSettings(db.DefaultContext, 99, []string{keyName})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, settings, 1)
|
||||
assert.EqualValues(t, newSetting.SettingValue, settings[keyName].SettingValue)
|
||||
|
||||
settingValue, err := user_model.GetUserSetting(99, keyName)
|
||||
settingValue, err := user_model.GetUserSetting(db.DefaultContext, 99, keyName)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, newSetting.SettingValue, settingValue)
|
||||
|
||||
settingValue, err = user_model.GetUserSetting(99, "no_such")
|
||||
settingValue, err = user_model.GetUserSetting(db.DefaultContext, 99, "no_such")
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "", settingValue)
|
||||
|
||||
// updated setting
|
||||
updatedSetting := &user_model.Setting{UserID: 99, SettingKey: keyName, SettingValue: "Updated"}
|
||||
err = user_model.SetUserSetting(updatedSetting.UserID, updatedSetting.SettingKey, updatedSetting.SettingValue)
|
||||
err = user_model.SetUserSetting(db.DefaultContext, updatedSetting.UserID, updatedSetting.SettingKey, updatedSetting.SettingValue)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// get all settings
|
||||
settings, err = user_model.GetUserAllSettings(99)
|
||||
settings, err = user_model.GetUserAllSettings(db.DefaultContext, 99)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, settings, 1)
|
||||
assert.EqualValues(t, updatedSetting.SettingValue, settings[updatedSetting.SettingKey].SettingValue)
|
||||
|
||||
// delete setting
|
||||
err = user_model.DeleteUserSetting(99, keyName)
|
||||
err = user_model.DeleteUserSetting(db.DefaultContext, 99, keyName)
|
||||
assert.NoError(t, err)
|
||||
settings, err = user_model.GetUserAllSettings(99)
|
||||
settings, err = user_model.GetUserAllSettings(db.DefaultContext, 99)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, settings, 0)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue