two_factors_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package database
  2. import (
  3. "context"
  4. "testing"
  5. "time"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/stretchr/testify/require"
  8. "gorm.io/gorm"
  9. "gogs.io/gogs/internal/errutil"
  10. )
  11. func TestTwoFactor_BeforeCreate(t *testing.T) {
  12. now := time.Now()
  13. db := &gorm.DB{
  14. Config: &gorm.Config{
  15. SkipDefaultTransaction: true,
  16. NowFunc: func() time.Time {
  17. return now
  18. },
  19. },
  20. }
  21. t.Run("CreatedUnix has been set", func(t *testing.T) {
  22. tf := &TwoFactor{
  23. CreatedUnix: 1,
  24. }
  25. _ = tf.BeforeCreate(db)
  26. assert.Equal(t, int64(1), tf.CreatedUnix)
  27. })
  28. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  29. tf := &TwoFactor{}
  30. _ = tf.BeforeCreate(db)
  31. assert.Equal(t, db.NowFunc().Unix(), tf.CreatedUnix)
  32. })
  33. }
  34. func TestTwoFactor_AfterFind(t *testing.T) {
  35. now := time.Now()
  36. db := &gorm.DB{
  37. Config: &gorm.Config{
  38. SkipDefaultTransaction: true,
  39. NowFunc: func() time.Time {
  40. return now
  41. },
  42. },
  43. }
  44. tf := &TwoFactor{
  45. CreatedUnix: now.Unix(),
  46. }
  47. _ = tf.AfterFind(db)
  48. assert.Equal(t, tf.CreatedUnix, tf.Created.Unix())
  49. }
  50. func TestTwoFactors(t *testing.T) {
  51. if testing.Short() {
  52. t.Skip()
  53. }
  54. t.Parallel()
  55. ctx := context.Background()
  56. s := &TwoFactorsStore{
  57. db: newTestDB(t, "TwoFactorsStore"),
  58. }
  59. for _, tc := range []struct {
  60. name string
  61. test func(t *testing.T, ctx context.Context, s *TwoFactorsStore)
  62. }{
  63. {"Create", twoFactorsCreate},
  64. {"GetByUserID", twoFactorsGetByUserID},
  65. {"IsEnabled", twoFactorsIsEnabled},
  66. } {
  67. t.Run(tc.name, func(t *testing.T) {
  68. t.Cleanup(func() {
  69. err := clearTables(t, s.db)
  70. require.NoError(t, err)
  71. })
  72. tc.test(t, ctx, s)
  73. })
  74. if t.Failed() {
  75. break
  76. }
  77. }
  78. }
  79. func twoFactorsCreate(t *testing.T, ctx context.Context, s *TwoFactorsStore) {
  80. // Create a 2FA token
  81. err := s.Create(ctx, 1, "secure-key", "secure-secret")
  82. require.NoError(t, err)
  83. // Get it back and check the Created field
  84. tf, err := s.GetByUserID(ctx, 1)
  85. require.NoError(t, err)
  86. assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339))
  87. // Verify there are 10 recover codes generated
  88. var count int64
  89. err = s.db.Model(new(TwoFactorRecoveryCode)).Count(&count).Error
  90. require.NoError(t, err)
  91. assert.Equal(t, int64(10), count)
  92. }
  93. func twoFactorsGetByUserID(t *testing.T, ctx context.Context, s *TwoFactorsStore) {
  94. // Create a 2FA token for user 1
  95. err := s.Create(ctx, 1, "secure-key", "secure-secret")
  96. require.NoError(t, err)
  97. // We should be able to get it back
  98. _, err = s.GetByUserID(ctx, 1)
  99. require.NoError(t, err)
  100. // Try to get a non-existent 2FA token
  101. _, err = s.GetByUserID(ctx, 2)
  102. wantErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}}
  103. assert.Equal(t, wantErr, err)
  104. }
  105. func twoFactorsIsEnabled(t *testing.T, ctx context.Context, s *TwoFactorsStore) {
  106. // Create a 2FA token for user 1
  107. err := s.Create(ctx, 1, "secure-key", "secure-secret")
  108. require.NoError(t, err)
  109. assert.True(t, s.IsEnabled(ctx, 1))
  110. assert.False(t, s.IsEnabled(ctx, 2))
  111. }