builder.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package builder
  5. import (
  6. sql2 "database/sql"
  7. "fmt"
  8. "sort"
  9. )
  10. type optype byte
  11. const (
  12. condType optype = iota // only conditions
  13. selectType // select
  14. insertType // insert
  15. updateType // update
  16. deleteType // delete
  17. unionType // union
  18. )
  19. const (
  20. POSTGRES = "postgres"
  21. SQLITE = "sqlite3"
  22. MYSQL = "mysql"
  23. MSSQL = "mssql"
  24. ORACLE = "oracle"
  25. )
  26. type join struct {
  27. joinType string
  28. joinTable string
  29. joinCond Cond
  30. }
  31. type union struct {
  32. unionType string
  33. builder *Builder
  34. }
  35. type limit struct {
  36. limitN int
  37. offset int
  38. }
  39. // Builder describes a SQL statement
  40. type Builder struct {
  41. optype
  42. dialect string
  43. isNested bool
  44. into string
  45. from string
  46. subQuery *Builder
  47. cond Cond
  48. selects []string
  49. joins []join
  50. unions []union
  51. limitation *limit
  52. insertCols []string
  53. insertVals []interface{}
  54. updates []Eq
  55. orderBy string
  56. groupBy string
  57. having string
  58. }
  59. // Dialect sets the db dialect of Builder.
  60. func Dialect(dialect string) *Builder {
  61. builder := &Builder{cond: NewCond(), dialect: dialect}
  62. return builder
  63. }
  64. // MySQL is shortcut of Dialect(MySQL)
  65. func MySQL() *Builder {
  66. return Dialect(MYSQL)
  67. }
  68. // MsSQL is shortcut of Dialect(MsSQL)
  69. func MsSQL() *Builder {
  70. return Dialect(MSSQL)
  71. }
  72. // Oracle is shortcut of Dialect(Oracle)
  73. func Oracle() *Builder {
  74. return Dialect(ORACLE)
  75. }
  76. // Postgres is shortcut of Dialect(Postgres)
  77. func Postgres() *Builder {
  78. return Dialect(POSTGRES)
  79. }
  80. // SQLite is shortcut of Dialect(SQLITE)
  81. func SQLite() *Builder {
  82. return Dialect(SQLITE)
  83. }
  84. // Where sets where SQL
  85. func (b *Builder) Where(cond Cond) *Builder {
  86. if b.cond.IsValid() {
  87. b.cond = b.cond.And(cond)
  88. } else {
  89. b.cond = cond
  90. }
  91. return b
  92. }
  93. // From sets from subject(can be a table name in string or a builder pointer) and its alias
  94. func (b *Builder) From(subject interface{}, alias ...string) *Builder {
  95. switch subject.(type) {
  96. case *Builder:
  97. b.subQuery = subject.(*Builder)
  98. if len(alias) > 0 {
  99. b.from = alias[0]
  100. } else {
  101. b.isNested = true
  102. }
  103. case string:
  104. b.from = subject.(string)
  105. if len(alias) > 0 {
  106. b.from = b.from + " " + alias[0]
  107. }
  108. }
  109. return b
  110. }
  111. // TableName returns the table name
  112. func (b *Builder) TableName() string {
  113. if b.optype == insertType {
  114. return b.into
  115. }
  116. return b.from
  117. }
  118. // Into sets insert table name
  119. func (b *Builder) Into(tableName string) *Builder {
  120. b.into = tableName
  121. return b
  122. }
  123. // Join sets join table and conditions
  124. func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builder {
  125. switch joinCond.(type) {
  126. case Cond:
  127. b.joins = append(b.joins, join{joinType, joinTable, joinCond.(Cond)})
  128. case string:
  129. b.joins = append(b.joins, join{joinType, joinTable, Expr(joinCond.(string))})
  130. }
  131. return b
  132. }
  133. // Union sets union conditions
  134. func (b *Builder) Union(unionTp string, unionCond *Builder) *Builder {
  135. var builder *Builder
  136. if b.optype != unionType {
  137. builder = &Builder{cond: NewCond()}
  138. builder.optype = unionType
  139. builder.dialect = b.dialect
  140. builder.selects = b.selects
  141. currentUnions := b.unions
  142. // erase sub unions (actually append to new Builder.unions)
  143. b.unions = nil
  144. for e := range currentUnions {
  145. currentUnions[e].builder.dialect = b.dialect
  146. }
  147. builder.unions = append(append(builder.unions, union{"", b}), currentUnions...)
  148. } else {
  149. builder = b
  150. }
  151. if unionCond != nil {
  152. if unionCond.dialect == "" && builder.dialect != "" {
  153. unionCond.dialect = builder.dialect
  154. }
  155. builder.unions = append(builder.unions, union{unionTp, unionCond})
  156. }
  157. return builder
  158. }
  159. // Limit sets limitN condition
  160. func (b *Builder) Limit(limitN int, offset ...int) *Builder {
  161. b.limitation = &limit{limitN: limitN}
  162. if len(offset) > 0 {
  163. b.limitation.offset = offset[0]
  164. }
  165. return b
  166. }
  167. // InnerJoin sets inner join
  168. func (b *Builder) InnerJoin(joinTable string, joinCond interface{}) *Builder {
  169. return b.Join("INNER", joinTable, joinCond)
  170. }
  171. // LeftJoin sets left join SQL
  172. func (b *Builder) LeftJoin(joinTable string, joinCond interface{}) *Builder {
  173. return b.Join("LEFT", joinTable, joinCond)
  174. }
  175. // RightJoin sets right join SQL
  176. func (b *Builder) RightJoin(joinTable string, joinCond interface{}) *Builder {
  177. return b.Join("RIGHT", joinTable, joinCond)
  178. }
  179. // CrossJoin sets cross join SQL
  180. func (b *Builder) CrossJoin(joinTable string, joinCond interface{}) *Builder {
  181. return b.Join("CROSS", joinTable, joinCond)
  182. }
  183. // FullJoin sets full join SQL
  184. func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
  185. return b.Join("FULL", joinTable, joinCond)
  186. }
  187. // Select sets select SQL
  188. func (b *Builder) Select(cols ...string) *Builder {
  189. b.selects = cols
  190. if b.optype == condType {
  191. b.optype = selectType
  192. }
  193. return b
  194. }
  195. // And sets AND condition
  196. func (b *Builder) And(cond Cond) *Builder {
  197. b.cond = And(b.cond, cond)
  198. return b
  199. }
  200. // Or sets OR condition
  201. func (b *Builder) Or(cond Cond) *Builder {
  202. b.cond = Or(b.cond, cond)
  203. return b
  204. }
  205. // Insert sets insert SQL
  206. func (b *Builder) Insert(eq ...interface{}) *Builder {
  207. if len(eq) > 0 {
  208. var paramType = -1
  209. for _, e := range eq {
  210. switch t := e.(type) {
  211. case Eq:
  212. if paramType == -1 {
  213. paramType = 0
  214. }
  215. if paramType != 0 {
  216. break
  217. }
  218. for k, v := range t {
  219. b.insertCols = append(b.insertCols, k)
  220. b.insertVals = append(b.insertVals, v)
  221. }
  222. case string:
  223. if paramType == -1 {
  224. paramType = 1
  225. }
  226. if paramType != 1 {
  227. break
  228. }
  229. b.insertCols = append(b.insertCols, t)
  230. }
  231. }
  232. }
  233. if len(b.insertCols) == len(b.insertVals) {
  234. sort.Slice(b.insertVals, func(i, j int) bool {
  235. return b.insertCols[i] < b.insertCols[j]
  236. })
  237. sort.Strings(b.insertCols)
  238. }
  239. b.optype = insertType
  240. return b
  241. }
  242. // Update sets update SQL
  243. func (b *Builder) Update(updates ...Eq) *Builder {
  244. b.updates = make([]Eq, 0, len(updates))
  245. for _, update := range updates {
  246. if update.IsValid() {
  247. b.updates = append(b.updates, update)
  248. }
  249. }
  250. b.optype = updateType
  251. return b
  252. }
  253. // Delete sets delete SQL
  254. func (b *Builder) Delete(conds ...Cond) *Builder {
  255. b.cond = b.cond.And(conds...)
  256. b.optype = deleteType
  257. return b
  258. }
  259. // WriteTo implements Writer interface
  260. func (b *Builder) WriteTo(w Writer) error {
  261. switch b.optype {
  262. /*case condType:
  263. return b.cond.WriteTo(w)*/
  264. case selectType:
  265. return b.selectWriteTo(w)
  266. case insertType:
  267. return b.insertWriteTo(w)
  268. case updateType:
  269. return b.updateWriteTo(w)
  270. case deleteType:
  271. return b.deleteWriteTo(w)
  272. case unionType:
  273. return b.unionWriteTo(w)
  274. }
  275. return ErrNotSupportType
  276. }
  277. // ToSQL convert a builder to SQL and args
  278. func (b *Builder) ToSQL() (string, []interface{}, error) {
  279. w := NewWriter()
  280. if err := b.WriteTo(w); err != nil {
  281. return "", nil, err
  282. }
  283. // in case of sql.NamedArg in args
  284. for e := range w.args {
  285. if namedArg, ok := w.args[e].(sql2.NamedArg); ok {
  286. w.args[e] = namedArg.Value
  287. }
  288. }
  289. var sql = w.writer.String()
  290. var err error
  291. switch b.dialect {
  292. case ORACLE, MSSQL:
  293. // This is for compatibility with different sql drivers
  294. for e := range w.args {
  295. w.args[e] = sql2.Named(fmt.Sprintf("p%d", e+1), w.args[e])
  296. }
  297. var prefix string
  298. if b.dialect == ORACLE {
  299. prefix = ":p"
  300. } else {
  301. prefix = "@p"
  302. }
  303. if sql, err = ConvertPlaceholder(sql, prefix); err != nil {
  304. return "", nil, err
  305. }
  306. case POSTGRES:
  307. if sql, err = ConvertPlaceholder(sql, "$"); err != nil {
  308. return "", nil, err
  309. }
  310. }
  311. return sql, w.args, nil
  312. }
  313. // ToBoundSQL
  314. func (b *Builder) ToBoundSQL() (string, error) {
  315. w := NewWriter()
  316. if err := b.WriteTo(w); err != nil {
  317. return "", err
  318. }
  319. return ConvertToBoundSQL(w.writer.String(), w.args)
  320. }