public_keys.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package database
  2. import (
  3. "os"
  4. "path/filepath"
  5. "github.com/pkg/errors"
  6. "gorm.io/gorm"
  7. "gogs.io/gogs/internal/conf"
  8. "gogs.io/gogs/internal/osutil"
  9. )
  10. // PublicKeysStore is the storage layer for public keys.
  11. type PublicKeysStore struct {
  12. db *gorm.DB
  13. }
  14. func newPublicKeysStore(db *gorm.DB) *PublicKeysStore {
  15. return &PublicKeysStore{db: db}
  16. }
  17. func authorizedKeysPath() string {
  18. return filepath.Join(conf.SSH.RootPath, "authorized_keys")
  19. }
  20. // RewriteAuthorizedKeys rewrites the "authorized_keys" file under the SSH root
  21. // path with all public keys stored in the database.
  22. func (s *PublicKeysStore) RewriteAuthorizedKeys() error {
  23. sshOpLocker.Lock()
  24. defer sshOpLocker.Unlock()
  25. err := os.MkdirAll(conf.SSH.RootPath, os.ModePerm)
  26. if err != nil {
  27. return errors.Wrap(err, "create SSH root path")
  28. }
  29. fpath := authorizedKeysPath()
  30. tempPath := fpath + ".tmp"
  31. f, err := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600)
  32. if err != nil {
  33. return errors.Wrap(err, "create temporary file")
  34. }
  35. defer func() {
  36. _ = f.Close()
  37. _ = os.Remove(tempPath)
  38. }()
  39. // NOTE: More recently updated keys are more likely to be used more frequently,
  40. // putting them in the earlier lines could speed up the key lookup by SSHD.
  41. rows, err := s.db.Model(&PublicKey{}).Order("updated_unix DESC").Rows()
  42. if err != nil {
  43. return errors.Wrap(err, "iterate public keys")
  44. }
  45. defer func() { _ = rows.Close() }()
  46. for rows.Next() {
  47. var key PublicKey
  48. err = s.db.ScanRows(rows, &key)
  49. if err != nil {
  50. return errors.Wrap(err, "scan rows")
  51. }
  52. _, err = f.WriteString(key.AuthorizedString())
  53. if err != nil {
  54. return errors.Wrapf(err, "write key %d", key.ID)
  55. }
  56. }
  57. if err = rows.Err(); err != nil {
  58. return errors.Wrap(err, "check rows.Err")
  59. }
  60. err = f.Close()
  61. if err != nil {
  62. return errors.Wrap(err, "close temporary file")
  63. }
  64. if osutil.Exist(fpath) {
  65. err = os.Remove(fpath)
  66. if err != nil {
  67. return errors.Wrap(err, "remove")
  68. }
  69. }
  70. err = os.Rename(tempPath, fpath)
  71. if err != nil {
  72. return errors.Wrap(err, "rename")
  73. }
  74. return nil
  75. }