login_sources.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. package database
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strconv"
  7. "time"
  8. "github.com/cockroachdb/errors"
  9. "gorm.io/gorm"
  10. "gogs.io/gogs/internal/auth"
  11. "gogs.io/gogs/internal/auth/github"
  12. "gogs.io/gogs/internal/auth/ldap"
  13. "gogs.io/gogs/internal/auth/pam"
  14. "gogs.io/gogs/internal/auth/smtp"
  15. "gogs.io/gogs/internal/errutil"
  16. )
  17. // LoginSource represents an external way for authorizing users.
  18. type LoginSource struct {
  19. ID int64 `gorm:"primaryKey"`
  20. Type auth.Type
  21. Name string `xorm:"UNIQUE" gorm:"unique"`
  22. IsActived bool `xorm:"NOT NULL DEFAULT false" gorm:"not null"`
  23. IsDefault bool `xorm:"DEFAULT false"`
  24. Provider auth.Provider `xorm:"-" gorm:"-"`
  25. Config string `xorm:"TEXT cfg" gorm:"column:cfg;type:TEXT" json:"RawConfig"`
  26. Created time.Time `xorm:"-" gorm:"-" json:"-"`
  27. CreatedUnix int64
  28. Updated time.Time `xorm:"-" gorm:"-" json:"-"`
  29. UpdatedUnix int64
  30. File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
  31. }
  32. // BeforeSave implements the GORM save hook.
  33. func (s *LoginSource) BeforeSave(_ *gorm.DB) (err error) {
  34. if s.Provider == nil {
  35. return nil
  36. }
  37. data, err := json.Marshal(s.Provider.Config())
  38. s.Config = string(data)
  39. return err
  40. }
  41. // BeforeCreate implements the GORM create hook.
  42. func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
  43. if s.CreatedUnix == 0 {
  44. s.CreatedUnix = tx.NowFunc().Unix()
  45. s.UpdatedUnix = s.CreatedUnix
  46. }
  47. return nil
  48. }
  49. // BeforeUpdate implements the GORM update hook.
  50. func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
  51. s.UpdatedUnix = tx.NowFunc().Unix()
  52. return nil
  53. }
  54. type mockProviderConfig struct {
  55. ExternalAccount *auth.ExternalAccount
  56. }
  57. // AfterFind implements the GORM query hook.
  58. func (s *LoginSource) AfterFind(_ *gorm.DB) error {
  59. s.Created = time.Unix(s.CreatedUnix, 0).Local()
  60. s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
  61. switch s.Type {
  62. case auth.LDAP:
  63. var cfg ldap.Config
  64. err := json.Unmarshal([]byte(s.Config), &cfg)
  65. if err != nil {
  66. return err
  67. }
  68. s.Provider = ldap.NewProvider(false, &cfg)
  69. case auth.DLDAP:
  70. var cfg ldap.Config
  71. err := json.Unmarshal([]byte(s.Config), &cfg)
  72. if err != nil {
  73. return err
  74. }
  75. s.Provider = ldap.NewProvider(true, &cfg)
  76. case auth.SMTP:
  77. var cfg smtp.Config
  78. err := json.Unmarshal([]byte(s.Config), &cfg)
  79. if err != nil {
  80. return err
  81. }
  82. s.Provider = smtp.NewProvider(&cfg)
  83. case auth.PAM:
  84. var cfg pam.Config
  85. err := json.Unmarshal([]byte(s.Config), &cfg)
  86. if err != nil {
  87. return err
  88. }
  89. s.Provider = pam.NewProvider(&cfg)
  90. case auth.GitHub:
  91. var cfg github.Config
  92. err := json.Unmarshal([]byte(s.Config), &cfg)
  93. if err != nil {
  94. return err
  95. }
  96. s.Provider = github.NewProvider(&cfg)
  97. case auth.Mock:
  98. var cfg mockProviderConfig
  99. err := json.Unmarshal([]byte(s.Config), &cfg)
  100. if err != nil {
  101. return err
  102. }
  103. mockProvider := NewMockProvider()
  104. mockProvider.AuthenticateFunc.SetDefaultReturn(cfg.ExternalAccount, nil)
  105. s.Provider = mockProvider
  106. default:
  107. return errors.Newf("unrecognized login source type: %v", s.Type)
  108. }
  109. return nil
  110. }
  111. func (s *LoginSource) TypeName() string {
  112. return auth.Name(s.Type)
  113. }
  114. func (s *LoginSource) IsLDAP() bool {
  115. return s.Type == auth.LDAP
  116. }
  117. func (s *LoginSource) IsDLDAP() bool {
  118. return s.Type == auth.DLDAP
  119. }
  120. func (s *LoginSource) IsSMTP() bool {
  121. return s.Type == auth.SMTP
  122. }
  123. func (s *LoginSource) IsPAM() bool {
  124. return s.Type == auth.PAM
  125. }
  126. func (s *LoginSource) IsGitHub() bool {
  127. return s.Type == auth.GitHub
  128. }
  129. func (s *LoginSource) LDAP() *ldap.Config {
  130. return s.Provider.Config().(*ldap.Config)
  131. }
  132. func (s *LoginSource) SMTP() *smtp.Config {
  133. return s.Provider.Config().(*smtp.Config)
  134. }
  135. func (s *LoginSource) PAM() *pam.Config {
  136. return s.Provider.Config().(*pam.Config)
  137. }
  138. func (s *LoginSource) GitHub() *github.Config {
  139. return s.Provider.Config().(*github.Config)
  140. }
  141. // LoginSourcesStore is the storage layer for login sources.
  142. type LoginSourcesStore struct {
  143. db *gorm.DB
  144. files loginSourceFilesStore
  145. }
  146. func newLoginSourcesStore(db *gorm.DB, files loginSourceFilesStore) *LoginSourcesStore {
  147. return &LoginSourcesStore{
  148. db: db,
  149. files: files,
  150. }
  151. }
  152. type CreateLoginSourceOptions struct {
  153. Type auth.Type
  154. Name string
  155. Activated bool
  156. Default bool
  157. Config any
  158. }
  159. type ErrLoginSourceAlreadyExist struct {
  160. args errutil.Args
  161. }
  162. func IsErrLoginSourceAlreadyExist(err error) bool {
  163. return errors.As(err, &ErrLoginSourceAlreadyExist{})
  164. }
  165. func (err ErrLoginSourceAlreadyExist) Error() string {
  166. return fmt.Sprintf("login source already exists: %v", err.args)
  167. }
  168. // Create creates a new login source and persists it to the database. It returns
  169. // ErrLoginSourceAlreadyExist when a login source with same name already exists.
  170. func (s *LoginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
  171. err := s.db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
  172. if err == nil {
  173. return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
  174. } else if !errors.Is(err, gorm.ErrRecordNotFound) {
  175. return nil, err
  176. }
  177. source := &LoginSource{
  178. Type: opts.Type,
  179. Name: opts.Name,
  180. IsActived: opts.Activated,
  181. IsDefault: opts.Default,
  182. }
  183. data, err := json.Marshal(opts.Config)
  184. source.Config = string(data)
  185. if err != nil {
  186. return nil, err
  187. }
  188. return source, s.db.WithContext(ctx).Create(source).Error
  189. }
  190. // Count returns the total number of login sources.
  191. func (s *LoginSourcesStore) Count(ctx context.Context) int64 {
  192. var count int64
  193. s.db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
  194. return count + int64(s.files.Len())
  195. }
  196. type ErrLoginSourceInUse struct {
  197. args errutil.Args
  198. }
  199. func IsErrLoginSourceInUse(err error) bool {
  200. return errors.As(err, &ErrLoginSourceInUse{})
  201. }
  202. func (err ErrLoginSourceInUse) Error() string {
  203. return fmt.Sprintf("login source is still used by some users: %v", err.args)
  204. }
  205. // DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
  206. // if at least one user is associated with the login source.
  207. func (s *LoginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
  208. var count int64
  209. err := s.db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
  210. if err != nil {
  211. return err
  212. } else if count > 0 {
  213. return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
  214. }
  215. return s.db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
  216. }
  217. // GetByID returns the login source with given ID. It returns
  218. // ErrLoginSourceNotExist when not found.
  219. func (s *LoginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
  220. source := new(LoginSource)
  221. err := s.db.WithContext(ctx).Where("id = ?", id).First(source).Error
  222. if err != nil {
  223. if errors.Is(err, gorm.ErrRecordNotFound) {
  224. return s.files.GetByID(id)
  225. }
  226. return nil, err
  227. }
  228. return source, nil
  229. }
  230. type ListLoginSourceOptions struct {
  231. // Whether to only include activated login sources.
  232. OnlyActivated bool
  233. }
  234. // List returns a list of login sources filtered by options.
  235. func (s *LoginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
  236. var sources []*LoginSource
  237. query := s.db.WithContext(ctx).Order("id ASC")
  238. if opts.OnlyActivated {
  239. query = query.Where("is_actived = ?", true)
  240. }
  241. err := query.Find(&sources).Error
  242. if err != nil {
  243. return nil, err
  244. }
  245. return append(sources, s.files.List(opts)...), nil
  246. }
  247. // ResetNonDefault clears default flag for all the other login sources.
  248. func (s *LoginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
  249. err := s.db.WithContext(ctx).
  250. Model(new(LoginSource)).
  251. Where("id != ?", dflt.ID).
  252. Updates(map[string]any{"is_default": false}).
  253. Error
  254. if err != nil {
  255. return err
  256. }
  257. for _, source := range s.files.List(ListLoginSourceOptions{}) {
  258. if source.File != nil && source.ID != dflt.ID {
  259. source.File.SetGeneral("is_default", "false")
  260. if err = source.File.Save(); err != nil {
  261. return errors.Wrap(err, "save file")
  262. }
  263. }
  264. }
  265. s.files.Update(dflt)
  266. return nil
  267. }
  268. // Save persists all values of given login source to database or local file. The
  269. // Updated field is set to current time automatically.
  270. func (s *LoginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
  271. if source.File == nil {
  272. return s.db.WithContext(ctx).Save(source).Error
  273. }
  274. source.File.SetGeneral("name", source.Name)
  275. source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
  276. source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
  277. if err := source.File.SetConfig(source.Provider.Config()); err != nil {
  278. return errors.Wrap(err, "set config")
  279. } else if err = source.File.Save(); err != nil {
  280. return errors.Wrap(err, "save file")
  281. }
  282. return nil
  283. }