postgres_dialect.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package migrator
  2. import (
  3. "fmt"
  4. "strconv"
  5. "strings"
  6. "github.com/go-xorm/xorm"
  7. "github.com/lib/pq"
  8. )
  9. type Postgres struct {
  10. BaseDialect
  11. }
  12. func NewPostgresDialect(engine *xorm.Engine) *Postgres {
  13. d := Postgres{}
  14. d.BaseDialect.dialect = &d
  15. d.BaseDialect.engine = engine
  16. d.BaseDialect.driverName = POSTGRES
  17. return &d
  18. }
  19. func (db *Postgres) SupportEngine() bool {
  20. return false
  21. }
  22. func (db *Postgres) Quote(name string) string {
  23. return "\"" + name + "\""
  24. }
  25. func (b *Postgres) LikeStr() string {
  26. return "ILIKE"
  27. }
  28. func (db *Postgres) AutoIncrStr() string {
  29. return ""
  30. }
  31. func (db *Postgres) BooleanStr(value bool) string {
  32. return strconv.FormatBool(value)
  33. }
  34. func (b *Postgres) Default(col *Column) string {
  35. if col.Type == DB_Bool {
  36. if col.Default == "0" {
  37. return "FALSE"
  38. }
  39. return "TRUE"
  40. }
  41. return col.Default
  42. }
  43. func (db *Postgres) SqlType(c *Column) string {
  44. var res string
  45. switch t := c.Type; t {
  46. case DB_TinyInt:
  47. res = DB_SmallInt
  48. return res
  49. case DB_MediumInt, DB_Int, DB_Integer:
  50. if c.IsAutoIncrement {
  51. return DB_Serial
  52. }
  53. return DB_Integer
  54. case DB_Serial, DB_BigSerial:
  55. c.IsAutoIncrement = true
  56. c.Nullable = false
  57. res = t
  58. case DB_Binary, DB_VarBinary:
  59. return DB_Bytea
  60. case DB_DateTime:
  61. res = DB_TimeStamp
  62. case DB_TimeStampz:
  63. return "timestamp with time zone"
  64. case DB_Float:
  65. res = DB_Real
  66. case DB_TinyText, DB_MediumText, DB_LongText:
  67. res = DB_Text
  68. case DB_NVarchar:
  69. res = DB_Varchar
  70. case DB_Uuid:
  71. res = DB_Uuid
  72. case DB_Blob, DB_TinyBlob, DB_MediumBlob, DB_LongBlob:
  73. return DB_Bytea
  74. case DB_Double:
  75. return "DOUBLE PRECISION"
  76. default:
  77. if c.IsAutoIncrement {
  78. return DB_Serial
  79. }
  80. res = t
  81. }
  82. var hasLen1 = (c.Length > 0)
  83. var hasLen2 = (c.Length2 > 0)
  84. if hasLen2 {
  85. res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
  86. } else if hasLen1 {
  87. res += "(" + strconv.Itoa(c.Length) + ")"
  88. }
  89. return res
  90. }
  91. func (db *Postgres) IndexCheckSql(tableName, indexName string) (string, []interface{}) {
  92. args := []interface{}{tableName, indexName}
  93. sql := "SELECT 1 FROM " + db.Quote("pg_indexes") + " WHERE" + db.Quote("tablename") + "=? AND " + db.Quote("indexname") + "=?"
  94. return sql, args
  95. }
  96. func (db *Postgres) DropIndexSql(tableName string, index *Index) string {
  97. quote := db.Quote
  98. idxName := index.XName(tableName)
  99. return fmt.Sprintf("DROP INDEX %v", quote(idxName))
  100. }
  101. func (db *Postgres) UpdateTableSql(tableName string, columns []*Column) string {
  102. var statements = []string{}
  103. for _, col := range columns {
  104. statements = append(statements, "ALTER "+db.Quote(col.Name)+" TYPE "+db.SqlType(col))
  105. }
  106. return "ALTER TABLE " + db.Quote(tableName) + " " + strings.Join(statements, ", ") + ";"
  107. }
  108. func (db *Postgres) CleanDB() error {
  109. sess := db.engine.NewSession()
  110. defer sess.Close()
  111. if _, err := sess.Exec("DROP SCHEMA public CASCADE;"); err != nil {
  112. return fmt.Errorf("Failed to drop schema public")
  113. }
  114. if _, err := sess.Exec("CREATE SCHEMA public;"); err != nil {
  115. return fmt.Errorf("Failed to create schema public")
  116. }
  117. return nil
  118. }
  119. func (db *Postgres) IsUniqueConstraintViolation(err error) bool {
  120. if driverErr, ok := err.(*pq.Error); ok {
  121. if driverErr.Code == "23505" {
  122. return true
  123. }
  124. }
  125. return false
  126. }