main.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package main
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. "sort"
  7. "strings"
  8. "github.com/cockroachdb/errors"
  9. "github.com/glebarez/sqlite"
  10. "github.com/olekukonko/tablewriter"
  11. "gopkg.in/DATA-DOG/go-sqlmock.v2"
  12. "gorm.io/driver/mysql"
  13. "gorm.io/driver/postgres"
  14. "gorm.io/gorm"
  15. "gorm.io/gorm/clause"
  16. "gorm.io/gorm/schema"
  17. "gogs.io/gogs/internal/database"
  18. )
  19. //go:generate go run main.go ../../../docs/dev/database_schema.md
  20. func main() {
  21. w, err := os.Create(os.Args[1])
  22. if err != nil {
  23. log.Fatalf("Failed to create file: %v", err)
  24. }
  25. defer func() { _ = w.Close() }()
  26. conn, _, err := sqlmock.New()
  27. if err != nil {
  28. log.Fatalf("Failed to get mock connection: %v", err)
  29. }
  30. defer func() { _ = conn.Close() }()
  31. dialectors := []gorm.Dialector{
  32. postgres.New(postgres.Config{
  33. Conn: conn,
  34. }),
  35. mysql.New(mysql.Config{
  36. Conn: conn,
  37. SkipInitializeWithVersion: true,
  38. }),
  39. sqlite.Open(""),
  40. }
  41. collected := make([][]*tableInfo, 0, len(dialectors))
  42. for i, dialector := range dialectors {
  43. tableInfos, err := generate(dialector)
  44. if err != nil {
  45. log.Fatalf("Failed to get table info of %d: %v", i, err)
  46. }
  47. collected = append(collected, tableInfos)
  48. }
  49. for i, ti := range collected[0] {
  50. _, _ = w.WriteString(`# Table "` + ti.Name + `"`)
  51. _, _ = w.WriteString("\n\n")
  52. _, _ = w.WriteString("```\n")
  53. table := tablewriter.NewWriter(w)
  54. table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"})
  55. table.SetBorder(false)
  56. for j, f := range ti.Fields {
  57. sqlite3Type := strings.ToUpper(collected[2][i].Fields[j].Type)
  58. sqlite3Type = strings.ReplaceAll(sqlite3Type, "PRIMARY KEY ", "")
  59. table.Append([]string{
  60. f.Name, f.Column,
  61. strings.ToUpper(f.Type), // PostgreSQL
  62. strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL
  63. sqlite3Type,
  64. })
  65. }
  66. table.Render()
  67. _, _ = w.WriteString("\n")
  68. _, _ = w.WriteString("Primary keys: ")
  69. _, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", "))
  70. _, _ = w.WriteString("\n")
  71. if len(ti.Indexes) > 0 {
  72. _, _ = w.WriteString("Indexes: \n")
  73. for _, index := range ti.Indexes {
  74. _, _ = fmt.Fprintf(w, "\t%q", index.Name)
  75. if index.Class != "" {
  76. _, _ = fmt.Fprintf(w, " %s", index.Class)
  77. }
  78. if index.Type != "" {
  79. _, _ = fmt.Fprintf(w, ", %s", index.Type)
  80. }
  81. if len(index.Fields) > 0 {
  82. fields := make([]string, len(index.Fields))
  83. for i := range index.Fields {
  84. fields[i] = index.Fields[i].DBName
  85. }
  86. _, _ = fmt.Fprintf(w, " (%s)", strings.Join(fields, ", "))
  87. }
  88. _, _ = w.WriteString("\n")
  89. }
  90. }
  91. _, _ = w.WriteString("```\n\n")
  92. }
  93. }
  94. type tableField struct {
  95. Name string
  96. Column string
  97. Type string
  98. }
  99. type tableInfo struct {
  100. Name string
  101. Fields []*tableField
  102. PrimaryKeys []string
  103. Indexes []schema.Index
  104. }
  105. // This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable.
  106. func generate(dialector gorm.Dialector) ([]*tableInfo, error) {
  107. conn, err := gorm.Open(dialector,
  108. &gorm.Config{
  109. SkipDefaultTransaction: true,
  110. NamingStrategy: schema.NamingStrategy{
  111. SingularTable: true,
  112. },
  113. DryRun: true,
  114. DisableAutomaticPing: true,
  115. },
  116. )
  117. if err != nil {
  118. return nil, errors.Wrap(err, "open database")
  119. }
  120. m := conn.Migrator().(interface {
  121. RunWithValue(value any, fc func(*gorm.Statement) error) error
  122. FullDataTypeOf(*schema.Field) clause.Expr
  123. })
  124. tableInfos := make([]*tableInfo, 0, len(database.Tables))
  125. for _, table := range database.Tables {
  126. err = m.RunWithValue(table, func(stmt *gorm.Statement) error {
  127. fields := make([]*tableField, 0, len(stmt.Schema.DBNames))
  128. for _, field := range stmt.Schema.Fields {
  129. if field.DBName == "" {
  130. continue
  131. }
  132. tags := make([]string, 0)
  133. for tag := range field.TagSettings {
  134. if tag == "UNIQUE" {
  135. tags = append(tags, tag)
  136. }
  137. }
  138. typeSuffix := ""
  139. if len(tags) > 0 {
  140. typeSuffix = " " + strings.Join(tags, " ")
  141. }
  142. fields = append(fields, &tableField{
  143. Name: field.Name,
  144. Column: field.DBName,
  145. Type: m.FullDataTypeOf(field).SQL + typeSuffix,
  146. })
  147. }
  148. primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields))
  149. if len(stmt.Schema.PrimaryFields) > 0 {
  150. for _, field := range stmt.Schema.PrimaryFields {
  151. primaryKeys = append(primaryKeys, field.DBName)
  152. }
  153. }
  154. var indexes []schema.Index
  155. for _, index := range stmt.Schema.ParseIndexes() {
  156. indexes = append(indexes, index)
  157. }
  158. sort.Slice(indexes, func(i, j int) bool {
  159. return indexes[i].Name < indexes[j].Name
  160. })
  161. tableInfos = append(tableInfos, &tableInfo{
  162. Name: stmt.Table,
  163. Fields: fields,
  164. PrimaryKeys: primaryKeys,
  165. Indexes: indexes,
  166. })
  167. return nil
  168. })
  169. if err != nil {
  170. return nil, errors.Wrap(err, "gather table information")
  171. }
  172. }
  173. return tableInfos, nil
  174. }