1
0

dbtest.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package dbtest
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/require"
  10. "gorm.io/gorm"
  11. "gorm.io/gorm/schema"
  12. "gogs.io/gogs/internal/conf"
  13. "gogs.io/gogs/internal/dbutil"
  14. )
  15. // NewDB creates a new test database and initializes the given list of tables
  16. // for the suite. The test database is dropped after testing is completed unless
  17. // failed.
  18. func NewDB(t *testing.T, suite string, tables ...any) *gorm.DB {
  19. dbType := os.Getenv("GOGS_DATABASE_TYPE")
  20. var dbName string
  21. var dbOpts conf.DatabaseOpts
  22. var cleanup func(db *gorm.DB)
  23. switch dbType {
  24. case "mysql":
  25. dbOpts = conf.DatabaseOpts{
  26. Type: "mysql",
  27. Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
  28. Name: dbName,
  29. User: os.Getenv("MYSQL_USER"),
  30. Password: os.Getenv("MYSQL_PASSWORD"),
  31. }
  32. dsn, err := dbutil.NewDSN(dbOpts)
  33. require.NoError(t, err)
  34. sqlDB, err := sql.Open("mysql", dsn)
  35. require.NoError(t, err)
  36. // Set up test database
  37. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  38. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
  39. require.NoError(t, err)
  40. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
  41. require.NoError(t, err)
  42. dbOpts.Name = dbName
  43. cleanup = func(db *gorm.DB) {
  44. testDB, err := db.DB()
  45. if err == nil {
  46. _ = testDB.Close()
  47. }
  48. _, _ = sqlDB.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
  49. _ = sqlDB.Close()
  50. }
  51. case "postgres":
  52. dbOpts = conf.DatabaseOpts{
  53. Type: "postgres",
  54. Host: os.ExpandEnv("$PGHOST:$PGPORT"),
  55. Name: dbName,
  56. Schema: "public",
  57. User: os.Getenv("PGUSER"),
  58. Password: os.Getenv("PGPASSWORD"),
  59. SSLMode: os.Getenv("PGSSLMODE"),
  60. }
  61. dsn, err := dbutil.NewDSN(dbOpts)
  62. require.NoError(t, err)
  63. sqlDB, err := sql.Open("pgx", dsn)
  64. require.NoError(t, err)
  65. // Set up test database
  66. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  67. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
  68. require.NoError(t, err)
  69. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
  70. require.NoError(t, err)
  71. dbOpts.Name = dbName
  72. cleanup = func(db *gorm.DB) {
  73. testDB, err := db.DB()
  74. if err == nil {
  75. _ = testDB.Close()
  76. }
  77. _, _ = sqlDB.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
  78. _ = sqlDB.Close()
  79. }
  80. case "sqlite":
  81. dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
  82. dbOpts = conf.DatabaseOpts{
  83. Type: "sqlite",
  84. Path: dbName,
  85. }
  86. cleanup = func(db *gorm.DB) {
  87. sqlDB, err := db.DB()
  88. if err == nil {
  89. _ = sqlDB.Close()
  90. }
  91. _ = os.Remove(dbName)
  92. }
  93. default:
  94. dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
  95. dbOpts = conf.DatabaseOpts{
  96. Type: "sqlite3",
  97. Path: dbName,
  98. }
  99. cleanup = func(db *gorm.DB) {
  100. sqlDB, err := db.DB()
  101. if err == nil {
  102. _ = sqlDB.Close()
  103. }
  104. _ = os.Remove(dbName)
  105. }
  106. }
  107. now := time.Now().UTC().Truncate(time.Second)
  108. db, err := dbutil.OpenDB(
  109. dbOpts,
  110. &gorm.Config{
  111. SkipDefaultTransaction: true,
  112. NamingStrategy: schema.NamingStrategy{
  113. SingularTable: true,
  114. },
  115. NowFunc: func() time.Time {
  116. return now
  117. },
  118. },
  119. )
  120. require.NoError(t, err)
  121. t.Cleanup(func() {
  122. if t.Failed() {
  123. t.Logf("Database %q left intact for inspection", dbName)
  124. return
  125. }
  126. cleanup(db)
  127. })
  128. err = db.Migrator().AutoMigrate(tables...)
  129. require.NoError(t, err)
  130. return db
  131. }