Преглед изворни кода

Migrate internal/database/ssh_key.go from XORM to GORM

Co-authored-by: unknwon <2946214+unknwon@users.noreply.github.com>
copilot-swe-agent[bot] пре 2 недеља
родитељ
комит
8d43893af2
1 измењених фајлова са 131 додато и 137 уклоњено
  1. 131 137
      internal/database/ssh_key.go

+ 131 - 137
internal/database/ssh_key.go

@@ -16,8 +16,8 @@ import (
 	"github.com/cockroachdb/errors"
 	"github.com/unknwon/com"
 	"golang.org/x/crypto/ssh"
+	"gorm.io/gorm"
 	log "unknwon.dev/clog/v2"
-	"xorm.io/xorm"
 
 	"gogs.io/gogs/internal/conf"
 	"gogs.io/gogs/internal/errutil"
@@ -40,38 +40,41 @@ const (
 // PublicKey represents a user or deploy SSH public key.
 type PublicKey struct {
 	ID          int64      `gorm:"primaryKey"`
-	OwnerID     int64      `xorm:"INDEX NOT NULL" gorm:"index;not null"`
-	Name        string     `xorm:"NOT NULL" gorm:"not null"`
-	Fingerprint string     `xorm:"NOT NULL" gorm:"not null"`
-	Content     string     `xorm:"TEXT NOT NULL" gorm:"type:TEXT;not null"`
-	Mode        AccessMode `xorm:"NOT NULL DEFAULT 2" gorm:"not null;default:2"`
-	Type        KeyType    `xorm:"NOT NULL DEFAULT 1" gorm:"not null;default:1"`
-
-	Created           time.Time `xorm:"-" json:"-" gorm:"-"`
+	OwnerID     int64      `gorm:"index;not null"`
+	Name        string     `gorm:"not null"`
+	Fingerprint string     `gorm:"not null"`
+	Content     string     `gorm:"type:text;not null"`
+	Mode        AccessMode `gorm:"not null;default:2"`
+	Type        KeyType    `gorm:"not null;default:1"`
+
+	Created           time.Time `gorm:"-" json:"-"`
 	CreatedUnix       int64
-	Updated           time.Time `xorm:"-" json:"-" gorm:"-"` // Note: Updated must below Created for AfterSet.
+	Updated           time.Time `gorm:"-" json:"-"` // Note: Updated must below Created for AfterFind.
 	UpdatedUnix       int64
-	HasRecentActivity bool `xorm:"-" json:"-" gorm:"-"`
-	HasUsed           bool `xorm:"-" json:"-" gorm:"-"`
+	HasRecentActivity bool `gorm:"-" json:"-"`
+	HasUsed           bool `gorm:"-" json:"-"`
 }
 
-func (k *PublicKey) BeforeInsert() {
-	k.CreatedUnix = time.Now().Unix()
+func (k *PublicKey) BeforeCreate(tx *gorm.DB) error {
+	if k.CreatedUnix == 0 {
+		k.CreatedUnix = tx.NowFunc().Unix()
+	}
+	return nil
 }
 
-func (k *PublicKey) BeforeUpdate() {
-	k.UpdatedUnix = time.Now().Unix()
+func (k *PublicKey) BeforeUpdate(tx *gorm.DB) error {
+	k.UpdatedUnix = tx.NowFunc().Unix()
+	return nil
 }
 
-func (k *PublicKey) AfterSet(colName string, _ xorm.Cell) {
-	switch colName {
-	case "created_unix":
-		k.Created = time.Unix(k.CreatedUnix, 0).Local()
-	case "updated_unix":
+func (k *PublicKey) AfterFind(tx *gorm.DB) error {
+	k.Created = time.Unix(k.CreatedUnix, 0).Local()
+	if k.UpdatedUnix > 0 {
 		k.Updated = time.Unix(k.UpdatedUnix, 0).Local()
 		k.HasUsed = k.Updated.After(k.Created)
-		k.HasRecentActivity = k.Updated.Add(7 * 24 * time.Hour).After(time.Now())
+		k.HasRecentActivity = k.Updated.Add(7 * 24 * time.Hour).After(tx.NowFunc())
 	}
+	return nil
 }
 
 // OmitEmail returns content of public key without email address.
@@ -356,19 +359,16 @@ func appendAuthorizedKeysToFile(keys ...*PublicKey) error {
 // checkKeyContent onlys checks if key content has been used as public key,
 // it is OK to use same key as deploy key for multiple repositories/users.
 func checkKeyContent(content string) error {
-	has, err := x.Get(&PublicKey{
-		Content: content,
-		Type:    KeyTypeUser,
-	})
-	if err != nil {
-		return err
-	} else if has {
+	err := db.Where("content = ? AND type = ?", content, KeyTypeUser).First(&PublicKey{}).Error
+	if err == nil {
 		return ErrKeyAlreadyExist{0, content}
+	} else if !errors.Is(err, gorm.ErrRecordNotFound) {
+		return err
 	}
 	return nil
 }
 
-func addKey(e Engine, key *PublicKey) (err error) {
+func addKey(tx *gorm.DB, key *PublicKey) (err error) {
 	// Calculate fingerprint.
 	tmpPath := strings.ReplaceAll(path.Join(os.TempDir(), fmt.Sprintf("%d", time.Now().Nanosecond()), "id_rsa.pub"), "\\", "/")
 	_ = os.MkdirAll(path.Dir(tmpPath), os.ModePerm)
@@ -385,7 +385,7 @@ func addKey(e Engine, key *PublicKey) (err error) {
 	key.Fingerprint = strings.Split(stdout, " ")[1]
 
 	// Save SSH key.
-	if _, err = e.Insert(key); err != nil {
+	if err = tx.Create(key).Error; err != nil {
 		return err
 	}
 
@@ -404,16 +404,10 @@ func AddPublicKey(ownerID int64, name, content string) (*PublicKey, error) {
 	}
 
 	// Key name of same user cannot be duplicated.
-	has, err := x.Where("owner_id = ? AND name = ?", ownerID, name).Get(new(PublicKey))
-	if err != nil {
-		return nil, err
-	} else if has {
+	err := db.Where("owner_id = ? AND name = ?", ownerID, name).First(new(PublicKey)).Error
+	if err == nil {
 		return nil, ErrKeyNameAlreadyUsed{ownerID, name}
-	}
-
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
+	} else if !errors.Is(err, gorm.ErrRecordNotFound) {
 		return nil, err
 	}
 
@@ -424,21 +418,25 @@ func AddPublicKey(ownerID int64, name, content string) (*PublicKey, error) {
 		Mode:    AccessModeWrite,
 		Type:    KeyTypeUser,
 	}
-	if err = addKey(sess, key); err != nil {
+	err = db.Transaction(func(tx *gorm.DB) error {
+		return addKey(tx, key)
+	})
+	if err != nil {
 		return nil, errors.Newf("addKey: %v", err)
 	}
 
-	return key, sess.Commit()
+	return key, nil
 }
 
 // GetPublicKeyByID returns public key by given ID.
 func GetPublicKeyByID(keyID int64) (*PublicKey, error) {
 	key := new(PublicKey)
-	has, err := x.Id(keyID).Get(key)
+	err := db.Where("id = ?", keyID).First(key).Error
 	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			return nil, ErrKeyNotExist{keyID}
+		}
 		return nil, err
-	} else if !has {
-		return nil, ErrKeyNotExist{keyID}
 	}
 	return key, nil
 }
@@ -448,11 +446,12 @@ func GetPublicKeyByID(keyID int64) (*PublicKey, error) {
 // exists.
 func SearchPublicKeyByContent(content string) (*PublicKey, error) {
 	key := new(PublicKey)
-	has, err := x.Where("content like ?", content+"%").Get(key)
+	err := db.Where("content LIKE ?", content+"%").First(key).Error
 	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			return nil, ErrKeyNotExist{}
+		}
 		return nil, err
-	} else if !has {
-		return nil, ErrKeyNotExist{}
 	}
 	return key, nil
 }
@@ -460,23 +459,21 @@ func SearchPublicKeyByContent(content string) (*PublicKey, error) {
 // ListPublicKeys returns a list of public keys belongs to given user.
 func ListPublicKeys(uid int64) ([]*PublicKey, error) {
 	keys := make([]*PublicKey, 0, 5)
-	return keys, x.Where("owner_id = ?", uid).Find(&keys)
+	return keys, db.Where("owner_id = ?", uid).Find(&keys).Error
 }
 
 // UpdatePublicKey updates given public key.
 func UpdatePublicKey(key *PublicKey) error {
-	_, err := x.Id(key.ID).AllCols().Update(key)
-	return err
+	return db.Model(key).Where("id = ?", key.ID).Updates(key).Error
 }
 
 // deletePublicKeys does the actual key deletion but does not update authorized_keys file.
-func deletePublicKeys(e *xorm.Session, keyIDs ...int64) error {
+func deletePublicKeys(tx *gorm.DB, keyIDs ...int64) error {
 	if len(keyIDs) == 0 {
 		return nil
 	}
 
-	_, err := e.In("id", keyIDs).Delete(new(PublicKey))
-	return err
+	return tx.Where("id IN ?", keyIDs).Delete(new(PublicKey)).Error
 }
 
 // DeletePublicKey deletes SSH key information both in database and authorized_keys file.
@@ -494,17 +491,10 @@ func DeletePublicKey(doer *User, id int64) (err error) {
 		return ErrKeyAccessDenied{doer.ID, key.ID, "public"}
 	}
 
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
-		return err
-	}
-
-	if err = deletePublicKeys(sess, id); err != nil {
-		return err
-	}
-
-	if err = sess.Commit(); err != nil {
+	err = db.Transaction(func(tx *gorm.DB) error {
+		return deletePublicKeys(tx, id)
+	})
+	if err != nil {
 		return err
 	}
 
@@ -562,37 +552,40 @@ func RewriteAuthorizedKeys() error {
 // DeployKey represents deploy key information and its relation with repository.
 type DeployKey struct {
 	ID          int64
-	KeyID       int64 `xorm:"UNIQUE(s) INDEX"`
-	RepoID      int64 `xorm:"UNIQUE(s) INDEX"`
+	KeyID       int64 `gorm:"uniqueIndex:s;index"`
+	RepoID      int64 `gorm:"uniqueIndex:s;index"`
 	Name        string
 	Fingerprint string
-	Content     string `xorm:"-" json:"-" gorm:"-"`
+	Content     string `gorm:"-" json:"-"`
 
-	Created           time.Time `xorm:"-" json:"-" gorm:"-"`
+	Created           time.Time `gorm:"-" json:"-"`
 	CreatedUnix       int64
-	Updated           time.Time `xorm:"-" json:"-" gorm:"-"` // Note: Updated must below Created for AfterSet.
+	Updated           time.Time `gorm:"-" json:"-"` // Note: Updated must below Created for AfterFind.
 	UpdatedUnix       int64
-	HasRecentActivity bool `xorm:"-" json:"-" gorm:"-"`
-	HasUsed           bool `xorm:"-" json:"-" gorm:"-"`
+	HasRecentActivity bool `gorm:"-" json:"-"`
+	HasUsed           bool `gorm:"-" json:"-"`
 }
 
-func (k *DeployKey) BeforeInsert() {
-	k.CreatedUnix = time.Now().Unix()
+func (k *DeployKey) BeforeCreate(tx *gorm.DB) error {
+	if k.CreatedUnix == 0 {
+		k.CreatedUnix = tx.NowFunc().Unix()
+	}
+	return nil
 }
 
-func (k *DeployKey) BeforeUpdate() {
-	k.UpdatedUnix = time.Now().Unix()
+func (k *DeployKey) BeforeUpdate(tx *gorm.DB) error {
+	k.UpdatedUnix = tx.NowFunc().Unix()
+	return nil
 }
 
-func (k *DeployKey) AfterSet(colName string, _ xorm.Cell) {
-	switch colName {
-	case "created_unix":
-		k.Created = time.Unix(k.CreatedUnix, 0).Local()
-	case "updated_unix":
+func (k *DeployKey) AfterFind(tx *gorm.DB) error {
+	k.Created = time.Unix(k.CreatedUnix, 0).Local()
+	if k.UpdatedUnix > 0 {
 		k.Updated = time.Unix(k.UpdatedUnix, 0).Local()
 		k.HasUsed = k.Updated.After(k.Created)
-		k.HasRecentActivity = k.Updated.Add(7 * 24 * time.Hour).After(time.Now())
+		k.HasRecentActivity = k.Updated.Add(7 * 24 * time.Hour).After(tx.NowFunc())
 	}
+	return nil
 }
 
 // GetContent gets associated public key content.
@@ -605,28 +598,28 @@ func (k *DeployKey) GetContent() error {
 	return nil
 }
 
-func checkDeployKey(e Engine, keyID, repoID int64, name string) error {
+func checkDeployKey(tx *gorm.DB, keyID, repoID int64, name string) error {
 	// Note: We want error detail, not just true or false here.
-	has, err := e.Where("key_id = ? AND repo_id = ?", keyID, repoID).Get(new(DeployKey))
-	if err != nil {
-		return err
-	} else if has {
+	err := tx.Where("key_id = ? AND repo_id = ?", keyID, repoID).First(new(DeployKey)).Error
+	if err == nil {
 		return ErrDeployKeyAlreadyExist{keyID, repoID}
+	} else if !errors.Is(err, gorm.ErrRecordNotFound) {
+		return err
 	}
 
-	has, err = e.Where("repo_id = ? AND name = ?", repoID, name).Get(new(DeployKey))
-	if err != nil {
-		return err
-	} else if has {
+	err = tx.Where("repo_id = ? AND name = ?", repoID, name).First(new(DeployKey)).Error
+	if err == nil {
 		return ErrDeployKeyNameAlreadyUsed{repoID, name}
+	} else if !errors.Is(err, gorm.ErrRecordNotFound) {
+		return err
 	}
 
 	return nil
 }
 
 // addDeployKey adds new key-repo relation.
-func addDeployKey(e *xorm.Session, keyID, repoID int64, name, fingerprint string) (*DeployKey, error) {
-	if err := checkDeployKey(e, keyID, repoID, name); err != nil {
+func addDeployKey(tx *gorm.DB, keyID, repoID int64, name, fingerprint string) (*DeployKey, error) {
+	if err := checkDeployKey(tx, keyID, repoID, name); err != nil {
 		return nil, err
 	}
 
@@ -636,14 +629,14 @@ func addDeployKey(e *xorm.Session, keyID, repoID int64, name, fingerprint string
 		Name:        name,
 		Fingerprint: fingerprint,
 	}
-	_, err := e.Insert(key)
+	err := tx.Create(key).Error
 	return key, err
 }
 
 // HasDeployKey returns true if public key is a deploy key of given repository.
 func HasDeployKey(keyID, repoID int64) bool {
-	has, _ := x.Where("key_id = ? AND repo_id = ?", keyID, repoID).Get(new(DeployKey))
-	return has
+	err := db.Where("key_id = ? AND repo_id = ?", keyID, repoID).First(new(DeployKey)).Error
+	return err == nil
 }
 
 // AddDeployKey add new deploy key to database and authorized_keys file.
@@ -657,30 +650,34 @@ func AddDeployKey(repoID int64, name, content string) (*DeployKey, error) {
 		Mode:    AccessModeRead,
 		Type:    KeyTypeDeploy,
 	}
-	has, err := x.Get(pkey)
-	if err != nil {
+	err := db.Where("content = ? AND mode = ? AND type = ?", content, AccessModeRead, KeyTypeDeploy).First(pkey).Error
+	has := err == nil
+	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
 		return nil, err
 	}
 
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
-		return nil, err
-	}
+	var key *DeployKey
+	err = db.Transaction(func(tx *gorm.DB) error {
+		// First time use this deploy key.
+		if !has {
+			if err := addKey(tx, pkey); err != nil {
+				return errors.Newf("addKey: %v", err)
+			}
+		}
 
-	// First time use this deploy key.
-	if !has {
-		if err = addKey(sess, pkey); err != nil {
-			return nil, errors.Newf("addKey: %v", err)
+		var err error
+		key, err = addDeployKey(tx, pkey.ID, repoID, name, pkey.Fingerprint)
+		if err != nil {
+			return errors.Newf("addDeployKey: %v", err)
 		}
-	}
 
-	key, err := addDeployKey(sess, pkey.ID, repoID, name, pkey.Fingerprint)
+		return nil
+	})
 	if err != nil {
-		return nil, errors.Newf("addDeployKey: %v", err)
+		return nil, err
 	}
 
-	return key, sess.Commit()
+	return key, nil
 }
 
 var _ errutil.NotFound = (*ErrDeployKeyNotExist)(nil)
@@ -705,11 +702,12 @@ func (ErrDeployKeyNotExist) NotFound() bool {
 // GetDeployKeyByID returns deploy key by given ID.
 func GetDeployKeyByID(id int64) (*DeployKey, error) {
 	key := new(DeployKey)
-	has, err := x.Id(id).Get(key)
+	err := db.Where("id = ?", id).First(key).Error
 	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			return nil, ErrDeployKeyNotExist{args: map[string]any{"deployKeyID": id}}
+		}
 		return nil, err
-	} else if !has {
-		return nil, ErrDeployKeyNotExist{args: map[string]any{"deployKeyID": id}}
 	}
 	return key, nil
 }
@@ -720,19 +718,19 @@ func GetDeployKeyByRepo(keyID, repoID int64) (*DeployKey, error) {
 		KeyID:  keyID,
 		RepoID: repoID,
 	}
-	has, err := x.Get(key)
+	err := db.Where("key_id = ? AND repo_id = ?", keyID, repoID).First(key).Error
 	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			return nil, ErrDeployKeyNotExist{args: map[string]any{"keyID": keyID, "repoID": repoID}}
+		}
 		return nil, err
-	} else if !has {
-		return nil, ErrDeployKeyNotExist{args: map[string]any{"keyID": keyID, "repoID": repoID}}
 	}
 	return key, nil
 }
 
 // UpdateDeployKey updates deploy key information.
 func UpdateDeployKey(key *DeployKey) error {
-	_, err := x.Id(key.ID).AllCols().Update(key)
-	return err
+	return db.Model(key).Where("id = ?", key.ID).Updates(key).Error
 }
 
 // DeleteDeployKey deletes deploy key from its repository authorized_keys file if needed.
@@ -761,31 +759,27 @@ func DeleteDeployKey(doer *User, id int64) error {
 		}
 	}
 
-	sess := x.NewSession()
-	defer sess.Close()
-	if err = sess.Begin(); err != nil {
-		return err
-	}
-
-	if _, err = sess.ID(key.ID).Delete(new(DeployKey)); err != nil {
-		return errors.Newf("delete deploy key [%d]: %v", key.ID, err)
-	}
+	return db.Transaction(func(tx *gorm.DB) error {
+		if err := tx.Where("id = ?", key.ID).Delete(new(DeployKey)).Error; err != nil {
+			return errors.Newf("delete deploy key [%d]: %v", key.ID, err)
+		}
 
-	// Check if this is the last reference to same key content.
-	has, err := sess.Where("key_id = ?", key.KeyID).Get(new(DeployKey))
-	if err != nil {
-		return err
-	} else if !has {
-		if err = deletePublicKeys(sess, key.KeyID); err != nil {
+		// Check if this is the last reference to same key content.
+		err := tx.Where("key_id = ?", key.KeyID).First(new(DeployKey)).Error
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			if err = deletePublicKeys(tx, key.KeyID); err != nil {
+				return err
+			}
+		} else if err != nil {
 			return err
 		}
-	}
 
-	return sess.Commit()
+		return nil
+	})
 }
 
 // ListDeployKeys returns all deploy keys by given repository ID.
 func ListDeployKeys(repoID int64) ([]*DeployKey, error) {
 	keys := make([]*DeployKey, 0, 5)
-	return keys, x.Where("repo_id = ?", repoID).Find(&keys)
+	return keys, db.Where("repo_id = ?", repoID).Find(&keys).Error
 }