|
@@ -8,6 +8,7 @@ import (
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/cockroachdb/errors"
|
|
|
"github.com/gogs/git-module"
|
|
"github.com/gogs/git-module"
|
|
|
"github.com/unknwon/com"
|
|
"github.com/unknwon/com"
|
|
|
|
|
+ "gorm.io/gorm"
|
|
|
|
|
|
|
|
"gogs.io/gogs/internal/errutil"
|
|
"gogs.io/gogs/internal/errutil"
|
|
|
"gogs.io/gogs/internal/tool"
|
|
"gogs.io/gogs/internal/tool"
|
|
@@ -93,8 +94,9 @@ type ProtectBranchWhitelist struct {
|
|
|
|
|
|
|
|
// IsUserInProtectBranchWhitelist returns true if given user is in the whitelist of a branch in a repository.
|
|
// IsUserInProtectBranchWhitelist returns true if given user is in the whitelist of a branch in a repository.
|
|
|
func IsUserInProtectBranchWhitelist(repoID, userID int64, branch string) bool {
|
|
func IsUserInProtectBranchWhitelist(repoID, userID int64, branch string) bool {
|
|
|
- has, err := x.Where("repo_id = ?", repoID).And("user_id = ?", userID).And("name = ?", branch).Get(new(ProtectBranchWhitelist))
|
|
|
|
|
- return has && err == nil
|
|
|
|
|
|
|
+ var whitelist ProtectBranchWhitelist
|
|
|
|
|
+ err := db.Where("repo_id = ?", repoID).Where("user_id = ?", userID).Where("name = ?", branch).First(&whitelist).Error
|
|
|
|
|
+ return err == nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// ProtectBranch contains options of a protected branch.
|
|
// ProtectBranch contains options of a protected branch.
|
|
@@ -115,11 +117,11 @@ func GetProtectBranchOfRepoByName(repoID int64, name string) (*ProtectBranch, er
|
|
|
RepoID: repoID,
|
|
RepoID: repoID,
|
|
|
Name: name,
|
|
Name: name,
|
|
|
}
|
|
}
|
|
|
- has, err := x.Get(protectBranch)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, err
|
|
|
|
|
- } else if !has {
|
|
|
|
|
|
|
+ err := db.Where("repo_id = ? AND name = ?", repoID, name).First(protectBranch).Error
|
|
|
|
|
+ if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
return nil, ErrBranchNotExist{args: map[string]any{"name": name}}
|
|
return nil, ErrBranchNotExist{args: map[string]any{"name": name}}
|
|
|
|
|
+ } else if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
}
|
|
}
|
|
|
return protectBranch, nil
|
|
return protectBranch, nil
|
|
|
}
|
|
}
|
|
@@ -136,23 +138,19 @@ func IsBranchOfRepoRequirePullRequest(repoID int64, name string) bool {
|
|
|
// UpdateProtectBranch saves branch protection options.
|
|
// UpdateProtectBranch saves branch protection options.
|
|
|
// If ID is 0, it creates a new record. Otherwise, updates existing record.
|
|
// If ID is 0, it creates a new record. Otherwise, updates existing record.
|
|
|
func UpdateProtectBranch(protectBranch *ProtectBranch) (err error) {
|
|
func UpdateProtectBranch(protectBranch *ProtectBranch) (err error) {
|
|
|
- sess := x.NewSession()
|
|
|
|
|
- defer sess.Close()
|
|
|
|
|
- if err = sess.Begin(); err != nil {
|
|
|
|
|
- return err
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if protectBranch.ID == 0 {
|
|
|
|
|
- if _, err = sess.Insert(protectBranch); err != nil {
|
|
|
|
|
- return errors.Newf("insert: %v", err)
|
|
|
|
|
|
|
+ return db.Transaction(func(tx *gorm.DB) error {
|
|
|
|
|
+ if protectBranch.ID == 0 {
|
|
|
|
|
+ if err := tx.Create(protectBranch).Error; err != nil {
|
|
|
|
|
+ return errors.Newf("insert: %v", err)
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- if _, err = sess.ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
|
|
|
|
|
- return errors.Newf("update: %v", err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if err := tx.Model(&ProtectBranch{}).Where("id = ?", protectBranch.ID).Updates(protectBranch).Error; err != nil {
|
|
|
|
|
+ return errors.Newf("update: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- return sess.Commit()
|
|
|
|
|
|
|
+ return nil
|
|
|
|
|
+ })
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// UpdateOrgProtectBranch saves branch protection options of organizational repository.
|
|
// UpdateOrgProtectBranch saves branch protection options of organizational repository.
|
|
@@ -209,7 +207,7 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit
|
|
|
|
|
|
|
|
// Make sure protectBranch.ID is not 0 for whitelists
|
|
// Make sure protectBranch.ID is not 0 for whitelists
|
|
|
if protectBranch.ID == 0 {
|
|
if protectBranch.ID == 0 {
|
|
|
- if _, err = x.Insert(protectBranch); err != nil {
|
|
|
|
|
|
|
+ if err = db.Create(protectBranch).Error; err != nil {
|
|
|
return errors.Newf("insert: %v", err)
|
|
return errors.Newf("insert: %v", err)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -247,30 +245,29 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- sess := x.NewSession()
|
|
|
|
|
- defer sess.Close()
|
|
|
|
|
- if err = sess.Begin(); err != nil {
|
|
|
|
|
- return err
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if _, err = sess.ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
|
|
|
|
|
- return errors.Newf("Update: %v", err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ return db.Transaction(func(tx *gorm.DB) error {
|
|
|
|
|
+ if err := tx.Model(&ProtectBranch{}).Where("id = ?", protectBranch.ID).Updates(protectBranch).Error; err != nil {
|
|
|
|
|
+ return errors.Newf("Update: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Refresh whitelists
|
|
|
|
|
- if hasUsersChanged || hasTeamsChanged {
|
|
|
|
|
- if _, err = sess.Delete(&ProtectBranchWhitelist{ProtectBranchID: protectBranch.ID}); err != nil {
|
|
|
|
|
- return errors.Newf("delete old protect branch whitelists: %v", err)
|
|
|
|
|
- } else if _, err = sess.Insert(whitelists); err != nil {
|
|
|
|
|
- return errors.Newf("insert new protect branch whitelists: %v", err)
|
|
|
|
|
|
|
+ // Refresh whitelists
|
|
|
|
|
+ if hasUsersChanged || hasTeamsChanged {
|
|
|
|
|
+ if err := tx.Delete(&ProtectBranchWhitelist{}, "protect_branch_id = ?", protectBranch.ID).Error; err != nil {
|
|
|
|
|
+ return errors.Newf("delete old protect branch whitelists: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ if len(whitelists) > 0 {
|
|
|
|
|
+ if err := tx.Create(&whitelists).Error; err != nil {
|
|
|
|
|
+ return errors.Newf("insert new protect branch whitelists: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- return sess.Commit()
|
|
|
|
|
|
|
+ return nil
|
|
|
|
|
+ })
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// GetProtectBranchesByRepoID returns a list of *ProtectBranch in given repository.
|
|
// GetProtectBranchesByRepoID returns a list of *ProtectBranch in given repository.
|
|
|
func GetProtectBranchesByRepoID(repoID int64) ([]*ProtectBranch, error) {
|
|
func GetProtectBranchesByRepoID(repoID int64) ([]*ProtectBranch, error) {
|
|
|
protectBranches := make([]*ProtectBranch, 0, 2)
|
|
protectBranches := make([]*ProtectBranch, 0, 2)
|
|
|
- return protectBranches, x.Where("repo_id = ? and protected = ?", repoID, true).Asc("name").Find(&protectBranches)
|
|
|
|
|
|
|
+ return protectBranches, db.Where("repo_id = ? AND protected = ?", repoID, true).Order("name ASC").Find(&protectBranches).Error
|
|
|
}
|
|
}
|