1
0
Эх сурвалжийг харах

Convert repo_branch.go and repo_collaboration.go from XORM to GORM

Co-authored-by: unknwon <2946214+unknwon@users.noreply.github.com>
copilot-swe-agent[bot] 2 долоо хоног өмнө
parent
commit
c74b48bd2d

+ 1 - 1
internal/database/issue_mail.go

@@ -176,7 +176,7 @@ func mailIssueCommentToParticipants(issue *Issue, doer *User, mentions []string)
 // and mentioned people.
 func (issue *Issue) MailParticipants() (err error) {
 	mentions := markup.FindAllMentions(issue.Content)
-	if err = updateIssueMentions(x, issue.ID, mentions); err != nil {
+	if err = updateIssueMentions(db, issue.ID, mentions); err != nil {
 		return errors.Newf("UpdateIssueMentions [%d]: %v", issue.ID, err)
 	}
 

+ 1 - 1
internal/database/models.go

@@ -97,7 +97,7 @@ func getGormDB(gormLogger logger.Writer) (*gorm.DB, error) {
 
 func NewTestEngine() error {
 	var err error
-	db, err = getGormDB(&dbutil.Logger{Writer: log.NewConsoleWriter()})
+	db, err = getGormDB(&dbutil.Logger{Writer: os.Stdout})
 	if err != nil {
 		return errors.Newf("connect to database: %v", err)
 	}

+ 17 - 21
internal/database/org.go

@@ -484,7 +484,7 @@ func (org *User) GetUserRepositories(userID int64, page, pageSize int) ([]*Repos
 	}
 
 	var teamRepoIDs []int64
-	if err = x.Table("team_repo").In("team_id", teamIDs).Distinct("repo_id").Find(&teamRepoIDs); err != nil {
+	if err = db.Table("team_repo").Where("team_id IN ?", teamIDs).Distinct("repo_id").Find(&teamRepoIDs).Error; err != nil {
 		return nil, 0, errors.Newf("get team repository IDs: %v", err)
 	}
 	if len(teamRepoIDs) == 0 {
@@ -496,22 +496,18 @@ func (org *User) GetUserRepositories(userID int64, page, pageSize int) ([]*Repos
 		page = 1
 	}
 	repos := make([]*Repository, 0, pageSize)
-	if err = x.Where("owner_id = ?", org.ID).
-		And(builder.Or(
-			builder.And(builder.Expr("is_private = ?", false), builder.Expr("is_unlisted = ?", false)),
-			builder.In("id", teamRepoIDs))).
-		Desc("updated_unix").
-		Limit(pageSize, (page-1)*pageSize).
-		Find(&repos); err != nil {
+	if err = db.Where("owner_id = ?", org.ID).
+		Where(db.Where("is_private = ? AND is_unlisted = ?", false, false).Or("id IN ?", teamRepoIDs)).
+		Order("updated_unix DESC").
+		Limit(pageSize).Offset((page - 1) * pageSize).
+		Find(&repos).Error; err != nil {
 		return nil, 0, errors.Newf("get user repositories: %v", err)
 	}
 
-	repoCount, err := x.Where("owner_id = ?", org.ID).
-		And(builder.Or(
-			builder.Expr("is_private = ?", false),
-			builder.In("id", teamRepoIDs))).
-		Count(new(Repository))
-	if err != nil {
+	var repoCount int64
+	if err = db.Model(&Repository{}).Where("owner_id = ?", org.ID).
+		Where(db.Where("is_private = ?", false).Or("id IN ?", teamRepoIDs)).
+		Count(&repoCount).Error; err != nil {
 		return nil, 0, errors.Newf("count user repositories: %v", err)
 	}
 
@@ -529,7 +525,7 @@ func (org *User) GetUserMirrorRepositories(userID int64) ([]*Repository, error)
 	}
 
 	var teamRepoIDs []int64
-	err = x.Table("team_repo").In("team_id", teamIDs).Distinct("repo_id").Find(&teamRepoIDs)
+	err = db.Table("team_repo").Where("team_id IN ?", teamIDs).Distinct("repo_id").Find(&teamRepoIDs).Error
 	if err != nil {
 		return nil, errors.Newf("get team repository ids: %v", err)
 	}
@@ -539,12 +535,12 @@ func (org *User) GetUserMirrorRepositories(userID int64) ([]*Repository, error)
 	}
 
 	repos := make([]*Repository, 0, 10)
-	if err = x.Where("owner_id = ?", org.ID).
-		And("is_private = ?", false).
-		Or(builder.In("id", teamRepoIDs)).
-		And("is_mirror = ?", true). // Don't move up because it's an independent condition
-		Desc("updated_unix").
-		Find(&repos); err != nil {
+	if err = db.Where("owner_id = ?", org.ID).
+		Where("is_private = ?", false).
+		Or("id IN ?", teamRepoIDs).
+		Where("is_mirror = ?", true). // Don't move up because it's an independent condition
+		Order("updated_unix DESC").
+		Find(&repos).Error; err != nil {
 		return nil, errors.Newf("get user repositories: %v", err)
 	}
 	return repos, nil

+ 2 - 2
internal/database/repo.go

@@ -549,7 +549,7 @@ func (r *Repository) GetAssigneeByID(userID int64) (*User, error) {
 
 // GetWriters returns all users that have write access to the repository.
 func (r *Repository) GetWriters() (_ []*User, err error) {
-	return r.getUsersWithAccesMode(x, AccessModeWrite)
+	return r.getUsersWithAccesMode(db, AccessModeWrite)
 }
 
 // GetMilestoneByID returns the milestone belongs to repository by given ID.
@@ -1230,7 +1230,7 @@ func CreateRepository(doer, owner *User, opts CreateRepoOptionsLegacy) (_ *Repos
 		EnablePulls:  true,
 	}
 
-	err := db.Transaction(func(tx *gorm.DB) error {
+	err = db.Transaction(func(tx *gorm.DB) error {
 		if err := createRepository(tx, doer, owner, repo); err != nil {
 			return err
 		}

+ 36 - 39
internal/database/repo_branch.go

@@ -8,6 +8,7 @@ import (
 	"github.com/cockroachdb/errors"
 	"github.com/gogs/git-module"
 	"github.com/unknwon/com"
+	"gorm.io/gorm"
 
 	"gogs.io/gogs/internal/errutil"
 	"gogs.io/gogs/internal/tool"
@@ -93,8 +94,9 @@ type ProtectBranchWhitelist struct {
 
 // IsUserInProtectBranchWhitelist returns true if given user is in the whitelist of a branch in a repository.
 func IsUserInProtectBranchWhitelist(repoID, userID int64, branch string) bool {
-	has, err := x.Where("repo_id = ?", repoID).And("user_id = ?", userID).And("name = ?", branch).Get(new(ProtectBranchWhitelist))
-	return has && err == nil
+	var whitelist ProtectBranchWhitelist
+	err := db.Where("repo_id = ?", repoID).Where("user_id = ?", userID).Where("name = ?", branch).First(&whitelist).Error
+	return err == nil
 }
 
 // ProtectBranch contains options of a protected branch.
@@ -115,11 +117,11 @@ func GetProtectBranchOfRepoByName(repoID int64, name string) (*ProtectBranch, er
 		RepoID: repoID,
 		Name:   name,
 	}
-	has, err := x.Get(protectBranch)
-	if err != nil {
-		return nil, err
-	} else if !has {
+	err := db.Where("repo_id = ? AND name = ?", repoID, name).First(protectBranch).Error
+	if errors.Is(err, gorm.ErrRecordNotFound) {
 		return nil, ErrBranchNotExist{args: map[string]any{"name": name}}
+	} else if err != nil {
+		return nil, err
 	}
 	return protectBranch, nil
 }
@@ -136,23 +138,19 @@ func IsBranchOfRepoRequirePullRequest(repoID int64, name string) bool {
 // UpdateProtectBranch saves branch protection options.
 // If ID is 0, it creates a new record. Otherwise, updates existing record.
 func UpdateProtectBranch(protectBranch *ProtectBranch) (err error) {
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
-		return err
-	}
-
-	if protectBranch.ID == 0 {
-		if _, err = sess.Insert(protectBranch); err != nil {
-			return errors.Newf("insert: %v", err)
+	return db.Transaction(func(tx *gorm.DB) error {
+		if protectBranch.ID == 0 {
+			if err := tx.Create(protectBranch).Error; err != nil {
+				return errors.Newf("insert: %v", err)
+			}
 		}
-	}
 
-	if _, err = sess.ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
-		return errors.Newf("update: %v", err)
-	}
+		if err := tx.Model(&ProtectBranch{}).Where("id = ?", protectBranch.ID).Updates(protectBranch).Error; err != nil {
+			return errors.Newf("update: %v", err)
+		}
 
-	return sess.Commit()
+		return nil
+	})
 }
 
 // UpdateOrgProtectBranch saves branch protection options of organizational repository.
@@ -209,7 +207,7 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit
 
 	// Make sure protectBranch.ID is not 0 for whitelists
 	if protectBranch.ID == 0 {
-		if _, err = x.Insert(protectBranch); err != nil {
+		if err = db.Create(protectBranch).Error; err != nil {
 			return errors.Newf("insert: %v", err)
 		}
 	}
@@ -247,30 +245,29 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit
 		}
 	}
 
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
-		return err
-	}
-
-	if _, err = sess.ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
-		return errors.Newf("Update: %v", err)
-	}
+	return db.Transaction(func(tx *gorm.DB) error {
+		if err := tx.Model(&ProtectBranch{}).Where("id = ?", protectBranch.ID).Updates(protectBranch).Error; err != nil {
+			return errors.Newf("Update: %v", err)
+		}
 
-	// Refresh whitelists
-	if hasUsersChanged || hasTeamsChanged {
-		if _, err = sess.Delete(&ProtectBranchWhitelist{ProtectBranchID: protectBranch.ID}); err != nil {
-			return errors.Newf("delete old protect branch whitelists: %v", err)
-		} else if _, err = sess.Insert(whitelists); err != nil {
-			return errors.Newf("insert new protect branch whitelists: %v", err)
+		// Refresh whitelists
+		if hasUsersChanged || hasTeamsChanged {
+			if err := tx.Delete(&ProtectBranchWhitelist{}, "protect_branch_id = ?", protectBranch.ID).Error; err != nil {
+				return errors.Newf("delete old protect branch whitelists: %v", err)
+			}
+			if len(whitelists) > 0 {
+				if err := tx.Create(&whitelists).Error; err != nil {
+					return errors.Newf("insert new protect branch whitelists: %v", err)
+				}
+			}
 		}
-	}
 
-	return sess.Commit()
+		return nil
+	})
 }
 
 // GetProtectBranchesByRepoID returns a list of *ProtectBranch in given repository.
 func GetProtectBranchesByRepoID(repoID int64) ([]*ProtectBranch, error) {
 	protectBranches := make([]*ProtectBranch, 0, 2)
-	return protectBranches, x.Where("repo_id = ? and protected = ?", repoID, true).Asc("name").Find(&protectBranches)
+	return protectBranches, db.Where("repo_id = ? AND protected = ?", repoID, true).Order("name ASC").Find(&protectBranches).Error
 }

+ 56 - 62
internal/database/repo_collaboration.go

@@ -35,12 +35,12 @@ func IsCollaborator(repoID, userID int64) bool {
 		RepoID: repoID,
 		UserID: userID,
 	}
-	has, err := x.Get(collaboration)
+	err := db.Where("repo_id = ? AND user_id = ?", repoID, userID).First(collaboration).Error
 	if err != nil {
 		log.Error("get collaboration [repo_id: %d, user_id: %d]: %v", repoID, userID, err)
 		return false
 	}
-	return has
+	return true
 }
 
 func (r *Repository) IsCollaborator(userID int64) bool {
@@ -54,27 +54,24 @@ func (r *Repository) AddCollaborator(u *User) error {
 		UserID: u.ID,
 	}
 
-	has, err := x.Get(collaboration)
-	if err != nil {
-		return err
-	} else if has {
+	var existing Collaboration
+	err := db.Where("repo_id = ? AND user_id = ?", r.ID, u.ID).First(&existing).Error
+	if err == nil {
 		return nil
-	}
-	collaboration.Mode = AccessModeWrite
-
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
+	} else if !errors.Is(err, gorm.ErrRecordNotFound) {
 		return err
 	}
+	collaboration.Mode = AccessModeWrite
 
-	if _, err = sess.Insert(collaboration); err != nil {
-		return err
-	} else if err = r.recalculateAccesses(sess); err != nil {
-		return errors.Newf("recalculateAccesses [repo_id: %v]: %v", r.ID, err)
-	}
-
-	return sess.Commit()
+	return db.Transaction(func(tx *gorm.DB) error {
+		if err := tx.Create(collaboration).Error; err != nil {
+			return err
+		}
+		if err := r.recalculateAccesses(tx); err != nil {
+			return errors.Newf("recalculateAccesses [repo_id: %v]: %v", r.ID, err)
+		}
+		return nil
+	})
 }
 
 func (r *Repository) getCollaborations(e *gorm.DB) ([]*Collaboration, error) {
@@ -121,7 +118,7 @@ func (r *Repository) getCollaborators(e *gorm.DB) ([]*Collaborator, error) {
 
 // GetCollaborators returns the collaborators for a repository
 func (r *Repository) GetCollaborators() ([]*Collaborator, error) {
-	return r.getCollaborators(x)
+	return r.getCollaborators(db)
 }
 
 // ChangeCollaborationAccessMode sets new access mode for the collaboration.
@@ -135,11 +132,11 @@ func (r *Repository) ChangeCollaborationAccessMode(userID int64, mode AccessMode
 		RepoID: r.ID,
 		UserID: userID,
 	}
-	has, err := x.Get(collaboration)
-	if err != nil {
-		return errors.Newf("get collaboration: %v", err)
-	} else if !has {
+	err := db.Where("repo_id = ? AND user_id = ?", r.ID, userID).First(collaboration).Error
+	if errors.Is(err, gorm.ErrRecordNotFound) {
 		return nil
+	} else if err != nil {
+		return errors.Newf("get collaboration: %v", err)
 	}
 
 	if collaboration.Mode == mode {
@@ -160,35 +157,31 @@ func (r *Repository) ChangeCollaborationAccessMode(userID int64, mode AccessMode
 		}
 	}
 
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
-		return err
-	}
-
-	if _, err = sess.ID(collaboration.ID).AllCols().Update(collaboration); err != nil {
-		return errors.Newf("update collaboration: %v", err)
-	}
+	return db.Transaction(func(tx *gorm.DB) error {
+		if err := tx.Model(&Collaboration{}).Where("id = ?", collaboration.ID).Updates(collaboration).Error; err != nil {
+			return errors.Newf("update collaboration: %v", err)
+		}
 
-	access := &Access{
-		UserID: userID,
-		RepoID: r.ID,
-	}
-	has, err = sess.Get(access)
-	if err != nil {
-		return errors.Newf("get access record: %v", err)
-	}
-	if has {
-		_, err = sess.Exec("UPDATE access SET mode = ? WHERE user_id = ? AND repo_id = ?", mode, userID, r.ID)
-	} else {
-		access.Mode = mode
-		_, err = sess.Insert(access)
-	}
-	if err != nil {
-		return errors.Newf("update/insert access table: %v", err)
-	}
+		access := &Access{
+			UserID: userID,
+			RepoID: r.ID,
+		}
+		err := tx.Where("user_id = ? AND repo_id = ?", userID, r.ID).First(access).Error
+		if err == nil {
+			if err := tx.Exec("UPDATE access SET mode = ? WHERE user_id = ? AND repo_id = ?", mode, userID, r.ID).Error; err != nil {
+				return errors.Newf("update access table: %v", err)
+			}
+		} else if errors.Is(err, gorm.ErrRecordNotFound) {
+			access.Mode = mode
+			if err := tx.Create(access).Error; err != nil {
+				return errors.Newf("insert access table: %v", err)
+			}
+		} else {
+			return errors.Newf("get access record: %v", err)
+		}
 
-	return sess.Commit()
+		return nil
+	})
 }
 
 // DeleteCollaboration removes collaboration relation between the user and repository.
@@ -202,19 +195,20 @@ func DeleteCollaboration(repo *Repository, userID int64) (err error) {
 		UserID: userID,
 	}
 
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
-		return err
-	}
+	return db.Transaction(func(tx *gorm.DB) error {
+		result := tx.Delete(collaboration, "repo_id = ? AND user_id = ?", repo.ID, userID)
+		if result.Error != nil {
+			return result.Error
+		} else if result.RowsAffected == 0 {
+			return nil
+		}
 
-	if has, err := sess.Delete(collaboration); err != nil || has == 0 {
-		return err
-	} else if err = repo.recalculateAccesses(sess); err != nil {
-		return err
-	}
+		if err := repo.recalculateAccesses(tx); err != nil {
+			return err
+		}
 
-	return sess.Commit()
+		return nil
+	})
 }
 
 func (r *Repository) DeleteCollaboration(userID int64) error {