1
Fork 0

Use db.Find instead of writing methods for every object (#28084)

For those simple objects, it's unnecessary to write the find and count
methods again and again.
This commit is contained in:
Lunny Xiao 2023-11-24 11:49:41 +08:00 committed by GitHub
parent d24a8223ce
commit df1e7d0067
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
88 changed files with 611 additions and 685 deletions

View file

@ -22,7 +22,6 @@ import (
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
"xorm.io/xorm"
)
type (
@ -93,7 +92,7 @@ type FindNotificationOptions struct {
}
// ToCond will convert each condition into a xorm-Cond
func (opts *FindNotificationOptions) ToCond() builder.Cond {
func (opts FindNotificationOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.UserID != 0 {
cond = cond.And(builder.Eq{"notification.user_id": opts.UserID})
@ -105,7 +104,11 @@ func (opts *FindNotificationOptions) ToCond() builder.Cond {
cond = cond.And(builder.Eq{"notification.issue_id": opts.IssueID})
}
if len(opts.Status) > 0 {
cond = cond.And(builder.In("notification.status", opts.Status))
if len(opts.Status) == 1 {
cond = cond.And(builder.Eq{"notification.status": opts.Status[0]})
} else {
cond = cond.And(builder.In("notification.status", opts.Status))
}
}
if len(opts.Source) > 0 {
cond = cond.And(builder.In("notification.source", opts.Source))
@ -119,24 +122,8 @@ func (opts *FindNotificationOptions) ToCond() builder.Cond {
return cond
}
// ToSession will convert the given options to a xorm Session by using the conditions from ToCond and joining with issue table if required
func (opts *FindNotificationOptions) ToSession(ctx context.Context) *xorm.Session {
sess := db.GetEngine(ctx).Where(opts.ToCond())
if opts.Page != 0 {
sess = db.SetSessionPagination(sess, opts)
}
return sess
}
// GetNotifications returns all notifications that fit to the given options.
func GetNotifications(ctx context.Context, options *FindNotificationOptions) (nl NotificationList, err error) {
err = options.ToSession(ctx).OrderBy("notification.updated_unix DESC").Find(&nl)
return nl, err
}
// CountNotifications count all notifications that fit to the given options and ignore pagination.
func CountNotifications(ctx context.Context, opts *FindNotificationOptions) (int64, error) {
return db.GetEngine(ctx).Where(opts.ToCond()).Count(&Notification{})
func (opts FindNotificationOptions) ToOrders() string {
return "notification.updated_unix DESC"
}
// CreateRepoTransferNotification creates notification for the user a repository was transferred to
@ -192,7 +179,9 @@ func CreateOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n
func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, notificationAuthorID, receiverID int64) error {
// init
var toNotify container.Set[int64]
notifications, err := getNotificationsByIssueID(ctx, issueID)
notifications, err := db.Find[Notification](ctx, FindNotificationOptions{
IssueID: issueID,
})
if err != nil {
return err
}
@ -273,23 +262,6 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n
return nil
}
func getNotificationsByIssueID(ctx context.Context, issueID int64) (notifications []*Notification, err error) {
err = db.GetEngine(ctx).
Where("issue_id = ?", issueID).
Find(&notifications)
return notifications, err
}
func notificationExists(notifications []*Notification, issueID, userID int64) bool {
for _, notification := range notifications {
if notification.IssueID == issueID && notification.UserID == userID {
return true
}
}
return false
}
func createIssueNotification(ctx context.Context, userID int64, issue *issues_model.Issue, commentID, updatedByID int64) error {
notification := &Notification{
UserID: userID,
@ -341,35 +313,6 @@ func GetIssueNotification(ctx context.Context, userID, issueID int64) (*Notifica
return notification, err
}
// NotificationsForUser returns notifications for a given user and status
func NotificationsForUser(ctx context.Context, user *user_model.User, statuses []NotificationStatus, page, perPage int) (notifications NotificationList, err error) {
if len(statuses) == 0 {
return nil, nil
}
sess := db.GetEngine(ctx).
Where("user_id = ?", user.ID).
In("status", statuses).
OrderBy("updated_unix DESC")
if page > 0 && perPage > 0 {
sess.Limit(perPage, (page-1)*perPage)
}
err = sess.Find(&notifications)
return notifications, err
}
// CountUnread count unread notifications for a user
func CountUnread(ctx context.Context, userID int64) int64 {
exist, err := db.GetEngine(ctx).Where("user_id = ?", userID).And("status = ?", NotificationStatusUnread).Count(new(Notification))
if err != nil {
log.Error("countUnread", err)
return 0
}
return exist
}
// LoadAttributes load Repo Issue User and Comment if not loaded
func (n *Notification) LoadAttributes(ctx context.Context) (err error) {
if err = n.loadRepo(ctx); err != nil {
@ -481,17 +424,47 @@ func (n *Notification) APIURL() string {
return setting.AppURL + "api/v1/notifications/threads/" + strconv.FormatInt(n.ID, 10)
}
func notificationExists(notifications []*Notification, issueID, userID int64) bool {
for _, notification := range notifications {
if notification.IssueID == issueID && notification.UserID == userID {
return true
}
}
return false
}
// UserIDCount is a simple coalition of UserID and Count
type UserIDCount struct {
UserID int64
Count int64
}
// GetUIDsAndNotificationCounts between the two provided times
func GetUIDsAndNotificationCounts(ctx context.Context, since, until timeutil.TimeStamp) ([]UserIDCount, error) {
sql := `SELECT user_id, count(*) AS count FROM notification ` +
`WHERE user_id IN (SELECT user_id FROM notification WHERE updated_unix >= ? AND ` +
`updated_unix < ?) AND status = ? GROUP BY user_id`
var res []UserIDCount
return res, db.GetEngine(ctx).SQL(sql, since, until, NotificationStatusUnread).Find(&res)
}
// NotificationList contains a list of notifications
type NotificationList []*Notification
// LoadAttributes load Repo Issue User and Comment if not loaded
func (nl NotificationList) LoadAttributes(ctx context.Context) error {
var err error
for i := 0; i < len(nl); i++ {
err = nl[i].LoadAttributes(ctx)
if err != nil && !issues_model.IsErrCommentNotExist(err) {
return err
}
if _, _, err := nl.LoadRepos(ctx); err != nil {
return err
}
if _, err := nl.LoadIssues(ctx); err != nil {
return err
}
if _, err := nl.LoadUsers(ctx); err != nil {
return err
}
if _, err := nl.LoadComments(ctx); err != nil {
return err
}
return nil
}
@ -665,6 +638,68 @@ func (nl NotificationList) getPendingCommentIDs() []int64 {
return ids.Values()
}
func (nl NotificationList) getUserIDs() []int64 {
ids := make(container.Set[int64], len(nl))
for _, notification := range nl {
if notification.UserID == 0 || notification.User != nil {
continue
}
ids.Add(notification.UserID)
}
return ids.Values()
}
// LoadUsers loads users from database
func (nl NotificationList) LoadUsers(ctx context.Context) ([]int, error) {
if len(nl) == 0 {
return []int{}, nil
}
userIDs := nl.getUserIDs()
users := make(map[int64]*user_model.User, len(userIDs))
left := len(userIDs)
for left > 0 {
limit := db.DefaultMaxInSize
if left < limit {
limit = left
}
rows, err := db.GetEngine(ctx).
In("id", userIDs[:limit]).
Rows(new(user_model.User))
if err != nil {
return nil, err
}
for rows.Next() {
var user user_model.User
err = rows.Scan(&user)
if err != nil {
rows.Close()
return nil, err
}
users[user.ID] = &user
}
_ = rows.Close()
left -= limit
userIDs = userIDs[limit:]
}
failures := []int{}
for i, notification := range nl {
if notification.UserID > 0 && notification.User == nil && users[notification.UserID] != nil {
notification.User = users[notification.UserID]
if notification.User == nil {
log.Error("Notification[%d]: UserID[%d] failed to load", notification.ID, notification.UserID)
failures = append(failures, i)
continue
}
}
}
return failures, nil
}
// LoadComments loads comments from database
func (nl NotificationList) LoadComments(ctx context.Context) ([]int, error) {
if len(nl) == 0 {
@ -717,30 +752,6 @@ func (nl NotificationList) LoadComments(ctx context.Context) ([]int, error) {
return failures, nil
}
// GetNotificationCount returns the notification count for user
func GetNotificationCount(ctx context.Context, user *user_model.User, status NotificationStatus) (count int64, err error) {
count, err = db.GetEngine(ctx).
Where("user_id = ?", user.ID).
And("status = ?", status).
Count(&Notification{})
return count, err
}
// UserIDCount is a simple coalition of UserID and Count
type UserIDCount struct {
UserID int64
Count int64
}
// GetUIDsAndNotificationCounts between the two provided times
func GetUIDsAndNotificationCounts(ctx context.Context, since, until timeutil.TimeStamp) ([]UserIDCount, error) {
sql := `SELECT user_id, count(*) AS count FROM notification ` +
`WHERE user_id IN (SELECT user_id FROM notification WHERE updated_unix >= ? AND ` +
`updated_unix < ?) AND status = ? GROUP BY user_id`
var res []UserIDCount
return res, db.GetEngine(ctx).SQL(sql, since, until, NotificationStatusUnread).Find(&res)
}
// SetIssueReadBy sets issue to be read by given user.
func SetIssueReadBy(ctx context.Context, issueID, userID int64) error {
if err := issues_model.UpdateIssueUserByRead(ctx, userID, issueID); err != nil {

View file

@ -34,8 +34,13 @@ func TestCreateOrUpdateIssueNotifications(t *testing.T) {
func TestNotificationsForUser(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
statuses := []activities_model.NotificationStatus{activities_model.NotificationStatusRead, activities_model.NotificationStatusUnread}
notfs, err := activities_model.NotificationsForUser(db.DefaultContext, user, statuses, 1, 10)
notfs, err := db.Find[activities_model.Notification](db.DefaultContext, activities_model.FindNotificationOptions{
UserID: user.ID,
Status: []activities_model.NotificationStatus{
activities_model.NotificationStatusRead,
activities_model.NotificationStatusUnread,
},
})
assert.NoError(t, err)
if assert.Len(t, notfs, 3) {
assert.EqualValues(t, 5, notfs[0].ID)
@ -68,11 +73,21 @@ func TestNotification_GetIssue(t *testing.T) {
func TestGetNotificationCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
cnt, err := activities_model.GetNotificationCount(db.DefaultContext, user, activities_model.NotificationStatusRead)
cnt, err := db.Count[activities_model.Notification](db.DefaultContext, activities_model.FindNotificationOptions{
UserID: user.ID,
Status: []activities_model.NotificationStatus{
activities_model.NotificationStatusRead,
},
})
assert.NoError(t, err)
assert.EqualValues(t, 0, cnt)
cnt, err = activities_model.GetNotificationCount(db.DefaultContext, user, activities_model.NotificationStatusUnread)
cnt, err = db.Count[activities_model.Notification](db.DefaultContext, activities_model.FindNotificationOptions{
UserID: user.ID,
Status: []activities_model.NotificationStatus{
activities_model.NotificationStatusUnread,
},
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}

View file

@ -52,7 +52,7 @@ type IssueByRepositoryCount struct {
func GetStatistic(ctx context.Context) (stats Statistic) {
e := db.GetEngine(ctx)
stats.Counter.User = user_model.CountUsers(ctx, nil)
stats.Counter.Org, _ = organization.CountOrgs(ctx, organization.FindOrgOptions{IncludePrivate: true})
stats.Counter.Org, _ = db.Count[organization.Organization](ctx, organization.FindOrgOptions{IncludePrivate: true})
stats.Counter.PublicKey, _ = e.Count(new(asymkey_model.PublicKey))
stats.Counter.Repo, _ = repo_model.CountRepositories(ctx, repo_model.CountRepositoryOptions{})
stats.Counter.Watch, _ = e.Count(new(repo_model.Watch))
@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
stats.Counter.AuthSource = auth.CountSources(ctx, auth.FindSourcesOptions{})
stats.Counter.AuthSource, _ = db.Count[auth.Source](ctx, auth.FindSourcesOptions{})
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
stats.Counter.Label, _ = e.Count(new(issues_model.Label))